| from __future__ import annotations | |
| import os | |
| import sys | |
| from collections.abc import MutableSequence | |
| def _arg_name_to_option(arg_name: str) -> str: | |
| arg_name = str(arg_name or "").strip() | |
| if not arg_name: | |
| return "" | |
| return arg_name if arg_name.startswith("--") else f"--{arg_name}" | |
| def _cuda_visible_device(device: str) -> str: | |
| device = str(device or "").strip().lower() | |
| if device.startswith("cuda:"): | |
| device = device.split(":", 1)[1] | |
| return device if device.isdigit() else "" | |
| def _rewrite_arg_value(argv: MutableSequence[str], option: str, value: str) -> None: | |
| for index, arg in enumerate(argv): | |
| if arg == option and index + 1 < len(argv): | |
| argv[index + 1] = value | |
| return | |
| if str(arg).startswith(f"{option}="): | |
| argv[index] = f"{option}={value}" | |
| return | |
| def set_default_cuda_device_from_arg(arg_name: str, default_device: str = "cuda:0") -> bool: | |
| option = _arg_name_to_option(arg_name) | |
| if not option: | |
| return False | |
| argv = sys.argv | |
| for index, arg in enumerate(argv[1:], start=1): | |
| if arg == option and index + 1 < len(argv): | |
| visible_device = _cuda_visible_device(argv[index + 1]) | |
| break | |
| if str(arg).startswith(f"{option}="): | |
| visible_device = _cuda_visible_device(str(arg).split("=", 1)[1]) | |
| break | |
| else: | |
| return False | |
| if not visible_device: | |
| return False | |
| os.environ["CUDA_VISIBLE_DEVICES"] = visible_device | |
| _rewrite_arg_value(argv, option, default_device) | |
| return True | |
Xet Storage Details
- Size:
- 1.6 kB
- Xet hash:
- 4be669b934a7bbc221b1142238fbc585ddd8151de2d56df94f980ed3892e4153
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.