File size: 2,680 Bytes
4f2b4bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#!/usr/bin/env python3
"""
Скрипт для скачивания модели SAM2.
Блин, Facebook не может нормально в pip packaging, поэтому качаем руками.
"""

import os
import urllib.request
import sys

# Директория для чекпоинтов
CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Модели на выбор
MODELS = {
    "tiny": {
        "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
        "filename": "sam2.1_hiera_tiny.pt",
        "size": "~39MB"
    },
    "small": {
        "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
        "filename": "sam2.1_hiera_small.pt",
        "size": "~46MB"
    },
    "base_plus": {
        "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
        "filename": "sam2.1_hiera_base_plus.pt",
        "size": "~81MB"
    },
    "large": {
        "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
        "filename": "sam2.1_hiera_large.pt",
        "size": "~224MB"
    }
}

def download_model(model_name="tiny"):
    """Качает модель, показывает прогресс"""
    if model_name not in MODELS:
        print(f"Неизвестная модель: {model_name}")
        print(f"Доступные: {', '.join(MODELS.keys())}")
        sys.exit(1)
    
    model_info = MODELS[model_name]
    filepath = os.path.join(CHECKPOINT_DIR, model_info["filename"])
    
    if os.path.exists(filepath):
        print(f"Модель уже скачана: {filepath}")
        return filepath
    
    print(f"Качаю {model_name} модель ({model_info['size']})...")
    print(f"URL: {model_info['url']}")
    
    def progress_hook(block_num, block_size, total_size):
        downloaded = block_num * block_size
        if total_size > 0:
            percent = min(100, downloaded * 100 / total_size)
            sys.stdout.write(f"\rПрогресс: {percent:.1f}%")
            sys.stdout.flush()
    
    try:
        urllib.request.urlretrieve(
            model_info["url"],
            filepath,
            reporthook=progress_hook
        )
        print(f"\n✓ Модель скачана: {filepath}")
        return filepath
    except Exception as e:
        print(f"\n✗ Ошибка при скачивании: {e}")
        if os.path.exists(filepath):
            os.remove(filepath)
        sys.exit(1)

if __name__ == "__main__":
    model_name = sys.argv[1] if len(sys.argv) > 1 else "tiny"
    download_model(model_name)