Musci-ASR-2.4B / chat_template_default.py
Musci-research's picture
upload Musci-ASR-2.4B
6cb6a8a verified
from dataclasses import dataclass
from typing import Literal, Optional
import numpy
@dataclass
class MusciChatTemplateSegment:
type: Literal["constant_text_token", "text_token", "audio_token", "audio_contiguous"]
add_loss: bool = True
text_ids: Optional[numpy.ndarray] = None
text_token_idx: Optional[int] = None
text_token_key: Optional[str] = None
def __post_init__(self) -> None:
if self.type == "constant_text_token":
assert self.text_ids is not None
elif self.type == "text_token":
assert self.text_token_key is not None and self.text_token_idx is not None
elif self.type in ("audio_token", "audio_contiguous"):
assert not self.add_loss
STYLE_CONTROL_TEXT = ""
chat_template = [
# <|im_start|>user\n<|audio_start|>
MusciChatTemplateSegment(
type="constant_text_token",
text_ids=numpy.array([151644, 872, 198, 151669]),
add_loss=False,
),
MusciChatTemplateSegment(
type="audio_contiguous",
add_loss=False,
),
# <|audio_end|><|im_end|>\n<|im_start|>assistant\n
MusciChatTemplateSegment(
type="constant_text_token",
text_ids=numpy.array([151670, 151645, 198, 151644, 77091, 198]),
add_loss=False,
),
MusciChatTemplateSegment(
type="text_token",
text_token_key="text_token_transcript",
text_token_idx=0,
add_loss=True,
),
# <|im_end|>
MusciChatTemplateSegment(
type="constant_text_token",
text_ids=numpy.array([151645]),
add_loss=True,
),
]
__all__ = ["MusciChatTemplateSegment", "STYLE_CONTROL_TEXT", "chat_template"]