Uday commited on
Commit
cc5b395
·
1 Parent(s): b695230

Fix:Added more env variables to run model training with accelerate

Browse files
Files changed (1) hide show
  1. Dockerfile +26 -8
Dockerfile CHANGED
@@ -3,36 +3,54 @@
3
  FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
4
 
5
  # Set architecture list for A10G (Ampere, Compute Capability 8.6)
6
- # This ensures that if any CUDA extensions are built, they target the correct architecture
7
  ENV TORCH_CUDA_ARCH_LIST="8.6"
8
 
9
- # Set working directory
10
- WORKDIR /app
11
-
12
  # Install system dependencies (ffmpeg for imageio/visualization, git for pip)
13
  RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
14
  ffmpeg \
15
  git \
16
  && rm -rf /var/lib/apt/lists/*
17
 
 
 
 
 
 
 
18
  # Install dependencies
19
  COPY requirements.txt .
20
 
21
  # 1. Remove torch and torchvision from requirements.txt to prevent pip from upgrading them
22
- # and replacing the optimized base image version with a generic wheel.
23
  # 2. Install the rest of the requirements.
24
  # 3. Explicitly ensure compatible torchvision is installed (0.16.0 matches torch 2.1.0).
25
  RUN sed -i '/torch/d' requirements.txt && \
26
  pip install --no-cache-dir -r requirements.txt && \
27
  pip install --no-cache-dir torchvision==0.16.0
28
 
29
- # Configure Accelerate (default to fp16 for speed)
30
- RUN python -c "from accelerate.utils import write_basic_config; write_basic_config(mixed_precision='fp16')"
31
-
32
  # Copy all project files into the container
33
  COPY . .
34
 
 
35
  COPY entrypoint.sh /app/entrypoint.sh
36
  RUN chmod +x /app/entrypoint.sh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  ENTRYPOINT ["/app/entrypoint.sh"]
38
  CMD ["--energy_head_enabled", "--loss_type", "energy_contrastive", "--push_to_hub", "--hub_model_id", "Uday/ctm-energy-based-halting"]
 
3
  FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
4
 
5
  # Set architecture list for A10G (Ampere, Compute Capability 8.6)
 
6
  ENV TORCH_CUDA_ARCH_LIST="8.6"
7
 
 
 
 
8
  # Install system dependencies (ffmpeg for imageio/visualization, git for pip)
9
  RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
10
  ffmpeg \
11
  git \
12
  && rm -rf /var/lib/apt/lists/*
13
 
14
+ # Create a non-root user to match HF Spaces default (user 1000)
15
+ RUN useradd -m -u 1000 user
16
+
17
+ # Set working directory
18
+ WORKDIR /app
19
+
20
  # Install dependencies
21
  COPY requirements.txt .
22
 
23
  # 1. Remove torch and torchvision from requirements.txt to prevent pip from upgrading them
 
24
  # 2. Install the rest of the requirements.
25
  # 3. Explicitly ensure compatible torchvision is installed (0.16.0 matches torch 2.1.0).
26
  RUN sed -i '/torch/d' requirements.txt && \
27
  pip install --no-cache-dir -r requirements.txt && \
28
  pip install --no-cache-dir torchvision==0.16.0
29
 
 
 
 
30
  # Copy all project files into the container
31
  COPY . .
32
 
33
+ # Copy entrypoint
34
  COPY entrypoint.sh /app/entrypoint.sh
35
  RUN chmod +x /app/entrypoint.sh
36
+
37
+ # Set up environment variables for the user
38
+ ENV HOME=/home/user \
39
+ PATH=/home/user/.local/bin:$PATH \
40
+ MPLCONFIGDIR=/tmp/matplotlib \
41
+ NUMBA_CACHE_DIR=/tmp/numba_cache
42
+
43
+ # Create cache directories with correct permissions
44
+ RUN mkdir -p /tmp/matplotlib /tmp/numba_cache && \
45
+ chmod 777 /tmp/matplotlib /tmp/numba_cache && \
46
+ chown -R user:user /app
47
+
48
+ # Switch to the non-root user
49
+ USER user
50
+
51
+ # Configure Accelerate for the user (default to fp16 for speed)
52
+ # This writes to ~/.cache/huggingface/accelerate/default_config.yaml
53
+ RUN python -c "from accelerate.utils import write_basic_config; write_basic_config(mixed_precision='fp16')"
54
+
55
  ENTRYPOINT ["/app/entrypoint.sh"]
56
  CMD ["--energy_head_enabled", "--loss_type", "energy_contrastive", "--push_to_hub", "--hub_model_id", "Uday/ctm-energy-based-halting"]