Fix:Added more env variables to run model training with accelerate
Browse files- Dockerfile +26 -8
Dockerfile
CHANGED
|
@@ -3,36 +3,54 @@
|
|
| 3 |
FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
|
| 4 |
|
| 5 |
# Set architecture list for A10G (Ampere, Compute Capability 8.6)
|
| 6 |
-
# This ensures that if any CUDA extensions are built, they target the correct architecture
|
| 7 |
ENV TORCH_CUDA_ARCH_LIST="8.6"
|
| 8 |
|
| 9 |
-
# Set working directory
|
| 10 |
-
WORKDIR /app
|
| 11 |
-
|
| 12 |
# Install system dependencies (ffmpeg for imageio/visualization, git for pip)
|
| 13 |
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
| 14 |
ffmpeg \
|
| 15 |
git \
|
| 16 |
&& rm -rf /var/lib/apt/lists/*
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
# Install dependencies
|
| 19 |
COPY requirements.txt .
|
| 20 |
|
| 21 |
# 1. Remove torch and torchvision from requirements.txt to prevent pip from upgrading them
|
| 22 |
-
# and replacing the optimized base image version with a generic wheel.
|
| 23 |
# 2. Install the rest of the requirements.
|
| 24 |
# 3. Explicitly ensure compatible torchvision is installed (0.16.0 matches torch 2.1.0).
|
| 25 |
RUN sed -i '/torch/d' requirements.txt && \
|
| 26 |
pip install --no-cache-dir -r requirements.txt && \
|
| 27 |
pip install --no-cache-dir torchvision==0.16.0
|
| 28 |
|
| 29 |
-
# Configure Accelerate (default to fp16 for speed)
|
| 30 |
-
RUN python -c "from accelerate.utils import write_basic_config; write_basic_config(mixed_precision='fp16')"
|
| 31 |
-
|
| 32 |
# Copy all project files into the container
|
| 33 |
COPY . .
|
| 34 |
|
|
|
|
| 35 |
COPY entrypoint.sh /app/entrypoint.sh
|
| 36 |
RUN chmod +x /app/entrypoint.sh
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
ENTRYPOINT ["/app/entrypoint.sh"]
|
| 38 |
CMD ["--energy_head_enabled", "--loss_type", "energy_contrastive", "--push_to_hub", "--hub_model_id", "Uday/ctm-energy-based-halting"]
|
|
|
|
| 3 |
FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
|
| 4 |
|
| 5 |
# Set architecture list for A10G (Ampere, Compute Capability 8.6)
|
|
|
|
| 6 |
ENV TORCH_CUDA_ARCH_LIST="8.6"
|
| 7 |
|
|
|
|
|
|
|
|
|
|
| 8 |
# Install system dependencies (ffmpeg for imageio/visualization, git for pip)
|
| 9 |
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
| 10 |
ffmpeg \
|
| 11 |
git \
|
| 12 |
&& rm -rf /var/lib/apt/lists/*
|
| 13 |
|
| 14 |
+
# Create a non-root user to match HF Spaces default (user 1000)
|
| 15 |
+
RUN useradd -m -u 1000 user
|
| 16 |
+
|
| 17 |
+
# Set working directory
|
| 18 |
+
WORKDIR /app
|
| 19 |
+
|
| 20 |
# Install dependencies
|
| 21 |
COPY requirements.txt .
|
| 22 |
|
| 23 |
# 1. Remove torch and torchvision from requirements.txt to prevent pip from upgrading them
|
|
|
|
| 24 |
# 2. Install the rest of the requirements.
|
| 25 |
# 3. Explicitly ensure compatible torchvision is installed (0.16.0 matches torch 2.1.0).
|
| 26 |
RUN sed -i '/torch/d' requirements.txt && \
|
| 27 |
pip install --no-cache-dir -r requirements.txt && \
|
| 28 |
pip install --no-cache-dir torchvision==0.16.0
|
| 29 |
|
|
|
|
|
|
|
|
|
|
| 30 |
# Copy all project files into the container
|
| 31 |
COPY . .
|
| 32 |
|
| 33 |
+
# Copy entrypoint
|
| 34 |
COPY entrypoint.sh /app/entrypoint.sh
|
| 35 |
RUN chmod +x /app/entrypoint.sh
|
| 36 |
+
|
| 37 |
+
# Set up environment variables for the user
|
| 38 |
+
ENV HOME=/home/user \
|
| 39 |
+
PATH=/home/user/.local/bin:$PATH \
|
| 40 |
+
MPLCONFIGDIR=/tmp/matplotlib \
|
| 41 |
+
NUMBA_CACHE_DIR=/tmp/numba_cache
|
| 42 |
+
|
| 43 |
+
# Create cache directories with correct permissions
|
| 44 |
+
RUN mkdir -p /tmp/matplotlib /tmp/numba_cache && \
|
| 45 |
+
chmod 777 /tmp/matplotlib /tmp/numba_cache && \
|
| 46 |
+
chown -R user:user /app
|
| 47 |
+
|
| 48 |
+
# Switch to the non-root user
|
| 49 |
+
USER user
|
| 50 |
+
|
| 51 |
+
# Configure Accelerate for the user (default to fp16 for speed)
|
| 52 |
+
# This writes to ~/.cache/huggingface/accelerate/default_config.yaml
|
| 53 |
+
RUN python -c "from accelerate.utils import write_basic_config; write_basic_config(mixed_precision='fp16')"
|
| 54 |
+
|
| 55 |
ENTRYPOINT ["/app/entrypoint.sh"]
|
| 56 |
CMD ["--energy_head_enabled", "--loss_type", "energy_contrastive", "--push_to_hub", "--hub_model_id", "Uday/ctm-energy-based-halting"]
|