import os import numpy as np import glob import math import torch import torchvision # For everything import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss, Linear, MSELoss from torch.nn import ConvTranspose2d, Conv2d, MaxPool2d, BatchNorm2d # For our model import torchvision.models as models from torchvision import datasets, transforms from torchvision.io import read_image from torch.utils.data import DataLoader, Dataset import torch.optim as optim from torch.autograd import Variable from torchsummary import summary class Nothing(nn.Module): def __init__(self): super(Nothing,self).__init__() def forward(self, radar,satellite): return radar, satellite class ConvBlock(nn.Module): def __init__(self, in_channels, out_channels): super(ConvBlock, self).__init__() # number of input channels is a number of filters in the previous layer # number of output channels is a number of filters in the current layer # "same" convolutions self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding='same', bias=True), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding='same', bias=True), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): x = self.conv(x) return x class UpConv(nn.Module): def __init__(self, in_channels, out_channels): super(UpConv, self).__init__() self.up = nn.Sequential( nn.Upsample(scale_factor=2), nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding='same', bias=True), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): x = self.up(x) return x class AttentionBlock(nn.Module): """Attention block with learnable parameters""" def __init__(self, F_g, F_l, n_coefficients): """ :param F_g: number of feature maps (channels) in previous layer :param F_l: number of feature maps in corresponding encoder layer, transferred via skip connection :param n_coefficients: number of learnable multi-dimensional attention coefficients """ super(AttentionBlock, self).__init__() self.W_gate = nn.Sequential( nn.Conv2d(F_g, n_coefficients, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(n_coefficients) ) self.W_x = nn.Sequential( nn.Conv2d(F_l, n_coefficients, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(n_coefficients) ) self.psi = nn.Sequential( nn.Conv2d(n_coefficients, 1, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(1), nn.Sigmoid() ) self.relu = nn.ReLU(inplace=True) def forward(self, gate, skip_connection): """ :param gate: gating signal from previous layer :param skip_connection: activation from corresponding encoder layer :return: output activations """ g1 = self.W_gate(gate) x1 = self.W_x(skip_connection) psi = self.relu(g1 + x1) psi = self.psi(psi) out = skip_connection * psi return out class Recurrent_block(nn.Module): def __init__(self,ch_out,t=2): super(Recurrent_block,self).__init__() self.t = t self.ch_out = ch_out self.conv = nn.Sequential( nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding='same',bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True) ) def forward(self,x): for i in range(self.t): if i==0: x1 = self.conv(x) x1 = self.conv(x+x1) return x1 class RRCNN_block(nn.Module): def __init__(self,ch_in,ch_out,t=2): super(RRCNN_block,self).__init__() self.RCNN = nn.Sequential( Recurrent_block(ch_out,t=t), Recurrent_block(ch_out,t=t) ) self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding='same') def forward(self,x): x = self.Conv_1x1(x) x1 = self.RCNN(x) return x+x1 class single_conv(nn.Module): def __init__(self,ch_in,ch_out): super(single_conv,self).__init__() self.conv = nn.Sequential( nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding='same',bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True) ) def forward(self,x): x = self.conv(x) return x class Unet(nn.Module): def __init__(self, rad_channel=1,sat_channel=1, rad_size=640, sat_size=20): super(Unet, self).__init__() assert rad_size % sat_size == 0, "rad_size must be divisible by sat_size" ratio = rad_size // sat_size assert (ratio & (ratio - 1)) == 0, "rad_size/sat_size must be a power of 2" self.n_pool = int(math.log2(ratio)) # Encoder self.encoder_blocks = nn.ModuleList() self.pools = nn.ModuleList() for i in range(self.n_pool): in_c = rad_channel * (2**(i)) out_c = rad_channel * (2**(i+1)) self.encoder_blocks.append(ConvBlock(in_c, out_c)) if i < self.n_pool: self.pools.append(nn.MaxPool2d(kernel_size=2, stride=2)) # Bottleneck self.mid_conv_1 = single_conv(out_c, out_c) self.mid_conv_2 = single_conv(sat_channel, out_c) self.mid_merge = ConvBlock(2*out_c, out_c) # Decoder self.up_convs = nn.ModuleList() self.decoder_blocks = nn.ModuleList() for i in reversed(range(self.n_pool)): up_in = rad_channel * (2**(i+2)) up_out = rad_channel * (2**(i+1)) self.up_convs.append(UpConv(up_in, up_out)) self.decoder_blocks.append(ConvBlock(up_in, up_out)) self.final_decoder = ConvBlock(4*rad_channel, 2*rad_channel) self.out_conv_R = nn.Conv2d(2*rad_channel, rad_channel, kernel_size=1, padding='same') self.out_conv_S = nn.Conv2d(out_c, sat_channel, kernel_size=1, padding='same') def forward(self, radar, satellite): # Encoding enc_feats = [] x = radar for i, block in enumerate(self.encoder_blocks): x = block(x) enc_feats.append(x) if i < self.n_pool: x = self.pools[i](x) # Bottleneck x = F.relu(self.mid_conv_1(x)) y = F.relu(self.mid_conv_2(satellite)) x = torch.cat((x, y), dim=1) mid_out = self.mid_merge(x) pred_sat = self.out_conv_S(mid_out) # Decoding x = x # input to decoder is original x before mid_merge for i in range(self.n_pool): x = self.up_convs[i](x) x = torch.cat((enc_feats[self.n_pool - 1 - i], x), dim=1) x = self.decoder_blocks[i](x) x = torch.cat((enc_feats[0], x), dim=1) x = self.final_decoder(x) pred_rad = self.out_conv_R(x) return pred_rad, pred_sat # class Unet(nn.Module): # def __init__(self,num_channel=1,rad_size=640,sat_size=20): # super(Unet, self).__init__() # self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2) # self.Conv1 = ConvBlock(1, 2*num_channel) # self.Conv2 = ConvBlock(2*num_channel, 4*num_channel) # self.Conv3 = ConvBlock(4*num_channel, 8*num_channel) # self.Conv4 = ConvBlock(8*num_channel, 16*num_channel) # self.Conv5 = ConvBlock(16*num_channel, 32*num_channel) # self.mid_conv_1 = single_conv(32*num_channel,32*num_channel) # self.mid_conv_2 = single_conv(2, 32*num_channel) # self.MidConv = ConvBlock(64*num_channel, 32*num_channel) # self.out_conv_S = Conv2d(32*num_channel, 1, (1, 1), padding= 'same') # self.Up5 = UpConv(64*num_channel, 32*num_channel) # self.UpConv5 = ConvBlock(64*num_channel, 32*num_channel) # self.Up4 = UpConv(32*num_channel, 16*num_channel) # self.UpConv4 = ConvBlock(32*num_channel, 16*num_channel) # self.Up3 = UpConv(16*num_channel, 8*num_channel) # self.UpConv3 = ConvBlock(16*num_channel, 8*num_channel) # self.Up2 = UpConv(8*num_channel, 4*num_channel) # self.UpConv2 = ConvBlock(8*num_channel, 4*num_channel) # self.Up1 = UpConv(4*num_channel, 2*num_channel) # self.UpConv1 = ConvBlock(4*num_channel, 2*num_channel) # self.out_conv_R = Conv2d(2*num_channel, 1, (1, 1), padding= 'same') # def forward(self, radar,satellite): # e1 = self.Conv1(radar) # e2 = self.MaxPool(e1) # e2 = self.Conv2(e2) # e3 = self.MaxPool(e2) # e3 = self.Conv3(e3) # e4 = self.MaxPool(e3) # e4 = self.Conv4(e4) # e5 = self.MaxPool(e4) # e5 = self.Conv5(e5) # e6 = self.MaxPool(e5) # X = F.relu(self.mid_conv_1(e6)) # Y = F.relu(self.mid_conv_2(satellite)) # X = torch.cat((X,Y),1) # Y = self.MidConv(X) # pred_satellite = self.out_conv_S(Y) # d5 = self.Up5(X) # d5 = torch.cat((e5, d5), dim=1) # d5 = self.UpConv5(d5) # d4 = self.Up4(d5) # d4 = torch.cat((e4, d4), dim=1) # d4 = self.UpConv4(d4) # d3 = self.Up3(d4) # d3 = torch.cat((e3, d3), dim=1) # d3 = self.UpConv3(d3) # d2 = self.Up2(d3) # d2 = torch.cat((e2, d2), dim=1) # d2 = self.UpConv2(d2) # d1 = self.Up1(d2) # d0 = torch.cat((e1, d1), dim=1) # d0 = self.UpConv1(d0) # pred_radar = self.out_conv_R(d0) # return pred_radar, pred_satellite class R2Unet(nn.Module): def __init__(self,num_channel=1,t=2): super(R2Unet, self).__init__() self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2) self.RRCNN1 = RRCNN_block(5,2*num_channel,t=t) self.RRCNN2 = RRCNN_block(2*num_channel,4*num_channel,t=t) self.RRCNN3 = RRCNN_block(4*num_channel,8*num_channel,t=t) self.RRCNN4 = RRCNN_block(8*num_channel,16*num_channel,t=t) self.RRCNN5 = RRCNN_block(16*num_channel,32*num_channel,t=t) self.mid_conv_1 = single_conv(32*num_channel,32*num_channel) self.mid_conv_2 = single_conv(2, 32*num_channel) self.MidConv = RRCNN_block(64*num_channel, 32*num_channel) self.out_conv_S = Conv2d(32*num_channel, 1, (1, 1), padding= 'same') self.Up5 = UpConv(64*num_channel, 32*num_channel) self.UpRRCNN5 = RRCNN_block(64*num_channel, 32*num_channel) self.Up4 = UpConv(32*num_channel, 16*num_channel) self.UpRRCNN4 = RRCNN_block(32*num_channel, 16*num_channel) self.Up3 = UpConv(16*num_channel, 8*num_channel) self.UpRRCNN3 = RRCNN_block(16*num_channel, 8*num_channel) self.Up2 = UpConv(8*num_channel, 4*num_channel) self.UpRRCNN2 = RRCNN_block(8*num_channel, 4*num_channel) self.Up1 = UpConv(4*num_channel, 2*num_channel) self.UpRRCNN1 = RRCNN_block(4*num_channel, 2*num_channel) self.out_conv_R = Conv2d(2*num_channel, 1, (1, 1), padding= 'same') def forward(self, radar,satellite): e1 = self.RRCNN1(radar) e2 = self.MaxPool(e1) e2 = self.RRCNN2(e2) e3 = self.MaxPool(e2) e3 = self.RRCNN3(e3) e4 = self.MaxPool(e3) e4 = self.RRCNN4(e4) e5 = self.MaxPool(e4) e5 = self.RRCNN5(e5) e6 = self.MaxPool(e5) X = F.relu(self.mid_conv_1(e6)) Y = F.relu(self.mid_conv_2(satellite)) X = torch.cat((X,Y),1) Y = self.MidConv(X) pred_satellite = self.out_conv_S(Y) d5 = self.Up5(X) d5 = torch.cat((e5, d5), dim=1) d5 = self.UpRRCNN5(d5) d4 = self.Up4(d5) d4 = torch.cat((e4, d4), dim=1) d4 = self.UpRRCNN4(d4) d3 = self.Up3(d4) d3 = torch.cat((e3, d3), dim=1) d3 = self.UpRRCNN3(d3) d2 = self.Up2(d3) d2 = torch.cat((e2, d2), dim=1) d2 = self.UpRRCNN2(d2) d1 = self.Up1(d2) d0 = torch.cat((e1, d1), dim=1) d0 = self.UpRRCNN1(d0) pred_radar = self.out_conv_R(d0) return pred_radar, pred_satellite class AttUnet(nn.Module): def __init__(self,num_channel=1): super(AttUnet, self).__init__() self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2) self.Conv1 = ConvBlock(5, 2*num_channel) self.Conv2 = ConvBlock(2*num_channel, 4*num_channel) self.Conv3 = ConvBlock(4*num_channel, 8*num_channel) self.Conv4 = ConvBlock(8*num_channel, 16*num_channel) self.Conv5 = ConvBlock(16*num_channel, 32*num_channel) self.mid_conv_1 = single_conv(32*num_channel,32*num_channel) self.mid_conv_2 = single_conv(2, 32*num_channel) self.MidConv = ConvBlock(64*num_channel, 32*num_channel) self.out_conv_S = Conv2d(32*num_channel, 1, (1, 1), padding= 'same') self.Up5 = UpConv(64*num_channel, 32*num_channel) self.Att5 = AttentionBlock(F_g=32*num_channel, F_l=32*num_channel, n_coefficients=16*num_channel) self.UpConv5 = ConvBlock(64*num_channel, 32*num_channel) self.Up4 = UpConv(32*num_channel, 16*num_channel) self.Att4 = AttentionBlock(F_g=16*num_channel, F_l=16*num_channel, n_coefficients=8*num_channel) self.UpConv4 = ConvBlock(32*num_channel, 16*num_channel) self.Up3 = UpConv(16*num_channel, 8*num_channel) self.Att3 = AttentionBlock(F_g=8*num_channel, F_l=8*num_channel, n_coefficients=4*num_channel) self.UpConv3 = ConvBlock(16*num_channel, 8*num_channel) self.Up2 = UpConv(8*num_channel, 4*num_channel) self.Att2 = AttentionBlock(F_g=4*num_channel, F_l=4*num_channel, n_coefficients=2*num_channel) self.UpConv2 = ConvBlock(8*num_channel, 4*num_channel) self.Up1 = UpConv(4*num_channel, 2*num_channel) self.Att1 = AttentionBlock(F_g=2*num_channel, F_l=2*num_channel, n_coefficients=1*num_channel) self.UpConv1 = ConvBlock(4*num_channel, 2*num_channel) self.out_conv_R = Conv2d(2*num_channel, 1, (1, 1), padding= 'same') def forward(self, radar,satellite): e1 = self.Conv1(radar) e2 = self.MaxPool(e1) e2 = self.Conv2(e2) e3 = self.MaxPool(e2) e3 = self.Conv3(e3) e4 = self.MaxPool(e3) e4 = self.Conv4(e4) e5 = self.MaxPool(e4) e5 = self.Conv5(e5) e6 = self.MaxPool(e5) X = F.relu(self.mid_conv_1(e6)) Y = F.relu(self.mid_conv_2(satellite)) X = torch.cat((X,Y),1) Y = self.MidConv(X) pred_satellite = self.out_conv_S(Y) d5 = self.Up5(X) s4 = self.Att5(gate=d5, skip_connection=e5) d5 = torch.cat((s4, d5), dim=1) # concatenate attention-weighted skip connection with previous layer output d5 = self.UpConv5(d5) d4 = self.Up4(d5) s3 = self.Att4(gate=d4, skip_connection=e4) d4 = torch.cat((s3, d4), dim=1) d4 = self.UpConv4(d4) d3 = self.Up3(d4) s2 = self.Att3(gate=d3, skip_connection=e3) d3 = torch.cat((s2, d3), dim=1) d3 = self.UpConv3(d3) d2 = self.Up2(d3) s1 = self.Att2(gate=d2, skip_connection=e2) d2 = torch.cat((s1, d2), dim=1) d2 = self.UpConv2(d2) d1 = self.Up1(d2) s0 = self.Att1(gate=d1, skip_connection=e1) d0 = torch.cat((s0, d1), dim=1) d0 = self.UpConv1(d0) pred_radar = self.out_conv_R(d0) return pred_radar, pred_satellite class AttR2Unet(nn.Module): def __init__(self,num_channel=1,t=2): super(AttR2Unet, self).__init__() self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2) self.RRCNN1 = RRCNN_block(5, 2*num_channel) self.RRCNN2 = RRCNN_block(2*num_channel, 4*num_channel) self.RRCNN3 = RRCNN_block(4*num_channel, 8*num_channel) self.RRCNN4 = RRCNN_block(8*num_channel, 16*num_channel) self.RRCNN5 = RRCNN_block(16*num_channel, 32*num_channel) self.mid_conv_1 = single_conv(32*num_channel,32*num_channel) self.mid_conv_2 = single_conv(2, 32*num_channel) self.MidConv = RRCNN_block(64*num_channel, 32*num_channel) self.out_conv_S = Conv2d(32*num_channel, 1, (1, 1), padding= 'same') self.Up5 = UpConv(64*num_channel, 32*num_channel) self.Att5 = AttentionBlock(F_g=32*num_channel, F_l=32*num_channel, n_coefficients=16*num_channel) self.UpRRCNN5 = RRCNN_block(64*num_channel, 32*num_channel) self.Up4 = UpConv(32*num_channel, 16*num_channel) self.Att4 = AttentionBlock(F_g=16*num_channel, F_l=16*num_channel, n_coefficients=8*num_channel) self.UpRRCNN4 = RRCNN_block(32*num_channel, 16*num_channel) self.Up3 = UpConv(16*num_channel, 8*num_channel) self.Att3 = AttentionBlock(F_g=8*num_channel, F_l=8*num_channel, n_coefficients=4*num_channel) self.UpRRCNN3 = RRCNN_block(16*num_channel, 8*num_channel) self.Up2 = UpConv(8*num_channel, 4*num_channel) self.Att2 = AttentionBlock(F_g=4*num_channel, F_l=4*num_channel, n_coefficients=2*num_channel) self.UpRRCNN2 = RRCNN_block(8*num_channel, 4*num_channel) self.Up1 = UpConv(4*num_channel, 2*num_channel) self.Att1 = AttentionBlock(F_g=2*num_channel, F_l=2*num_channel, n_coefficients=1*num_channel) self.UpRRCNN1 = RRCNN_block(4*num_channel, 2*num_channel) self.out_conv_R = Conv2d(2*num_channel, 1, (1, 1), padding= 'same') def forward(self, radar,satellite): e1 = self.RRCNN1(radar) e2 = self.MaxPool(e1) e2 = self.RRCNN2(e2) e3 = self.MaxPool(e2) e3 = self.RRCNN3(e3) e4 = self.MaxPool(e3) e4 = self.RRCNN4(e4) e5 = self.MaxPool(e4) e5 = self.RRCNN5(e5) e6 = self.MaxPool(e5) X = F.relu(self.mid_conv_1(e6)) Y = F.relu(self.mid_conv_2(satellite)) X = torch.cat((X,Y),1) Y = self.MidConv(X) pred_satellite = self.out_conv_S(Y) d5 = self.Up5(X) s4 = self.Att5(gate=d5, skip_connection=e5) d5 = torch.cat((s4, d5), dim=1) # concatenate attention-weighted skip connection with previous layer output d5 = self.UpRRCNN5(d5) d4 = self.Up4(d5) s3 = self.Att4(gate=d4, skip_connection=e4) d4 = torch.cat((s3, d4), dim=1) d4 = self.UpRRCNN4(d4) d3 = self.Up3(d4) s2 = self.Att3(gate=d3, skip_connection=e3) d3 = torch.cat((s2, d3), dim=1) d3 = self.UpRRCNN3(d3) d2 = self.Up2(d3) s1 = self.Att2(gate=d2, skip_connection=e2) d2 = torch.cat((s1, d2), dim=1) d2 = self.UpRRCNN2(d2) d1 = self.Up1(d2) s0 = self.Att1(gate=d1, skip_connection=e1) d0 = torch.cat((s0, d1), dim=1) d0 = self.UpRRCNN1(d0) pred_radar = self.out_conv_R(d0) return pred_radar, pred_satellite class Network(nn.Module): def __init__(self,model_type:str,rad_channel:int, sat_channel:int,rad_size:int,sat_size:int): super(Network,self).__init__() print(model_type) if(model_type == "Nothing"): self.net = Nothing() elif(model_type == "Unet"): self.net = Unet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size) elif(model_type == "Unet"): self.net = Unet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size) elif(model_type == "R2Unet"): self.net = R2Unet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size) elif(model_type == "AttUnet"): self.net = AttUnet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size) elif(model_type == "AttR2Unet"): self.net = AttR2Unet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size) else: raise ValueError("model_type is wrong") def forward(self, radar,satellite): pred_radar, pred_satellite = self.net.forward(radar,satellite) return pred_radar, pred_satellite