prompt_model / model.py
marisa0v0's picture
Upload prompt-conditioned mel→acc model (best checkpoint)
8b64619 verified
import torch
import torch.nn as nn
from transformers import LlamaForCausalLM, PreTrainedModel
from typing import Optional
class PianoLLaMA(PreTrainedModel):
"""基于LLaMA架构的Piano生成模型(纯模型层,不含 I/O)"""
_supports_flash_attn_2 = True
_supports_flash_attn = True
_supports_sdpa = True
def __init__(self, config):
super().__init__(config)
self.config = config
self.model = LlamaForCausalLM(self.config)
self.pad_token_id = config.pad_token_id
self.bos_token_id = config.bos_token_id
self.eos_token_id = config.eos_token_id
self.model.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
):
return self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
labels=labels,
use_cache=False if labels is not None else True,
)
# ==================== 自回归生成 ====================
@torch.no_grad()
def generate_music(
self,
initial_tokens: torch.Tensor,
device: str = 'cuda',
max_length: int = 8192,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 0.95,
repetition_penalty: float = 1.0,
) -> torch.Tensor:
self.eval()
input_ids = initial_tokens.unsqueeze(0).to(device)
output = self.model.generate(
input_ids=input_ids,
max_length=max_length,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
)
return output.cpu()