thecollabagepatch commited on
Commit
76c5192
·
1 Parent(s): f9457b6

no more duration params

Browse files
Files changed (1) hide show
  1. app.py +116 -103
app.py CHANGED
@@ -18,8 +18,21 @@ def preprocess_audio(waveform):
18
  waveform_np = waveform.cpu().squeeze().numpy()
19
  return torch.from_numpy(waveform_np).unsqueeze(0).to(device)
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # @spaces.GPU(duration=10)
22
- # def generate_drum_sample() -> str:
23
  # model = MusicGen.get_pretrained('pharoAIsanders420/micro-musicgen-jungle')
24
  # model.set_generation_params(duration=10)
25
  # wav = model.generate_unconditional(1).squeeze(0)
@@ -31,140 +44,140 @@ def preprocess_audio(waveform):
31
 
32
  # return filename_with_extension
33
 
34
- @spaces.GPU(duration=10)
35
- def continue_drum_sample(existing_audio_path):
36
- if existing_audio_path is None:
37
- return None
38
 
39
- existing_audio, sr = torchaudio.load(existing_audio_path)
40
- existing_audio = existing_audio.to(device)
41
 
42
- prompt_duration = 2
43
- output_duration = 10
44
 
45
- num_samples = int(prompt_duration * sr)
46
- if existing_audio.shape[1] < num_samples:
47
- raise ValueError("The existing audio is too short for the specified prompt duration.")
48
 
49
- start_sample = existing_audio.shape[1] - num_samples
50
- prompt_waveform = existing_audio[..., start_sample:]
51
 
52
- model = MusicGen.get_pretrained('pharoAIsanders420/micro-musicgen-jungle')
53
- model.set_generation_params(duration=output_duration)
54
 
55
- output = model.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True)
56
- output = output.to(device)
57
 
58
- if output.dim() == 3:
59
- output = output.squeeze(0)
60
 
61
- if output.dim() == 1:
62
- output = output.unsqueeze(0)
63
 
64
- combined_audio = torch.cat((existing_audio, output), dim=1)
65
- combined_audio = combined_audio.cpu()
66
 
67
- combined_file_path = f'./continued_jungle_{random.randint(1000, 9999)}.wav'
68
- torchaudio.save(combined_file_path, combined_audio, sr)
69
 
70
- return combined_file_path
71
 
72
- @spaces.GPU(duration=120)
73
- def generate_music(wav_filename, prompt_duration, musicgen_model, output_duration):
74
- if wav_filename is None:
75
- return None
76
 
77
- song, sr = torchaudio.load(wav_filename)
78
- song = song.to(device)
79
-
80
- model_name = musicgen_model.split(" ")[0]
81
- model_continue = MusicGen.get_pretrained(model_name)
82
-
83
- model_continue.set_generation_params(
84
- use_sampling=True,
85
- top_k=250,
86
- top_p=0.0,
87
- temperature=1.0,
88
- duration=output_duration,
89
- cfg_coef=3
90
- )
91
-
92
- prompt_waveform = song[..., :int(prompt_duration * sr)]
93
- prompt_waveform = preprocess_audio(prompt_waveform)
94
 
95
- output = model_continue.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True)
96
- output = output.cpu()
97
 
98
- if len(output.size()) > 2:
99
- output = output.squeeze()
100
 
101
- filename_without_extension = f'continued_music'
102
- filename_with_extension = f'{filename_without_extension}.wav'
103
- audio_write(filename_without_extension, output, model_continue.sample_rate, strategy="loudness", loudness_compressor=True)
104
 
105
- return filename_with_extension
106
 
107
- @spaces.GPU(duration=120)
108
- def continue_music(input_audio_path, prompt_duration, musicgen_model, output_duration):
109
- if input_audio_path is None:
110
- return None
111
 
112
- song, sr = torchaudio.load(input_audio_path)
113
- song = song.to(device)
114
 
115
- model_continue = MusicGen.get_pretrained(musicgen_model.split(" ")[0])
116
- model_continue.set_generation_params(
117
- use_sampling=True,
118
- top_k=250,
119
- top_p=0.0,
120
- temperature=1.0,
121
- duration=output_duration,
122
- cfg_coef=3
123
- )
124
 
125
- original_audio = AudioSegment.from_mp3(input_audio_path)
126
- current_audio = original_audio
127
 
128
- file_paths_for_cleanup = []
129
 
130
- for i in range(1):
131
- num_samples = int(prompt_duration * sr)
132
- if current_audio.duration_seconds * 1000 < prompt_duration * 1000:
133
- raise ValueError("The prompt_duration is longer than the current audio length.")
134
 
135
- start_time = current_audio.duration_seconds * 1000 - prompt_duration * 1000
136
- prompt_audio = current_audio[start_time:]
137
 
