File size: 1,034 Bytes
1ff38fb
 
 
 
 
 
b695230
 
 
 
 
 
1ff38fb
 
 
 
 
6914bc9
 
 
 
32e3089
 
 
 
1ff38fb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#!/bin/bash
set -e

# Collect arguments passed to the container (CMD)
args=("$@")

# Sanitize OMP_NUM_THREADS if it's not an integer (e.g. "3500m" from HF Spaces)
if ! [[ "$OMP_NUM_THREADS" =~ ^[0-9]+$ ]]; then
    echo "WARNING: OMP_NUM_THREADS is '$OMP_NUM_THREADS', which is not an integer. Resetting to 1."
    export OMP_NUM_THREADS=1
fi

# If HF_TOKEN is set, append it to the arguments
if [ -n "$HF_TOKEN" ]; then
    args+=("--hub_token" "$HF_TOKEN")
fi

# Generate Accelerate config at runtime to detect GPUs correctly
# This writes to ~/.cache/huggingface/accelerate/default_config.yaml
python -c "from accelerate.utils import write_basic_config; write_basic_config(mixed_precision='fp16')"

# Start a dummy web server in the background to satisfy HF Spaces health check (port 7860)
# This serves the current directory, which should contain index.html
python -m http.server 7860 &

# Run accelerate launch with the training script and arguments
exec accelerate launch tasks/image_classification/train_energy.py "${args[@]}"