| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ CED model configuration""" |
| |
|
| |
|
| | from transformers import PretrainedConfig |
| | from transformers.utils import logging |
| | from transformers.utils.hub import cached_file |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | CED_PRETRAINED_CONFIG_ARCHIVE_MAP = { |
| | "mispeech/ced-tiny": "https://huggingface.co/mispeech/ced-tiny/resolve/main/config.json", |
| | } |
| |
|
| |
|
| | class CedConfig(PretrainedConfig): |
| | model_type = "ced" |
| |
|
| | r""" |
| | Configuration class for the CED model. |
| | |
| | Args: |
| | name (str, optional, *optional*): |
| | Name of the pre-defined configuration. Can be "ced-tiny", "ced-mini", "ced-small" or "ced-base". |
| | attn_drop_rate (float, *optional*, defaults to 0.0): |
| | Dropout probability for attention weights. Default to 0.0. |
| | depth (int, *optional*, defaults to 12): Number of transformer layers. Default to 12. |
| | drop_path_rate (float, *optional*, defaults to 0.0): Drop path is taken from timm. Default to 0.0. |
| | drop_rate (float, *optional*, defaults to 0.0): |
| | Dropout probability for input embeddings. Default to 0.0. |
| | embed_dim (int, *optional*, defaults to 768): |
| | Dimensionality of the audio patch embeddings. Default to 768. |
| | eval_avg (str, *optional*, defaults to `"mean"`): |
| | Type of pooling to use for evaluation. Can be "mean", "token", "dm" or "logit". Default to "mean". |
| | mlp_ratio (float, *optional*, defaults to 4.0): |
| | Ratio of hidden size in the feedforward layer to the embedding size. Default to 4.0. |
| | num_heads (int, *optional*, defaults to 12): Number of attention heads. Default to 12. |
| | outputdim (int, *optional*, defaults to 527): Dimensionality of the output. Default to 527. |
| | patch_size (int, *optional*, defaults to 16): Size of the patches. Default to 16. |
| | patch_stride (int, *optional*, defaults to 16): Stride of the patches. Default to 16. |
| | pooling (str, *optional*, defaults to `"mean"`): |
| | Type of pooling to use for the output. Can be "mean", "token", "dm" or "logit". Default to "mean". |
| | qkv_bias (bool, *optional*, defaults to `True`): |
| | Whether to include bias terms in the query, key and value projections. Default to True. |
| | target_length (int, *optional*, defaults to 1012): Frames of an audio chunk. Default to 1012. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | name=None, |
| | attn_drop_rate=0.0, |
| | depth=12, |
| | drop_path_rate=0.0, |
| | drop_rate=0.0, |
| | embed_dim=768, |
| | eval_avg="mean", |
| | mlp_ratio=4.0, |
| | num_heads=12, |
| | outputdim=527, |
| | patch_size=16, |
| | patch_stride=16, |
| | pooling="mean", |
| | qkv_bias=True, |
| | target_length=1012, |
| | **kwargs, |
| | ): |
| | r""" |
| | TODO: Add docstring |
| | """ |
| |
|
| | super().__init__(**kwargs) |
| |
|
| | if name == "ced-tiny": |
| | embed_dim = 192 |
| | num_heads = 3 |
| | elif name == "ced-mini": |
| | embed_dim = 256 |
| | num_heads = 4 |
| | elif name == "ced-small": |
| | embed_dim = 384 |
| | num_heads = 6 |
| | elif name == "ced-base": |
| | embed_dim = 768 |
| | num_heads = 12 |
| | else: |
| | logger.info("No model name specified for CedConfig, use default settings.") |
| |
|
| | assert pooling in ("mean", "token", "dm", "logit") |
| | self.name = name |
| | self.attn_drop_rate = attn_drop_rate |
| | self.center = kwargs.get("center", True) |
| | self.depth = depth |
| | self.drop_path_rate = drop_path_rate |
| | self.drop_rate = drop_rate |
| | self.embed_dim = embed_dim |
| | self.eval_avg = eval_avg |
| | self.f_max = kwargs.get("f_max", 8000) |
| | self.f_min = kwargs.get("f_min", 0) |
| | self.hop_size = kwargs.get("hop_size", 160) |
| | self.mlp_ratio = mlp_ratio |
| | self.n_fft = kwargs.get("n_fft", 512) |
| | self.n_mels = kwargs.get("n_mels", 64) |
| | self.n_mels = kwargs.get("n_mels", 64) |
| | self.num_heads = num_heads |
| | self.outputdim = outputdim |
| | self.pad_last = kwargs.get("pad_last", True) |
| | self.patch_size = patch_size |
| | self.patch_stride = patch_stride |
| | self.pooling = pooling |
| | self.qkv_bias = qkv_bias |
| | self.target_length = target_length |
| | self.win_size = kwargs.get("win_size", 512) |
| | self.loss = "BCE" |
| |
|
| | if self.outputdim == 527: |
| | with open(cached_file("topel/ConvNeXt-Tiny-AT", "class_labels_indices.csv"), "r") as f: |
| | self.id2label = { |
| | int(line.split(",", maxsplit=3)[0]): line.split(",", maxsplit=3)[2].replace('"', "").strip("\n") |
| | for line in f.readlines()[1:] |
| | } |
| | self.label2id = {v: k for k, v in self.id2label.items()} |
| | else: |
| | self.id2label = None |
| | self.label2id = None |
| |
|