138
- prompt_bytes = prompt_audio.export(format="wav").read()
139
- prompt_waveform, _ = torchaudio.load(io.BytesIO(prompt_bytes))
140
- prompt_waveform = prompt_waveform.to(device)
141
 
142
- prompt_waveform = preprocess_audio(prompt_waveform)
143
 
144
- output = model_continue.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True)
145
- output = output.cpu()
146
 
147
- if len(output.size()) > 2:
148
- output = output.squeeze()
149
 
150
- filename_without_extension = f'continue_{i}'
151
- filename_with_extension = f'{filename_without_extension}.wav'
152
- correct_filename_extension = f'{filename_without_extension}.wav.wav'
153
 
154
- audio_write(filename_with_extension, output, model_continue.sample_rate, strategy="loudness", loudness_compressor=True)
155
- generated_audio_segment = AudioSegment.from_wav(correct_filename_extension)
156
 
157
- current_audio = current_audio[:start_time] + generated_audio_segment
158
 
159
- file_paths_for_cleanup.append(correct_filename_extension)
160
 
161
- combined_audio_filename = f"combined_audio_{random.randint(1, 10000)}.mp3"
162
- current_audio.export(combined_audio_filename, format="mp3")
163
 
164
- for file_path in file_paths_for_cleanup:
165
- os.remove(file_path)
166
 
167
- return combined_audio_filename
168
 
169
  # Define the expandable sections (keeping your existing content)
