| | |
| | import numpy as np |
| | from matplotlib import pyplot as plt |
| |
|
| | import scipy |
| | import scipy.stats |
| | from imageio import imsave |
| | import cv2 |
| |
|
| |
|
| | def concat_images(images, image_width, spacer_size): |
| | """ Concat image horizontally with spacer """ |
| | spacer = np.ones([image_width, spacer_size, 4], dtype=np.uint8) * 255 |
| | images_with_spacers = [] |
| |
|
| | image_size = len(images) |
| |
|
| | for i in range(image_size): |
| | images_with_spacers.append(images[i]) |
| | if i != image_size - 1: |
| | |
| | images_with_spacers.append(spacer) |
| | ret = np.hstack(images_with_spacers) |
| | return ret |
| |
|
| |
|
| | def concat_images_in_rows(images, row_size, image_width, spacer_size=4): |
| | """ Concat images in rows """ |
| | column_size = len(images) // row_size |
| | spacer_h = np.ones([spacer_size, image_width*column_size + (column_size-1)*spacer_size, 4], |
| | dtype=np.uint8) * 255 |
| |
|
| | row_images_with_spacers = [] |
| |
|
| | for row in range(row_size): |
| | row_images = images[column_size*row:column_size*row+column_size] |
| | row_concated_images = concat_images(row_images, image_width, spacer_size) |
| | row_images_with_spacers.append(row_concated_images) |
| |
|
| | if row != row_size-1: |
| | row_images_with_spacers.append(spacer_h) |
| |
|
| | ret = np.vstack(row_images_with_spacers) |
| | return ret |
| |
|
| |
|
| | def convert_to_colormap(im, cmap): |
| | im = cmap(im) |
| | im = np.uint8(im * 255) |
| | return im |
| |
|
| |
|
| | def rgb(im, cmap='jet', smooth=True): |
| | cmap = plt.cm.get_cmap(cmap) |
| | np.seterr(invalid='ignore') |
| | im = (im - np.min(im)) / (np.max(im) - np.min(im)) |
| | if smooth: |
| | im = cv2.GaussianBlur(im, (3,3), sigmaX=1, sigmaY=0) |
| | im = cmap(im) |
| | im = np.uint8(im * 255) |
| | return im |
| |
|
| |
|
| | def plot_ratemaps(activations, n_plots, cmap='jet', smooth=True, width=16): |
| | images = [rgb(im, cmap, smooth) for im in activations[:n_plots]] |
| | rm_fig = concat_images_in_rows(images, n_plots//width, activations.shape[-1]) |
| | return rm_fig |
| |
|
| |
|
| | def compute_ratemaps(model, trajectory_generator, options, res=20, n_avg=None, Ng=512, idxs=None, return_raw=False): |
| | '''Compute spatial firing fields |
| | |
| | Args: |
| | model: The RNN model |
| | trajectory_generator: Generator for test trajectories |
| | options: Training options |
| | res: Resolution of the rate map grid |
| | n_avg: Number of batches to average over |
| | Ng: Number of grid cells to analyze |
| | idxs: Indices of specific grid cells to analyze |
| | return_raw: If True, also return raw activations (g) and positions (pos). |
| | Warning: This uses significant memory for large batch_size/n_avg. |
| | If False, returns None for g and pos to save memory. |
| | |
| | Returns: |
| | activations: Spatial firing fields [Ng, res, res] |
| | rate_map: Flattened rate maps [Ng, res*res] |
| | g: Raw activations (None if return_raw=False) |
| | pos: Raw positions (None if return_raw=False) |
| | ''' |
| |
|
| | if not n_avg: |
| | n_avg = 1000 // options.sequence_length |
| |
|
| | if not np.any(idxs): |
| | idxs = np.arange(Ng) |
| | idxs = idxs[:Ng] |
| |
|
| | |
| | if return_raw: |
| | g = np.zeros([n_avg, options.batch_size * options.sequence_length, Ng]) |
| | pos = np.zeros([n_avg, options.batch_size * options.sequence_length, 2]) |
| | else: |
| | g = None |
| | pos = None |
| |
|
| | activations = np.zeros([Ng, res, res]) |
| | counts = np.zeros([res, res]) |
| |
|
| | for index in range(n_avg): |
| | inputs, pos_batch, _ = trajectory_generator.get_test_batch() |
| | g_batch = model.g(inputs).detach().cpu().numpy() |
| | |
| | pos_batch = np.reshape(pos_batch.cpu(), [-1, 2]) |
| | g_batch = g_batch[:,:,idxs].reshape(-1, Ng) |
| | |
| | if return_raw: |
| | g[index] = g_batch |
| | pos[index] = pos_batch |
| |
|
| | x_batch = (pos_batch[:,0] + options.box_width/2) / (options.box_width) * res |
| | y_batch = (pos_batch[:,1] + options.box_height/2) / (options.box_height) * res |
| |
|
| | for i in range(options.batch_size*options.sequence_length): |
| | x = x_batch[i] |
| | y = y_batch[i] |
| | if x >=0 and x < res and y >=0 and y < res: |
| | counts[int(x), int(y)] += 1 |
| | activations[:, int(x), int(y)] += g_batch[i, :] |
| |
|
| | for x in range(res): |
| | for y in range(res): |
| | if counts[x, y] > 0: |
| | activations[:, x, y] /= counts[x, y] |
| | |
| | if return_raw: |
| | g = g.reshape([-1, Ng]) |
| | pos = pos.reshape([-1, 2]) |
| |
|
| | |
| | |
| | rate_map = activations.reshape(Ng, -1) |
| |
|
| | return activations, rate_map, g, pos |
| |
|
| |
|
| | def save_ratemaps(model, trajectory_generator, options, step, res=20, n_avg=None): |
| | if not n_avg: |
| | n_avg = 1000 // options.sequence_length |
| | activations, rate_map, g, pos = compute_ratemaps(model, trajectory_generator, |
| | options, res=res, n_avg=n_avg) |
| | rm_fig = plot_ratemaps(activations, n_plots=len(activations)) |
| | imdir = options.save_dir + "/" + options.run_ID |
| | imsave(imdir + "/" + str(step) + ".png", rm_fig) |
| |
|
| |
|
| | def save_autocorr(sess, model, save_name, trajectory_generator, step, flags): |
| | starts = [0.2] * 10 |
| | ends = np.linspace(0.4, 1.0, num=10) |
| | coord_range=((-1.1, 1.1), (-1.1, 1.1)) |
| | masks_parameters = zip(starts, ends.tolist()) |
| | latest_epoch_scorer = scores.GridScorer(20, coord_range, masks_parameters) |
| | |
| | res = dict() |
| | index_size = 100 |
| | for _ in range(index_size): |
| | feed_dict = trajectory_generator.feed_dict(flags.box_width, flags.box_height) |
| | mb_res = sess.run({ |
| | 'pos_xy': model.target_pos, |
| | 'bottleneck': model.g, |
| | }, feed_dict=feed_dict) |
| | res = utils.concat_dict(res, mb_res) |
| | |
| | filename = save_name + '/autocorrs_' + str(step) + '.pdf' |
| | imdir = flags.save_dir + '/' |
| | out = utils.get_scores_and_plot( |
| | latest_epoch_scorer, res['pos_xy'], res['bottleneck'], |
| | imdir, filename) |
| |
|