AiCoderv2 commited on
Commit
b36751d
·
verified ·
1 Parent(s): 0a828ee

Update app.py from anycoder

Browse files
Files changed (1) hide show
  1. app.py +28 -35
app.py CHANGED
@@ -1,10 +1,12 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
  import os
5
  from pathlib import Path
6
  import time
7
  import tempfile
 
 
8
 
9
  # Custom theme for music maker
10
  custom_theme = gr.themes.Soft(
@@ -27,7 +29,7 @@ MODEL_CACHE_DIR = Path.home() / ".cache" / "huggingface" / "musicgen"
27
  MAX_NEW_TOKENS = 250
28
  AUDIO_DURATION = 10 # seconds
29
 
30
- # Initialize model and tokenizer
31
  def load_model():
32
  """Load the MusicGen model with caching"""
33
  if not os.path.exists(MODEL_CACHE_DIR):
@@ -36,14 +38,14 @@ def load_model():
36
  print("Loading MusicGen model...")
37
  start_time = time.time()
38
 
39
- # Load tokenizer
40
- tokenizer = AutoTokenizer.from_pretrained(
41
  MODEL_NAME,
42
  cache_dir=MODEL_CACHE_DIR
43
  )
44
 
45
- # Load model
46
- model = AutoModelForSeq2SeqLM.from_pretrained(
47
  MODEL_NAME,
48
  cache_dir=MODEL_CACHE_DIR,
49
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
@@ -54,10 +56,10 @@ def load_model():
54
 
55
  load_time = time.time() - start_time
56
  print(f"Model loaded in {load_time:.2f} seconds")
57
- return model, tokenizer
58
 
59
  # Global variables for model
60
- model, tokenizer = load_model()
61
 
62
  def generate_music(prompt, duration, temperature, top_k):
63
  """
@@ -73,47 +75,38 @@ def generate_music(prompt, duration, temperature, top_k):
73
  Generated audio file path
74
  """
75
  try:
76
- # Generate music
77
- inputs = tokenizer(
78
- [prompt],
79
- padding="max_length",
80
- truncation=True,
81
- max_length=64,
82
  return_tensors="pt"
83
  ).to(model.device)
84
 
85
  # Generate audio
86
- with torch.no_grad():
87
- audio_values = model.generate(
88
- **inputs,
89
- do_sample=True,
90
- max_new_tokens=MAX_NEW_TOKENS,
91
- temperature=temperature,
92
- top_k=top_k
93
- )
94
 
95
  # Convert to audio file
96
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
97
- # Save audio (this is a simplified version - actual MusicGen would need proper decoding)
98
- # For demo purposes, we'll create a simple audio file
99
- import numpy as np
100
- from scipy.io.wavfile import write
101
 
102
- # Generate simple sine wave for demo
103
- sample_rate = 44100
104
- t = np.linspace(0, duration, int(sample_rate * duration), False)
105
- frequency = 440 # A4 note
106
- audio_data = np.sin(2 * np.pi * frequency * t) * 0.5
107
 
108
- # Add some variation based on prompt length
109
- if len(prompt) > 20:
110
- audio_data = audio_data * 0.8 + np.random.normal(0, 0.1, len(audio_data))
111
 
112
  # Convert to 16-bit PCM format
113
  audio_data = (audio_data * 32767).astype(np.int16)
114
 
115
  # Write to file
116
- write(temp_file.name, sample_rate, audio_data)
117
 
118
  return temp_file.name
119
 
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoProcessor, MusicgenForConditionalGeneration
4
  import os
5
  from pathlib import Path
6
  import time
7
  import tempfile
8
+ import numpy as np
9
+ from scipy.io.wavfile import write
10
 
11
  # Custom theme for music maker
12
  custom_theme = gr.themes.Soft(
 
29
  MAX_NEW_TOKENS = 250
30
  AUDIO_DURATION = 10 # seconds
31
 
32
+ # Initialize model and processor
33
  def load_model():
34
  """Load the MusicGen model with caching"""
35
  if not os.path.exists(MODEL_CACHE_DIR):
 
38
  print("Loading MusicGen model...")
39
  start_time = time.time()
40
 
41
+ # Load processor (replaces tokenizer for MusicGen)
42
+ processor = AutoProcessor.from_pretrained(
43
  MODEL_NAME,
44
  cache_dir=MODEL_CACHE_DIR
45
  )
46
 
47
+ # Load model - MusicGen uses MusicgenForConditionalGeneration
48
+ model = MusicgenForConditionalGeneration.from_pretrained(
49
  MODEL_NAME,
50
  cache_dir=MODEL_CACHE_DIR,
51
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
 
56
 
57
  load_time = time.time() - start_time
58
  print(f"Model loaded in {load_time:.2f} seconds")
59
+ return model, processor
60
 
61
  # Global variables for model
62
+ model, processor = load_model()
63
 
64
  def generate_music(prompt, duration, temperature, top_k):
65
  """
 
75
  Generated audio file path
76
  """
77
  try:
78
+ # Generate music using MusicGen
79
+ inputs = processor(
80
+ text=[prompt],
81
+ padding=True,
 
 
82
  return_tensors="pt"
83
  ).to(model.device)
84
 
85
  # Generate audio
86
+ audio_values = model.generate(
87
+ **inputs,
88
+ max_new_tokens=MAX_NEW_TOKENS,
89
+ do_sample=True,
90
+ temperature=temperature,
91
+ top_k=top_k
92
+ )
 
93
 
94
  # Convert to audio file
95
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
96
+ # Get sampling rate from model config
97
+ sampling_rate = model.config.audio_encoder.sample_rate
 
 
98
 
99
+ # Convert audio tensor to numpy array
100
+ audio_data = audio_values[0, 0].cpu().numpy()
 
 
 
101
 
102
+ # Normalize audio
103
+ audio_data = audio_data / np.max(np.abs(audio_data)) * 0.9
 
104
 
105
  # Convert to 16-bit PCM format
106
  audio_data = (audio_data * 32767).astype(np.int16)
107
 
108
  # Write to file
109
+ write(temp_file.name, sampling_rate, audio_data)
110
 
111
  return temp_file.name
112