Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import tempfile | |
| import shutil | |
| from pathlib import Path | |
| from model_manager import ModelManager | |
| # サンプルファイルのディレクトリを定義(絶対パスに解決) | |
| EXAMPLES_DIR = (Path(__file__).parent / "example").resolve() | |
| OUTPUT_DIR = (Path(__file__).parent / "output").resolve() | |
| # 出力ディレクトリの作成 | |
| OUTPUT_DIR.mkdir(exist_ok=True) | |
| # インポートエラーのデバッグ情報を表示 | |
| try: | |
| import filetype | |
| print("✅ filetype module imported successfully") | |
| except ImportError as e: | |
| print(f"⚠️ filetype import failed: {e}") | |
| print("Using fallback file type detection") | |
| # モデルの初期化 | |
| print("=== モデルの初期化開始 ===") | |
| # PyTorchモデルを使用(TensorRTモデルは非常に大きいため) | |
| USE_PYTORCH = True | |
| model_manager = ModelManager(cache_dir="/tmp/ditto_models", use_pytorch=USE_PYTORCH) | |
| if not model_manager.setup_models(): | |
| raise RuntimeError("モデルのセットアップに失敗しました。") | |
| # SDKの初期化 | |
| if USE_PYTORCH: | |
| data_root = "./checkpoints/ditto_pytorch" | |
| cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl" | |
| else: | |
| data_root = "./checkpoints/ditto_trt_Ampere_Plus" | |
| cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl" | |
| # SDK初期化のためのグローバル変数 | |
| SDK = None | |
| try: | |
| # モジュールをインポート | |
| from stream_pipeline_offline import StreamSDK | |
| from inference import run, seed_everything | |
| SDK = StreamSDK(cfg_pkl, data_root) | |
| print("✅ SDK初期化成功") | |
| except Exception as e: | |
| print(f"❌ SDK初期化エラー: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise | |
| def process_talking_head(audio_file, source_image): | |
| """音声とソース画像からTalking Headビデオを生成""" | |
| if audio_file is None: | |
| return None, "音声ファイルをアップロードしてください。" | |
| if source_image is None: | |
| return None, "ソース画像をアップロードしてください。" | |
| try: | |
| # 出力ファイルの作成(出力ディレクトリ内) | |
| import uuid | |
| output_filename = f"{uuid.uuid4()}.mp4" | |
| output_path = str(OUTPUT_DIR / output_filename) | |
| # 処理実行 | |
| print(f"処理開始: audio={audio_file}, image={source_image}") | |
| seed_everything(1024) | |
| run(SDK, audio_file, source_image, output_path) | |
| # 結果の確認 | |
| if os.path.exists(output_path) and os.path.getsize(output_path) > 0: | |
| return output_path, "✅ 処理が完了しました!" | |
| else: | |
| return None, "❌ 処理に失敗しました。出力ファイルが生成されませんでした。" | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ エラーが発生しました: {str(e)}\n{traceback.format_exc()}" | |
| print(error_msg) | |
| return None, error_msg | |
| # Gradio UI | |
| with gr.Blocks(title="DittoTalkingHead") as demo: | |
| gr.Markdown(""" | |
| # DittoTalkingHead - Talking Head Generation | |
| 音声とソース画像から、リアルなTalking Headビデオを生成します。 | |
| ## 使い方 | |
| 1. **音声ファイル**(WAV形式)をアップロード | |
| 2. **ソース画像**(PNG/JPG形式)をアップロード | |
| 3. **生成**ボタンをクリック | |
| ⚠️ 初回実行時は、モデルのダウンロードのため時間がかかります(約2.5GB)。 | |
| ### 技術仕様 | |
| - **モデル**: DittoTalkingHead (PyTorch版) | |
| - **GPU**: NVIDIA A100推奨 | |
| - **モデル提供**: [digital-avatar/ditto-talkinghead](https://huggingface.co/digital-avatar/ditto-talkinghead) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio( | |
| label="音声ファイル (WAV)", | |
| type="filepath" | |
| ) | |
| image_input = gr.Image( | |
| label="ソース画像", | |
| type="filepath" | |
| ) | |
| generate_btn = gr.Button("生成", variant="primary") | |
| with gr.Column(): | |
| video_output = gr.Video( | |
| label="生成されたビデオ" | |
| ) | |
| status_output = gr.Textbox( | |
| label="ステータス", | |
| lines=3 | |
| ) | |
| # サンプル | |
| example_audio = EXAMPLES_DIR / "audio.wav" | |
| example_image = EXAMPLES_DIR / "image.png" | |
| if example_audio.exists() and example_image.exists(): | |
| gr.Examples( | |
| examples=[ | |
| [str(example_audio), str(example_image)] | |
| ], | |
| inputs=[audio_input, image_input], | |
| outputs=[video_output, status_output], | |
| fn=process_talking_head, | |
| cache_examples=True | |
| ) | |
| # イベントハンドラ | |
| generate_btn.click( | |
| fn=process_talking_head, | |
| inputs=[audio_input, image_input], | |
| outputs=[video_output, status_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| allowed_paths=[str(EXAMPLES_DIR), str(OUTPUT_DIR)] | |
| ) |