File size: 2,473 Bytes
8b64619 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | import torch
import torch.nn as nn
from transformers import LlamaForCausalLM, PreTrainedModel
from typing import Optional
class PianoLLaMA(PreTrainedModel):
"""基于LLaMA架构的Piano生成模型(纯模型层,不含 I/O)"""
_supports_flash_attn_2 = True
_supports_flash_attn = True
_supports_sdpa = True
def __init__(self, config):
super().__init__(config)
self.config = config
self.model = LlamaForCausalLM(self.config)
self.pad_token_id = config.pad_token_id
self.bos_token_id = config.bos_token_id
self.eos_token_id = config.eos_token_id
self.model.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
):
return self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
labels=labels,
use_cache=False if labels is not None else True,
)
# ==================== 自回归生成 ====================
@torch.no_grad()
def generate_music(
self,
initial_tokens: torch.Tensor,
device: str = 'cuda',
max_length: int = 8192,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 0.95,
repetition_penalty: float = 1.0,
) -> torch.Tensor:
self.eval()
input_ids = initial_tokens.unsqueeze(0).to(device)
output = self.model.generate(
input_ids=input_ids,
max_length=max_length,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
)
return output.cpu()
|