170
  musicgen_micro_blurb = """
@@ -266,9 +279,9 @@ with gr.Blocks() as iface:
266
 
267
  # Connecting the components
268
  # generate_button.click(generate_drum_sample, outputs=[drum_audio])
269
- continue_drum_sample_button.click(continue_drum_sample, inputs=[drum_audio], outputs=[drum_audio])
270
- generate_music_button.click(generate_music, inputs=[drum_audio, prompt_duration, musicgen_model, output_duration], outputs=[output_audio])
271
- continue_button.click(continue_music, inputs=[output_audio, prompt_duration, musicgen_model, output_duration], outputs=continue_output_audio)
272
 
273
  if __name__ == "__main__":
274
  iface.launch()
 
18
  waveform_np = waveform.cpu().squeeze().numpy()
19
  return torch.from_numpy(waveform_np).unsqueeze(0).to(device)
20
 
21
+ @spaces.GPU
22
+ def generate_drum_sample() -> str:
23
+ model = MusicGen.get_pretrained('pharoAIsanders420/micro-musicgen-jungle')
24
+ model.set_generation_params(duration=10)
25
+ wav = model.generate_unconditional(1).squeeze(0)
26
+
27
+ filename_without_extension = f'jungle'
28
+ filename_with_extension = f'{filename_without_extension}.wav'
29
+
30
+ audio_write(filename_without_extension, wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
31
+
32
+ return filename_with_extension
33
+
34
  # @spaces.GPU(duration=10)
35
+ # def generate_drum_sample():
36
  # model = MusicGen.get_pretrained('pharoAIsanders420/micro-musicgen-jungle')
37
  # model.set_generation_params(duration=10)
38
  # wav = model.generate_unconditional(1).squeeze(0)
 
44
 
45
  # return filename_with_extension
46
 
47
+ # @spaces.GPU(duration=10)
48
+ # def continue_drum_sample(existing_audio_path):
49
+ # if existing_audio_path is None:
50
+ # return None
51
 
52
+ # existing_audio, sr = torchaudio.load(existing_audio_path)
53
+ # existing_audio = existing_audio.to(device)
54
 
55
+ # prompt_duration = 2
56
+ # output_duration = 10
57
 
58
+ # num_samples = int(prompt_duration * sr)
59
+ # if existing_audio.shape[1] < num_samples:
60
+ # raise ValueError("The existing audio is too short for the specified prompt duration.")
61
 
62
+ # start_sample = existing_audio.shape[1] - num_samples
63
+ # prompt_waveform = existing_audio[..., start_sample:]
64
 
65
+ # model = MusicGen.get_pretrained('pharoAIsanders420/micro-musicgen-jungle')
66
+ # model.set_generation_params(duration=output_duration)
67
 
68
+ # output = model.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True)
69
+ # output = output.to(device)
70
 
71
+ # if output.dim() == 3:
72
+ # output = output.squeeze(0)
73
 
74
+ # if output.dim() == 1:
75
+ # output = output.unsqueeze(0)
76
 
77
+ # combined_audio = torch.cat((existing_audio, output), dim=1)
78
+ # combined_audio = combined_audio.cpu()
79
 
80
+ # combined_file_path = f'./continued_jungle_{random.randint(1000, 9999)}.wav'
81
+ # torchaudio.save(combined_file_path, combined_audio, sr)
82
 
83
+ # return combined_file_path
84
 
85
+ # @spaces.GPU(duration=120)
86
+ # def generate_music(wav_filename, prompt_duration, musicgen_model, output_duration):
87
+ # if wav_filename is None:
88
+ # return None
89
 
90
+ # song, sr = torchaudio.load(wav_filename)
91
+ # song = song.to(device)
92
+
93
+ # model_name = musicgen_model.split(" ")[0]
94
+ # model_continue = MusicGen.get_pretrained(model_name)
95
+
96
+ # model_continue.set_generation_params(
97
+ # use_sampling=True,
98
+ # top_k=250,
99
+ # top_p=0.0,
100
+ # temperature=1.0,
101
+ # duration=output_duration,
102
+ # cfg_coef=3
103
+ # )
104
+
105
+ # prompt_waveform = song[..., :int(prompt_duration * sr)]
106
+ # prompt_waveform = preprocess_audio(prompt_waveform)
107
 
108
+ # output = model_continue.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True)
109
+ # output = output.cpu()
110
 
111
+ # if len(output.size()) > 2:
112
+ # output = output.squeeze()
113
 
114
+ # filename_without_extension = f'continued_music'
115
+ # filename_with_extension = f'{filename_without_extension}.wav'
116
+ # audio_write(filename_without_extension, output, model_continue.sample_rate, strategy="loudness", loudness_compressor=True)
117
 
118
+ # return filename_with_extension
119
 
120
+ # @spaces.GPU(duration=120)
121
+ # def continue_music(input_audio_path, prompt_duration, musicgen_model, output_duration):
122
+ # if input_audio_path is None:
123
+ # return None
124
 
125
+ # song, sr = torchaudio.load(input_audio_path)
126
+ # song = song.to(device)
127
 
128
+ # model_continue = MusicGen.get_pretrained(musicgen_model.split(" ")[0])
129
+ # model_continue.set_generation_params(
130
+ # use_sampling=True,
131
+ # top_k=250,
132
+ # top_p=0.0,
133
+ # temperature=1.0,
134
+ # duration=output_duration,
135
+ # cfg_coef=3
136
+ # )
137
 
138
+ # original_audio = AudioSegment.from_mp3(input_audio_path)
139
+ # current_audio = original_audio
140
 
141
+ # file_paths_for_cleanup = []
142
 
143
+ # for i in range(1):
144
+ # num_samples = int(prompt_duration * sr)
145
+ # if current_audio.duration_seconds * 1000 < prompt_duration * 1000:
146
+ # raise ValueError("The prompt_duration is longer than the current audio length.")
147
 
148
+ # start_time = current_audio.duration_seconds * 1000 - prompt_duration * 1000
149
+ # prompt_audio = current_audio[start_time:]
150
 
151
+ # prompt_bytes = prompt_audio.export(format="wav").read()
152
+ # prompt_waveform, _ = torchaudio.load(io.BytesIO(prompt_bytes))
153
+ # prompt_waveform = prompt_waveform.to(device)
154
 
155
+ # prompt_waveform = preprocess_audio(prompt_waveform)
156
 
157
+ # output = model_continue.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True)
158
+ # output = output.cpu()
159
 
160
+ # if len(output.size()) > 2:
161
+ # output = output.squeeze()
162
 
163
+ # filename_without_extension = f'continue_{i}'
164
+ # filename_with_extension = f'{filename_without_extension}.wav'
165
+ # correct_filename_extension = f'{filename_without_extension}.wav.wav'
166
 
167
+ # audio_write(filename_with_extension, output, model_continue.sample_rate, strategy="loudness", loudness_compressor=True)
168
+ # generated_audio_segment = AudioSegment.from_wav(correct_filename_extension)
169
 
170
+ # current_audio = current_audio[:start_time] + generated_audio_segment
171
 
172
+ # file_paths_for_cleanup.append(correct_filename_extension)
173
 
174
+ # combined_audio_filename = f"combined_audio_{random.randint(1, 10000)}.mp3"
175
+ # current_audio.export(combined_audio_filename, format="mp3")
176
 
177
+ # for file_path in file_paths_for_cleanup:
178
+ # os.remove(file_path)
179
 
180
+ # return combined_audio_filename
181
 
182
  # Define the expandable sections (keeping your existing content)
183
  musicgen_micro_blurb = """
 
279
 
280
  # Connecting the components
281
  # generate_button.click(generate_drum_sample, outputs=[drum_audio])
282
+ # continue_drum_sample_button.click(continue_drum_sample, inputs=[drum_audio], outputs=[drum_audio])
283
+ # generate_music_button.click(generate_music, inputs=[drum_audio, prompt_duration, musicgen_model, output_duration], outputs=[output_audio])
284
+ # continue_button.click(continue_music, inputs=[output_audio, prompt_duration, musicgen_model, output_duration], outputs=continue_output_audio)
285
 
286
  if __name__ == "__main__":
287
  iface.launch()