Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Скрипт для скачивания модели SAM2. | |
| Блин, Facebook не может нормально в pip packaging, поэтому качаем руками. | |
| """ | |
| import os | |
| import urllib.request | |
| import sys | |
| # Директория для чекпоинтов | |
| CHECKPOINT_DIR = "checkpoints" | |
| os.makedirs(CHECKPOINT_DIR, exist_ok=True) | |
| # Модели на выбор | |
| MODELS = { | |
| "tiny": { | |
| "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt", | |
| "filename": "sam2.1_hiera_tiny.pt", | |
| "size": "~39MB" | |
| }, | |
| "small": { | |
| "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt", | |
| "filename": "sam2.1_hiera_small.pt", | |
| "size": "~46MB" | |
| }, | |
| "base_plus": { | |
| "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt", | |
| "filename": "sam2.1_hiera_base_plus.pt", | |
| "size": "~81MB" | |
| }, | |
| "large": { | |
| "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt", | |
| "filename": "sam2.1_hiera_large.pt", | |
| "size": "~224MB" | |
| } | |
| } | |
| def download_model(model_name="tiny"): | |
| """Качает модель, показывает прогресс""" | |
| if model_name not in MODELS: | |
| print(f"Неизвестная модель: {model_name}") | |
| print(f"Доступные: {', '.join(MODELS.keys())}") | |
| sys.exit(1) | |
| model_info = MODELS[model_name] | |
| filepath = os.path.join(CHECKPOINT_DIR, model_info["filename"]) | |
| if os.path.exists(filepath): | |
| print(f"Модель уже скачана: {filepath}") | |
| return filepath | |
| print(f"Качаю {model_name} модель ({model_info['size']})...") | |
| print(f"URL: {model_info['url']}") | |
| def progress_hook(block_num, block_size, total_size): | |
| downloaded = block_num * block_size | |
| if total_size > 0: | |
| percent = min(100, downloaded * 100 / total_size) | |
| sys.stdout.write(f"\rПрогресс: {percent:.1f}%") | |
| sys.stdout.flush() | |
| try: | |
| urllib.request.urlretrieve( | |
| model_info["url"], | |
| filepath, | |
| reporthook=progress_hook | |
| ) | |
| print(f"\n✓ Модель скачана: {filepath}") | |
| return filepath | |
| except Exception as e: | |
| print(f"\n✗ Ошибка при скачивании: {e}") | |
| if os.path.exists(filepath): | |
| os.remove(filepath) | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| model_name = sys.argv[1] if len(sys.argv) > 1 else "tiny" | |
| download_model(model_name) | |