LukeDarlow commited on
Commit
69b35a9
·
1 Parent(s): 52d4c84

added support for a small maze dataset, with the purpose of debugging and iteration

Browse files
tasks/mazes/README.md CHANGED
@@ -8,3 +8,9 @@ To run the maze training that we used for the paper, run the following command f
8
  ```
9
  python -m tasks.mazes.train --d_model 2048 --d_input 512 --synapse_depth 4 --heads 8 --n_synch_out 64 --n_synch_action 32 --neuron_select_type first-last --iterations 75 --memory_length 25 --deep_memory --memory_hidden_dims 32 --dropout 0.1 --no-do_normalisation --positional_embedding_type none --backbone_type resnet34-2 --batch_size 64 --batch_size_test 64 --lr 1e-4 --training_iterations 1000001 --warmup_steps 10000 --use_scheduler --scheduler_type cosine --weight_decay 0.0 --log_dir logs/mazes/d=2048--i=512--h=8--ns=64-32--iters=75x25--h=32--drop=0.1--pos=none--back=34-2--seed=42 --dataset mazes-medium --save_every 2000 --track_every 5000 --seed 42 --n_test_batches 50
10
  ```
 
 
 
 
 
 
 
8
  ```
9
  python -m tasks.mazes.train --d_model 2048 --d_input 512 --synapse_depth 4 --heads 8 --n_synch_out 64 --n_synch_action 32 --neuron_select_type first-last --iterations 75 --memory_length 25 --deep_memory --memory_hidden_dims 32 --dropout 0.1 --no-do_normalisation --positional_embedding_type none --backbone_type resnet34-2 --batch_size 64 --batch_size_test 64 --lr 1e-4 --training_iterations 1000001 --warmup_steps 10000 --use_scheduler --scheduler_type cosine --weight_decay 0.0 --log_dir logs/mazes/d=2048--i=512--h=8--ns=64-32--iters=75x25--h=32--drop=0.1--pos=none--back=34-2--seed=42 --dataset mazes-medium --save_every 2000 --track_every 5000 --seed 42 --n_test_batches 50
10
  ```
11
+
12
+ ## Small training run
13
+ We also provide a 'mazes-small' dataset (see [here](https://drive.google.com/file/d/1cBgqhaUUtsrll8-o2VY42hPpyBcfFv86/view?usp=drivesdk)) for fast iteration and testing ideas. The following command can train a CTM locally without a GPU in 12-24 hours:
14
+ ```
15
+ python -m tasks.mazes.train --dataset mazes-small --maze_route_length 50 --cirriculum_lookahead 5 --model ctm --d_model 1024 --d_input 256 --backbone_type resnet18-1 --synapse_depth 8 --heads 4 --n_synch_out 128 --n_synch_action 128 --neuron_select_type random-pairing --memory_length 25 --iterations 50 --training_iterations 100001 --lr 1e-4 --batch_size 64 --batch_size_test 32 --n_test_batches 50 --log_dir logs/mazes-small-tester --track_every 2000
16
+ ```
tasks/mazes/plotting.py CHANGED
@@ -96,7 +96,7 @@ def make_maze_gif(inputs, predictions, targets, attention_tracking, save_locatio
96
  route_colours = []
97
  solution_maze = draw_path(np.moveaxis(inputs, 0, -1), targets)
98
 
99
- # cv2.imwrite(f'{save_location}/ground_truth.png', solution_maze[:,:,::-1]*255)
100
  mosaic = [['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
101
  ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
102
  ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
@@ -104,9 +104,25 @@ def make_maze_gif(inputs, predictions, targets, attention_tracking, save_locatio
104
  ['head_0', 'head_1', 'head_2', 'head_3', 'head_4', 'head_5', 'head_6', 'head_7'],
105
  ['head_8', 'head_9', 'head_10', 'head_11', 'head_12', 'head_13', 'head_14', 'head_15'],
106
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  img_aspect = 1
108
  figscale = 1
109
- aspect_ratio = (8 * figscale, 6 * figscale * img_aspect) # W, H
110
 
111
  route_steps = [np.unravel_index(np.argmax((inputs == np.reshape(np.array([1, 0, 0]), (3, 1, 1))).all(0)), inputs.shape[1:])] # Starting point
112
  frames = []
 
96
  route_colours = []
97
  solution_maze = draw_path(np.moveaxis(inputs, 0, -1), targets)
98
 
99
+ n_heads = attention_tracking.shape[1]
100
  mosaic = [['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
101
  ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
102
  ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
 
104
  ['head_0', 'head_1', 'head_2', 'head_3', 'head_4', 'head_5', 'head_6', 'head_7'],
105
  ['head_8', 'head_9', 'head_10', 'head_11', 'head_12', 'head_13', 'head_14', 'head_15'],
106
  ]
107
+ if n_heads == 8:
108
+ mosaic = [['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
109
+ ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
110
+ ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
111
+ ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
112
+ ['head_0', 'head_1', 'head_2', 'head_3', 'head_4', 'head_5', 'head_6', 'head_7'],
113
+ ]
114
+ elif n_heads == 4:
115
+ mosaic = [['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
116
+ ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
117
+ ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
118
+ ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
119
+ ['head_0', 'head_0', 'head_1', 'head_1', 'head_2', 'head_2', 'head_3', 'head_3'],
120
+ ['head_0', 'head_0', 'head_1', 'head_1', 'head_2', 'head_2', 'head_3', 'head_3'],
121
+ ]
122
+
123
  img_aspect = 1
124
  figscale = 1
125
+ aspect_ratio = (len(mosaic[0]) * figscale, len(mosaic) * figscale * img_aspect) # W, H
126
 
127
  route_steps = [np.unravel_index(np.argmax((inputs == np.reshape(np.array([1, 0, 0]), (3, 1, 1))).all(0)), inputs.shape[1:])] # Starting point
128
  frames = []
tasks/mazes/train.py CHANGED
@@ -108,7 +108,7 @@ def parse_args():
108
 
109
  # Logging and Saving
110
  parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.')
111
- parser.add_argument('--dataset', type=str, default='mazes-medium', help='Dataset to use.', choices=['mazes-medium', 'mazes-large'])
112
  parser.add_argument('--data_root', type=str, default='data/mazes', help='Data root.')
113
 
114
  parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
@@ -139,7 +139,7 @@ if __name__=='__main__':
139
  set_seed(args.seed, False)
140
  if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
141
 
142
- assert args.dataset in ['mazes-medium', 'mazes-large']
143
 
144
 
145
 
 
108
 
109
  # Logging and Saving
110
  parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.')
111
+ parser.add_argument('--dataset', type=str, default='mazes-medium', help='Dataset to use.', choices=['mazes-medium', 'mazes-large', 'mazes-small'])
112
  parser.add_argument('--data_root', type=str, default='data/mazes', help='Data root.')
113
 
114
  parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
 
139
  set_seed(args.seed, False)
140
  if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
141
 
142
+ assert args.dataset in ['mazes-medium', 'mazes-large', 'mazes-small']
143
 
144
 
145