LukeDarlow commited on
Commit
e8dc0c3
·
1 Parent(s): caa71f5

MPS support updates

Browse files
tasks/image_classification/train.py CHANGED
@@ -212,8 +212,13 @@ if __name__=='__main__':
212
  with open(f'{args.log_dir}/args.txt', 'w') as f:
213
  print(args, file=f)
214
 
215
- # Configure device string
216
- device = f'cuda:{args.device[0]}' if args.device[0] != -1 else 'cpu'
 
 
 
 
 
217
  print(f'Running model {args.model} on {device}')
218
 
219
  # Build model conditionally
 
212
  with open(f'{args.log_dir}/args.txt', 'w') as f:
213
  print(args, file=f)
214
 
215
+ # Configure device string (support MPS on macOS)
216
+ if args.device[0] != -1:
217
+ device = f'cuda:{args.device[0]}'
218
+ elif torch.backends.mps.is_available():
219
+ device = 'mps'
220
+ else:
221
+ device = 'cpu'
222
  print(f'Running model {args.model} on {device}')
223
 
224
  # Build model conditionally
tasks/mazes/train.py CHANGED
@@ -151,9 +151,15 @@ if __name__=='__main__':
151
  with open(f'{args.log_dir}/args.txt', 'w') as f:
152
  print(args, file=f)
153
 
154
- # Configure device string
155
- device = f'cuda:{args.device[0]}' if args.device[0] != -1 else 'cpu'
156
- print(f'Running model {args.model} on {device} for dataset {args.dataset}')
 
 
 
 
 
 
157
 
158
  # Build model conditionally
159
  model = None
 
151
  with open(f'{args.log_dir}/args.txt', 'w') as f:
152
  print(args, file=f)
153
 
154
+ # Configure device string (support MPS on macOS)
155
+ if args.device[0] != -1:
156
+ device = f'cuda:{args.device[0]}'
157
+ elif torch.backends.mps.is_available():
158
+ device = 'mps'
159
+ else:
160
+ device = 'cpu'
161
+ print(f'Running model {args.model} on {device}')
162
+
163
 
164
  # Build model conditionally
165
  model = None
tasks/parity/train.py CHANGED
@@ -112,8 +112,14 @@ if __name__=='__main__':
112
  with open(f'{args.log_dir}/args.txt', 'w') as f:
113
  print(args, file=f)
114
 
115
- # Configure device string
116
- device = f'cuda:{args.device[0]}' if args.device[0] != -1 else 'cpu'
 
 
 
 
 
 
117
 
118
  # Build model
119
  model = prepare_model(prediction_reshaper, args, device)
 
112
  with open(f'{args.log_dir}/args.txt', 'w') as f:
113
  print(args, file=f)
114
 
115
+ # Configure device string (support MPS on macOS)
116
+ if args.device[0] != -1:
117
+ device = f'cuda:{args.device[0]}'
118
+ elif torch.backends.mps.is_available():
119
+ device = 'mps'
120
+ else:
121
+ device = 'cpu'
122
+ print(f'Running model {args.model} on {device}')
123
 
124
  # Build model
125
  model = prepare_model(prediction_reshaper, args, device)
tasks/qamnist/train.py CHANGED
@@ -141,9 +141,14 @@ if __name__=='__main__':
141
  with open(f'{args.log_dir}/args.txt', 'w') as f:
142
  print(args, file=f)
143
 
144
- # Configure device string
145
- device = f'cuda:{args.device[0]}' if args.device[0] != -1 else 'cpu'
146
- print(f'Running on {device}')
 
 
 
 
 
147
 
148
  # Build model
149
  model = prepare_model(args, device)
 
141
  with open(f'{args.log_dir}/args.txt', 'w') as f:
142
  print(args, file=f)
143
 
144
+ # Configure device string (support MPS on macOS)
145
+ if args.device[0] != -1:
146
+ device = f'cuda:{args.device[0]}'
147
+ elif torch.backends.mps.is_available():
148
+ device = 'mps'
149
+ else:
150
+ device = 'cpu'
151
+ print(f'Running model {args.model} on {device}')
152
 
153
  # Build model
154
  model = prepare_model(args, device)
tasks/sort/train.py CHANGED
@@ -154,9 +154,14 @@ if __name__=='__main__':
154
  with open(f'{args.log_dir}/args.txt', 'w') as f:
155
  print(args, file=f)
156
 
157
- # Configure device string
158
- device = f'cuda:{args.device[0]}' if args.device[0] != -1 else 'cpu'
159
- print(f'Running on {device}')
 
 
 
 
 
160
 
161
  # Build model
162
  model = ContinuousThoughtMachineSORT(
 
154
  with open(f'{args.log_dir}/args.txt', 'w') as f:
155
  print(args, file=f)
156
 
157
+ # Configure device string (support MPS on macOS)
158
+ if args.device[0] != -1:
159
+ device = f'cuda:{args.device[0]}'
160
+ elif torch.backends.mps.is_available():
161
+ device = 'mps'
162
+ else:
163
+ device = 'cpu'
164
+ print(f'Running model {args.model} on {device}')
165
 
166
  # Build model
167
  model = ContinuousThoughtMachineSORT(