File size: 1,034 Bytes
1ff38fb b695230 1ff38fb 6914bc9 32e3089 1ff38fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
#!/bin/bash
set -e
# Collect arguments passed to the container (CMD)
args=("$@")
# Sanitize OMP_NUM_THREADS if it's not an integer (e.g. "3500m" from HF Spaces)
if ! [[ "$OMP_NUM_THREADS" =~ ^[0-9]+$ ]]; then
echo "WARNING: OMP_NUM_THREADS is '$OMP_NUM_THREADS', which is not an integer. Resetting to 1."
export OMP_NUM_THREADS=1
fi
# If HF_TOKEN is set, append it to the arguments
if [ -n "$HF_TOKEN" ]; then
args+=("--hub_token" "$HF_TOKEN")
fi
# Generate Accelerate config at runtime to detect GPUs correctly
# This writes to ~/.cache/huggingface/accelerate/default_config.yaml
python -c "from accelerate.utils import write_basic_config; write_basic_config(mixed_precision='fp16')"
# Start a dummy web server in the background to satisfy HF Spaces health check (port 7860)
# This serves the current directory, which should contain index.html
python -m http.server 7860 &
# Run accelerate launch with the training script and arguments
exec accelerate launch tasks/image_classification/train_energy.py "${args[@]}"
|