| |
|
|
| |
|
|
| import matplotlib.pyplot as plt |
| |
| |
|
|
| |
| |
|
|
|
|
| from torchvision import transforms |
| from dataset_creation import normal_transforms |
| from model import MakiAlexNet |
| import numpy as np |
| import cv2, torch, os |
| from tqdm import tqdm |
| import time |
|
|
| TEST_IMAGE = "dataset/root/train/left1_frame_0.jpg" |
| MODEL_PARAMS = "alexnet_cognitive.pth" |
| all_processing_files = os.listdir(os.path.join(os.getcwd(), "./dataset/root/train")) |
|
|
| model = MakiAlexNet() |
|
|
| model.load_state_dict(torch.load(MODEL_PARAMS)) |
| model.eval() |
| print("Model armed and ready for evaluation.") |
|
|
| |
| print("Model's state_dict:") |
| for param_tensor in model.state_dict(): |
| print(param_tensor, "\t", model.state_dict()[param_tensor].size()) |
|
|
|
|
|
|
|
|
| for image_file in tqdm(all_processing_files): |
|
|
| |
| abs_file_path = os.path.join(os.getcwd(), "./dataset/root/train", image_file) |
| image = cv2.imread(abs_file_path) |
| |
| |
| |
|
|
|
|
| print("Image input shape of the matrix before: ", image.shape) |
| image = torch.unsqueeze(torch.tensor(image.astype(np.float32)), 0) |
| image = torch.einsum("BWHC->BCWH", image) |
| print("Image input shape of the matrix after: ", image.shape) |
| conv1_output = model.conv1(image) |
| print("Output shape of the matrix: ", conv1_output.shape) |
|
|
|
|
| |
|
|
| conv1_formatted = torch.einsum("BCWH->WHC", conv1_output) |
| print(f"Formatted shape of matrix is: {conv1_formatted.shape}") |
|
|
|
|
| |
| num_channels = conv1_formatted.shape[2] |
| max_rows = 5 |
| rows = min(max_rows, int(np.sqrt(num_channels))) |
| cols = int(np.ceil(num_channels / rows)) |
|
|
| fig, axes = plt.subplots(rows, cols, figsize=(12, 12)) |
|
|
| DATASET_OUTPUT_PATH = "./dataset/visualisation" |
| merged_frames = np.zeros((224,224)) |
| image_file_dir = abs_file_path.split(".jpg")[0].split("/")[-1] |
| if not os.path.isdir(os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir)): |
| os.mkdir(os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir)) |
|
|
|
|
| for i in range(rows): |
| for j in range(cols): |
| channel_idx = i * cols + j |
| if channel_idx < num_channels: |
| channel_data = conv1_formatted[:, :, channel_idx] |
| channel_data = channel_data.detach().numpy() |
| print(f"Channel Data shape dimension: {channel_data.shape}") |
| |
| |
| channel_data = cv2.resize(channel_data, (224, 224)) |
|
|
| |
| |
| |
| merged_frames += channel_data |
|
|
|
|
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
|
|
|
|
| |
|
|
| merged_frames /= (np.max(merged_frames) * .8) |
|
|
| |
|
|
| merged_frames_gray = merged_frames.astype(np.uint8) |
| |
|
|
|
|
| image_path = os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir, image_file_dir+"conv1_mask.jpg") |
|
|
| plt.imsave(image_path, merged_frames_gray, cmap='gray') |
|
|
| |
| heatmap_color = cv2.applyColorMap(merged_frames_gray, cv2.COLORMAP_JET) |
| |
| |
| image_path = os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir, image_file_dir+"conv1_heatmap.jpg") |
| plt.imsave(image_path, heatmap_color) |
| |
| |
| |
| plt.close() |
|
|
| exit() |
|
|
| |
| |
| |
| |
|
|
|
|
|
|