sam2-api / download_model.py
gbreadman13code
Deploy SAM2 segmentation API
4f2b4bb
#!/usr/bin/env python3
"""
Скрипт для скачивания модели SAM2.
Блин, Facebook не может нормально в pip packaging, поэтому качаем руками.
"""
import os
import urllib.request
import sys
# Директория для чекпоинтов
CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
# Модели на выбор
MODELS = {
"tiny": {
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
"filename": "sam2.1_hiera_tiny.pt",
"size": "~39MB"
},
"small": {
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
"filename": "sam2.1_hiera_small.pt",
"size": "~46MB"
},
"base_plus": {
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
"filename": "sam2.1_hiera_base_plus.pt",
"size": "~81MB"
},
"large": {
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
"filename": "sam2.1_hiera_large.pt",
"size": "~224MB"
}
}
def download_model(model_name="tiny"):
"""Качает модель, показывает прогресс"""
if model_name not in MODELS:
print(f"Неизвестная модель: {model_name}")
print(f"Доступные: {', '.join(MODELS.keys())}")
sys.exit(1)
model_info = MODELS[model_name]
filepath = os.path.join(CHECKPOINT_DIR, model_info["filename"])
if os.path.exists(filepath):
print(f"Модель уже скачана: {filepath}")
return filepath
print(f"Качаю {model_name} модель ({model_info['size']})...")
print(f"URL: {model_info['url']}")
def progress_hook(block_num, block_size, total_size):
downloaded = block_num * block_size
if total_size > 0:
percent = min(100, downloaded * 100 / total_size)
sys.stdout.write(f"\rПрогресс: {percent:.1f}%")
sys.stdout.flush()
try:
urllib.request.urlretrieve(
model_info["url"],
filepath,
reporthook=progress_hook
)
print(f"\n✓ Модель скачана: {filepath}")
return filepath
except Exception as e:
print(f"\n✗ Ошибка при скачивании: {e}")
if os.path.exists(filepath):
os.remove(filepath)
sys.exit(1)
if __name__ == "__main__":
model_name = sys.argv[1] if len(sys.argv) > 1 else "tiny"
download_model(model_name)