Commit
·
69b35a9
1
Parent(s):
52d4c84
added support for a small maze dataset, with the purpose of debugging and iteration
Browse files- tasks/mazes/README.md +6 -0
- tasks/mazes/plotting.py +18 -2
- tasks/mazes/train.py +2 -2
tasks/mazes/README.md
CHANGED
|
@@ -8,3 +8,9 @@ To run the maze training that we used for the paper, run the following command f
|
|
| 8 |
```
|
| 9 |
python -m tasks.mazes.train --d_model 2048 --d_input 512 --synapse_depth 4 --heads 8 --n_synch_out 64 --n_synch_action 32 --neuron_select_type first-last --iterations 75 --memory_length 25 --deep_memory --memory_hidden_dims 32 --dropout 0.1 --no-do_normalisation --positional_embedding_type none --backbone_type resnet34-2 --batch_size 64 --batch_size_test 64 --lr 1e-4 --training_iterations 1000001 --warmup_steps 10000 --use_scheduler --scheduler_type cosine --weight_decay 0.0 --log_dir logs/mazes/d=2048--i=512--h=8--ns=64-32--iters=75x25--h=32--drop=0.1--pos=none--back=34-2--seed=42 --dataset mazes-medium --save_every 2000 --track_every 5000 --seed 42 --n_test_batches 50
|
| 10 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
```
|
| 9 |
python -m tasks.mazes.train --d_model 2048 --d_input 512 --synapse_depth 4 --heads 8 --n_synch_out 64 --n_synch_action 32 --neuron_select_type first-last --iterations 75 --memory_length 25 --deep_memory --memory_hidden_dims 32 --dropout 0.1 --no-do_normalisation --positional_embedding_type none --backbone_type resnet34-2 --batch_size 64 --batch_size_test 64 --lr 1e-4 --training_iterations 1000001 --warmup_steps 10000 --use_scheduler --scheduler_type cosine --weight_decay 0.0 --log_dir logs/mazes/d=2048--i=512--h=8--ns=64-32--iters=75x25--h=32--drop=0.1--pos=none--back=34-2--seed=42 --dataset mazes-medium --save_every 2000 --track_every 5000 --seed 42 --n_test_batches 50
|
| 10 |
```
|
| 11 |
+
|
| 12 |
+
## Small training run
|
| 13 |
+
We also provide a 'mazes-small' dataset (see [here](https://drive.google.com/file/d/1cBgqhaUUtsrll8-o2VY42hPpyBcfFv86/view?usp=drivesdk)) for fast iteration and testing ideas. The following command can train a CTM locally without a GPU in 12-24 hours:
|
| 14 |
+
```
|
| 15 |
+
python -m tasks.mazes.train --dataset mazes-small --maze_route_length 50 --cirriculum_lookahead 5 --model ctm --d_model 1024 --d_input 256 --backbone_type resnet18-1 --synapse_depth 8 --heads 4 --n_synch_out 128 --n_synch_action 128 --neuron_select_type random-pairing --memory_length 25 --iterations 50 --training_iterations 100001 --lr 1e-4 --batch_size 64 --batch_size_test 32 --n_test_batches 50 --log_dir logs/mazes-small-tester --track_every 2000
|
| 16 |
+
```
|
tasks/mazes/plotting.py
CHANGED
|
@@ -96,7 +96,7 @@ def make_maze_gif(inputs, predictions, targets, attention_tracking, save_locatio
|
|
| 96 |
route_colours = []
|
| 97 |
solution_maze = draw_path(np.moveaxis(inputs, 0, -1), targets)
|
| 98 |
|
| 99 |
-
|
| 100 |
mosaic = [['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 101 |
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 102 |
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
|
@@ -104,9 +104,25 @@ def make_maze_gif(inputs, predictions, targets, attention_tracking, save_locatio
|
|
| 104 |
['head_0', 'head_1', 'head_2', 'head_3', 'head_4', 'head_5', 'head_6', 'head_7'],
|
| 105 |
['head_8', 'head_9', 'head_10', 'head_11', 'head_12', 'head_13', 'head_14', 'head_15'],
|
| 106 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
img_aspect = 1
|
| 108 |
figscale = 1
|
| 109 |
-
aspect_ratio = (
|
| 110 |
|
| 111 |
route_steps = [np.unravel_index(np.argmax((inputs == np.reshape(np.array([1, 0, 0]), (3, 1, 1))).all(0)), inputs.shape[1:])] # Starting point
|
| 112 |
frames = []
|
|
|
|
| 96 |
route_colours = []
|
| 97 |
solution_maze = draw_path(np.moveaxis(inputs, 0, -1), targets)
|
| 98 |
|
| 99 |
+
n_heads = attention_tracking.shape[1]
|
| 100 |
mosaic = [['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 101 |
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 102 |
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
|
|
|
| 104 |
['head_0', 'head_1', 'head_2', 'head_3', 'head_4', 'head_5', 'head_6', 'head_7'],
|
| 105 |
['head_8', 'head_9', 'head_10', 'head_11', 'head_12', 'head_13', 'head_14', 'head_15'],
|
| 106 |
]
|
| 107 |
+
if n_heads == 8:
|
| 108 |
+
mosaic = [['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 109 |
+
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 110 |
+
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 111 |
+
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 112 |
+
['head_0', 'head_1', 'head_2', 'head_3', 'head_4', 'head_5', 'head_6', 'head_7'],
|
| 113 |
+
]
|
| 114 |
+
elif n_heads == 4:
|
| 115 |
+
mosaic = [['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 116 |
+
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 117 |
+
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 118 |
+
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 119 |
+
['head_0', 'head_0', 'head_1', 'head_1', 'head_2', 'head_2', 'head_3', 'head_3'],
|
| 120 |
+
['head_0', 'head_0', 'head_1', 'head_1', 'head_2', 'head_2', 'head_3', 'head_3'],
|
| 121 |
+
]
|
| 122 |
+
|
| 123 |
img_aspect = 1
|
| 124 |
figscale = 1
|
| 125 |
+
aspect_ratio = (len(mosaic[0]) * figscale, len(mosaic) * figscale * img_aspect) # W, H
|
| 126 |
|
| 127 |
route_steps = [np.unravel_index(np.argmax((inputs == np.reshape(np.array([1, 0, 0]), (3, 1, 1))).all(0)), inputs.shape[1:])] # Starting point
|
| 128 |
frames = []
|
tasks/mazes/train.py
CHANGED
|
@@ -108,7 +108,7 @@ def parse_args():
|
|
| 108 |
|
| 109 |
# Logging and Saving
|
| 110 |
parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.')
|
| 111 |
-
parser.add_argument('--dataset', type=str, default='mazes-medium', help='Dataset to use.', choices=['mazes-medium', 'mazes-large'])
|
| 112 |
parser.add_argument('--data_root', type=str, default='data/mazes', help='Data root.')
|
| 113 |
|
| 114 |
parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
|
|
@@ -139,7 +139,7 @@ if __name__=='__main__':
|
|
| 139 |
set_seed(args.seed, False)
|
| 140 |
if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
|
| 141 |
|
| 142 |
-
assert args.dataset in ['mazes-medium', 'mazes-large']
|
| 143 |
|
| 144 |
|
| 145 |
|
|
|
|
| 108 |
|
| 109 |
# Logging and Saving
|
| 110 |
parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.')
|
| 111 |
+
parser.add_argument('--dataset', type=str, default='mazes-medium', help='Dataset to use.', choices=['mazes-medium', 'mazes-large', 'mazes-small'])
|
| 112 |
parser.add_argument('--data_root', type=str, default='data/mazes', help='Data root.')
|
| 113 |
|
| 114 |
parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
|
|
|
|
| 139 |
set_seed(args.seed, False)
|
| 140 |
if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
|
| 141 |
|
| 142 |
+
assert args.dataset in ['mazes-medium', 'mazes-large', 'mazes-small']
|
| 143 |
|
| 144 |
|
| 145 |
|