File size: 844 Bytes
1ff38fb b695230 1ff38fb 6914bc9 1ff38fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
#!/bin/bash
set -e
# Collect arguments passed to the container (CMD)
args=("$@")
# Sanitize OMP_NUM_THREADS if it's not an integer (e.g. "3500m" from HF Spaces)
if ! [[ "$OMP_NUM_THREADS" =~ ^[0-9]+$ ]]; then
echo "WARNING: OMP_NUM_THREADS is '$OMP_NUM_THREADS', which is not an integer. Resetting to 1."
export OMP_NUM_THREADS=1
fi
# If HF_TOKEN is set, append it to the arguments
if [ -n "$HF_TOKEN" ]; then
args+=("--hub_token" "$HF_TOKEN")
fi
# Generate Accelerate config at runtime to detect GPUs correctly
# This writes to ~/.cache/huggingface/accelerate/default_config.yaml
python -c "from accelerate.utils import write_basic_config; write_basic_config(mixed_precision='fp16')"
# Run accelerate launch with the training script and arguments
exec accelerate launch tasks/image_classification/train_energy.py "${args[@]}"
|