File size: 1,863 Bytes
1ff38fb
 
 
c8c8629
1ff38fb
 
c8c8629
1ff38fb
4472870
1ff38fb
 
 
c8c8629
cc5b395
 
 
 
 
 
1ff38fb
 
 
 
 
 
 
 
 
 
 
 
 
cc5b395
1ff38fb
 
cc5b395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6914bc9
cc5b395
1ff38fb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Use a PyTorch image with CUDA support for faster training and better compatibility
# PyTorch 2.1.0 with CUDA 12.1 is fully compatible with NVIDIA A10G (Ampere)
FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime

# Set architecture list for A10G (Ampere, Compute Capability 8.6)
ENV TORCH_CUDA_ARCH_LIST="8.6"

# Install system dependencies (ffmpeg for imageio/visualization, git for pip)
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
    ffmpeg \
    git \
    && rm -rf /var/lib/apt/lists/*

# Create a non-root user to match HF Spaces default (user 1000)
RUN useradd -m -u 1000 user

# Set working directory
WORKDIR /app

# Install dependencies
COPY requirements.txt .

# 1. Remove torch and torchvision from requirements.txt to prevent pip from upgrading them
# 2. Install the rest of the requirements.
# 3. Explicitly ensure compatible torchvision is installed (0.16.0 matches torch 2.1.0).
RUN sed -i '/torch/d' requirements.txt && \
    pip install --no-cache-dir -r requirements.txt && \
    pip install --no-cache-dir torchvision==0.16.0

# Copy all project files into the container
COPY . .

# Copy entrypoint
COPY entrypoint.sh /app/entrypoint.sh
RUN chmod +x /app/entrypoint.sh

# Set up environment variables for the user
ENV HOME=/home/user \
    PATH=/home/user/.local/bin:$PATH \
    MPLCONFIGDIR=/tmp/matplotlib \
    NUMBA_CACHE_DIR=/tmp/numba_cache

# Create cache directories with correct permissions
RUN mkdir -p /tmp/matplotlib /tmp/numba_cache && \
    chmod 777 /tmp/matplotlib /tmp/numba_cache && \
    chown -R user:user /app

# Switch to the non-root user
USER user

# Accelerate configuration is now handled in entrypoint.sh at runtime

ENTRYPOINT ["/app/entrypoint.sh"]
CMD ["--energy_head_enabled", "--loss_type", "energy_contrastive", "--push_to_hub", "--hub_model_id", "Uday/ctm-energy-based-halting"]