Commit
·
e8dc0c3
1
Parent(s):
caa71f5
MPS support updates
Browse files- tasks/image_classification/train.py +7 -2
- tasks/mazes/train.py +9 -3
- tasks/parity/train.py +8 -2
- tasks/qamnist/train.py +8 -3
- tasks/sort/train.py +8 -3
tasks/image_classification/train.py
CHANGED
|
@@ -212,8 +212,13 @@ if __name__=='__main__':
|
|
| 212 |
with open(f'{args.log_dir}/args.txt', 'w') as f:
|
| 213 |
print(args, file=f)
|
| 214 |
|
| 215 |
-
# Configure device string
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
print(f'Running model {args.model} on {device}')
|
| 218 |
|
| 219 |
# Build model conditionally
|
|
|
|
| 212 |
with open(f'{args.log_dir}/args.txt', 'w') as f:
|
| 213 |
print(args, file=f)
|
| 214 |
|
| 215 |
+
# Configure device string (support MPS on macOS)
|
| 216 |
+
if args.device[0] != -1:
|
| 217 |
+
device = f'cuda:{args.device[0]}'
|
| 218 |
+
elif torch.backends.mps.is_available():
|
| 219 |
+
device = 'mps'
|
| 220 |
+
else:
|
| 221 |
+
device = 'cpu'
|
| 222 |
print(f'Running model {args.model} on {device}')
|
| 223 |
|
| 224 |
# Build model conditionally
|
tasks/mazes/train.py
CHANGED
|
@@ -151,9 +151,15 @@ if __name__=='__main__':
|
|
| 151 |
with open(f'{args.log_dir}/args.txt', 'w') as f:
|
| 152 |
print(args, file=f)
|
| 153 |
|
| 154 |
-
# Configure device string
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
# Build model conditionally
|
| 159 |
model = None
|
|
|
|
| 151 |
with open(f'{args.log_dir}/args.txt', 'w') as f:
|
| 152 |
print(args, file=f)
|
| 153 |
|
| 154 |
+
# Configure device string (support MPS on macOS)
|
| 155 |
+
if args.device[0] != -1:
|
| 156 |
+
device = f'cuda:{args.device[0]}'
|
| 157 |
+
elif torch.backends.mps.is_available():
|
| 158 |
+
device = 'mps'
|
| 159 |
+
else:
|
| 160 |
+
device = 'cpu'
|
| 161 |
+
print(f'Running model {args.model} on {device}')
|
| 162 |
+
|
| 163 |
|
| 164 |
# Build model conditionally
|
| 165 |
model = None
|
tasks/parity/train.py
CHANGED
|
@@ -112,8 +112,14 @@ if __name__=='__main__':
|
|
| 112 |
with open(f'{args.log_dir}/args.txt', 'w') as f:
|
| 113 |
print(args, file=f)
|
| 114 |
|
| 115 |
-
# Configure device string
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
# Build model
|
| 119 |
model = prepare_model(prediction_reshaper, args, device)
|
|
|
|
| 112 |
with open(f'{args.log_dir}/args.txt', 'w') as f:
|
| 113 |
print(args, file=f)
|
| 114 |
|
| 115 |
+
# Configure device string (support MPS on macOS)
|
| 116 |
+
if args.device[0] != -1:
|
| 117 |
+
device = f'cuda:{args.device[0]}'
|
| 118 |
+
elif torch.backends.mps.is_available():
|
| 119 |
+
device = 'mps'
|
| 120 |
+
else:
|
| 121 |
+
device = 'cpu'
|
| 122 |
+
print(f'Running model {args.model} on {device}')
|
| 123 |
|
| 124 |
# Build model
|
| 125 |
model = prepare_model(prediction_reshaper, args, device)
|
tasks/qamnist/train.py
CHANGED
|
@@ -141,9 +141,14 @@ if __name__=='__main__':
|
|
| 141 |
with open(f'{args.log_dir}/args.txt', 'w') as f:
|
| 142 |
print(args, file=f)
|
| 143 |
|
| 144 |
-
# Configure device string
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
# Build model
|
| 149 |
model = prepare_model(args, device)
|
|
|
|
| 141 |
with open(f'{args.log_dir}/args.txt', 'w') as f:
|
| 142 |
print(args, file=f)
|
| 143 |
|
| 144 |
+
# Configure device string (support MPS on macOS)
|
| 145 |
+
if args.device[0] != -1:
|
| 146 |
+
device = f'cuda:{args.device[0]}'
|
| 147 |
+
elif torch.backends.mps.is_available():
|
| 148 |
+
device = 'mps'
|
| 149 |
+
else:
|
| 150 |
+
device = 'cpu'
|
| 151 |
+
print(f'Running model {args.model} on {device}')
|
| 152 |
|
| 153 |
# Build model
|
| 154 |
model = prepare_model(args, device)
|
tasks/sort/train.py
CHANGED
|
@@ -154,9 +154,14 @@ if __name__=='__main__':
|
|
| 154 |
with open(f'{args.log_dir}/args.txt', 'w') as f:
|
| 155 |
print(args, file=f)
|
| 156 |
|
| 157 |
-
# Configure device string
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
# Build model
|
| 162 |
model = ContinuousThoughtMachineSORT(
|
|
|
|
| 154 |
with open(f'{args.log_dir}/args.txt', 'w') as f:
|
| 155 |
print(args, file=f)
|
| 156 |
|
| 157 |
+
# Configure device string (support MPS on macOS)
|
| 158 |
+
if args.device[0] != -1:
|
| 159 |
+
device = f'cuda:{args.device[0]}'
|
| 160 |
+
elif torch.backends.mps.is_available():
|
| 161 |
+
device = 'mps'
|
| 162 |
+
else:
|
| 163 |
+
device = 'cpu'
|
| 164 |
+
print(f'Running model {args.model} on {device}')
|
| 165 |
|
| 166 |
# Build model
|
| 167 |
model = ContinuousThoughtMachineSORT(
|