Merge pull request #16 from ciaran-regan-ie/main
Browse filesAdditional Examples to aid in understanding the CTM.
- .gitignore +2 -0
- examples/01_mnist.ipynb +0 -0
- examples/03_mazes.ipynb +0 -0
- examples/04_parity.ipynb +0 -0
- tasks/mazes/plotting.py +2 -2
.gitignore
CHANGED
|
@@ -17,4 +17,6 @@ data/*
|
|
| 17 |
examples/*
|
| 18 |
!examples/01_mnist.ipynb
|
| 19 |
!examples/02_inference.ipynb
|
|
|
|
|
|
|
| 20 |
checkpoints
|
|
|
|
| 17 |
examples/*
|
| 18 |
!examples/01_mnist.ipynb
|
| 19 |
!examples/02_inference.ipynb
|
| 20 |
+
!examples/03_mazes.ipynb
|
| 21 |
+
!examples/04_parity.ipynb
|
| 22 |
checkpoints
|
examples/01_mnist.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/03_mazes.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/04_parity.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tasks/mazes/plotting.py
CHANGED
|
@@ -88,7 +88,7 @@ def draw_path(x, route, valid_only=False, gt=False, cmap=None):
|
|
| 88 |
|
| 89 |
return x
|
| 90 |
|
| 91 |
-
def make_maze_gif(inputs, predictions, targets, attention_tracking, save_location):
|
| 92 |
"""
|
| 93 |
Expect inputs, predictions, targets as numpy arrays
|
| 94 |
"""
|
|
@@ -130,7 +130,7 @@ def make_maze_gif(inputs, predictions, targets, attention_tracking, save_locatio
|
|
| 130 |
cmap_viridis = plt.get_cmap('viridis')
|
| 131 |
step_linspace = np.linspace(0, 1, predictions.shape[-1]) # For sampling colours
|
| 132 |
with tqdm(total=predictions.shape[-1], initial=0, leave=True, position=1, dynamic_ncols=True) as pbar:
|
| 133 |
-
pbar.set_description('Processing frames for maze plotting')
|
| 134 |
for stepi in np.arange(0, predictions.shape[-1], 1):
|
| 135 |
fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
|
| 136 |
for ax in axes.values():
|
|
|
|
| 88 |
|
| 89 |
return x
|
| 90 |
|
| 91 |
+
def make_maze_gif(inputs, predictions, targets, attention_tracking, save_location, verbose=True):
|
| 92 |
"""
|
| 93 |
Expect inputs, predictions, targets as numpy arrays
|
| 94 |
"""
|
|
|
|
| 130 |
cmap_viridis = plt.get_cmap('viridis')
|
| 131 |
step_linspace = np.linspace(0, 1, predictions.shape[-1]) # For sampling colours
|
| 132 |
with tqdm(total=predictions.shape[-1], initial=0, leave=True, position=1, dynamic_ncols=True) as pbar:
|
| 133 |
+
if verbose: pbar.set_description('Processing frames for maze plotting')
|
| 134 |
for stepi in np.arange(0, predictions.shape[-1], 1):
|
| 135 |
fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
|
| 136 |
for ax in axes.values():
|