| import torch |
| import torch.nn as nn |
| from transformers import LlamaForCausalLM, PreTrainedModel |
| from typing import Optional |
|
|
|
|
| class PianoLLaMA(PreTrainedModel): |
| """基于LLaMA架构的Piano生成模型(纯模型层,不含 I/O)""" |
| _supports_flash_attn_2 = True |
| _supports_flash_attn = True |
| _supports_sdpa = True |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
| self.model = LlamaForCausalLM(self.config) |
| self.pad_token_id = config.pad_token_id |
| self.bos_token_id = config.bos_token_id |
| self.eos_token_id = config.eos_token_id |
| self.model.apply(self._init_weights) |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| ): |
| return self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| labels=labels, |
| use_cache=False if labels is not None else True, |
| ) |
|
|
| |
|
|
| @torch.no_grad() |
| def generate_music( |
| self, |
| initial_tokens: torch.Tensor, |
| device: str = 'cuda', |
| max_length: int = 8192, |
| temperature: float = 0.8, |
| top_k: int = 50, |
| top_p: float = 0.95, |
| repetition_penalty: float = 1.0, |
| ) -> torch.Tensor: |
| self.eval() |
| input_ids = initial_tokens.unsqueeze(0).to(device) |
|
|
| output = self.model.generate( |
| input_ids=input_ids, |
| max_length=max_length, |
| do_sample=True, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| eos_token_id=self.eos_token_id, |
| pad_token_id=self.pad_token_id, |
| ) |
| return output.cpu() |
|
|