unet_code / src /arch.py
weatherforecast1024's picture
Upload folder using huggingface_hub
f3b050a verified
import os
import numpy as np
import glob
import math
import torch
import torchvision
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, Linear, MSELoss
from torch.nn import ConvTranspose2d, Conv2d, MaxPool2d, BatchNorm2d
# For our model
import torchvision.models as models
from torchvision import datasets, transforms
from torchvision.io import read_image
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torch.autograd import Variable
from torchsummary import summary
class Nothing(nn.Module):
def __init__(self):
super(Nothing,self).__init__()
def forward(self, radar,satellite):
return radar, satellite
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
# number of input channels is a number of filters in the previous layer
# number of output channels is a number of filters in the current layer
# "same" convolutions
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding='same', bias=True),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding='same', bias=True),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class UpConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(UpConv, self).__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding='same', bias=True),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.up(x)
return x
class AttentionBlock(nn.Module):
"""Attention block with learnable parameters"""
def __init__(self, F_g, F_l, n_coefficients):
"""
:param F_g: number of feature maps (channels) in previous layer
:param F_l: number of feature maps in corresponding encoder layer, transferred via skip connection
:param n_coefficients: number of learnable multi-dimensional attention coefficients
"""
super(AttentionBlock, self).__init__()
self.W_gate = nn.Sequential(
nn.Conv2d(F_g, n_coefficients, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(n_coefficients)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, n_coefficients, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(n_coefficients)
)
self.psi = nn.Sequential(
nn.Conv2d(n_coefficients, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, gate, skip_connection):
"""
:param gate: gating signal from previous layer
:param skip_connection: activation from corresponding encoder layer
:return: output activations
"""
g1 = self.W_gate(gate)
x1 = self.W_x(skip_connection)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
out = skip_connection * psi
return out
class Recurrent_block(nn.Module):
def __init__(self,ch_out,t=2):
super(Recurrent_block,self).__init__()
self.t = t
self.ch_out = ch_out
self.conv = nn.Sequential(
nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding='same',bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
for i in range(self.t):
if i==0:
x1 = self.conv(x)
x1 = self.conv(x+x1)
return x1
class RRCNN_block(nn.Module):
def __init__(self,ch_in,ch_out,t=2):
super(RRCNN_block,self).__init__()
self.RCNN = nn.Sequential(
Recurrent_block(ch_out,t=t),
Recurrent_block(ch_out,t=t)
)
self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding='same')
def forward(self,x):
x = self.Conv_1x1(x)
x1 = self.RCNN(x)
return x+x1
class single_conv(nn.Module):
def __init__(self,ch_in,ch_out):
super(single_conv,self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding='same',bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
x = self.conv(x)
return x
class Unet(nn.Module):
def __init__(self, rad_channel=1,sat_channel=1, rad_size=640, sat_size=20):
super(Unet, self).__init__()
assert rad_size % sat_size == 0, "rad_size must be divisible by sat_size"
ratio = rad_size // sat_size
assert (ratio & (ratio - 1)) == 0, "rad_size/sat_size must be a power of 2"
self.n_pool = int(math.log2(ratio))
# Encoder
self.encoder_blocks = nn.ModuleList()
self.pools = nn.ModuleList()
for i in range(self.n_pool):
in_c = rad_channel * (2**(i))
out_c = rad_channel * (2**(i+1))
self.encoder_blocks.append(ConvBlock(in_c, out_c))
if i < self.n_pool:
self.pools.append(nn.MaxPool2d(kernel_size=2, stride=2))
# Bottleneck
self.mid_conv_1 = single_conv(out_c, out_c)
self.mid_conv_2 = single_conv(sat_channel, out_c)
self.mid_merge = ConvBlock(2*out_c, out_c)
# Decoder
self.up_convs = nn.ModuleList()
self.decoder_blocks = nn.ModuleList()
for i in reversed(range(self.n_pool)):
up_in = rad_channel * (2**(i+2))
up_out = rad_channel * (2**(i+1))
self.up_convs.append(UpConv(up_in, up_out))
self.decoder_blocks.append(ConvBlock(up_in, up_out))
self.final_decoder = ConvBlock(4*rad_channel, 2*rad_channel)
self.out_conv_R = nn.Conv2d(2*rad_channel, rad_channel, kernel_size=1, padding='same')
self.out_conv_S = nn.Conv2d(out_c, sat_channel, kernel_size=1, padding='same')
def forward(self, radar, satellite):
# Encoding
enc_feats = []
x = radar
for i, block in enumerate(self.encoder_blocks):
x = block(x)
enc_feats.append(x)
if i < self.n_pool:
x = self.pools[i](x)
# Bottleneck
x = F.relu(self.mid_conv_1(x))
y = F.relu(self.mid_conv_2(satellite))
x = torch.cat((x, y), dim=1)
mid_out = self.mid_merge(x)
pred_sat = self.out_conv_S(mid_out)
# Decoding
x = x # input to decoder is original x before mid_merge
for i in range(self.n_pool):
x = self.up_convs[i](x)
x = torch.cat((enc_feats[self.n_pool - 1 - i], x), dim=1)
x = self.decoder_blocks[i](x)
x = torch.cat((enc_feats[0], x), dim=1)
x = self.final_decoder(x)
pred_rad = self.out_conv_R(x)
return pred_rad, pred_sat
# class Unet(nn.Module):
# def __init__(self,num_channel=1,rad_size=640,sat_size=20):
# super(Unet, self).__init__()
# self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2)
# self.Conv1 = ConvBlock(1, 2*num_channel)
# self.Conv2 = ConvBlock(2*num_channel, 4*num_channel)
# self.Conv3 = ConvBlock(4*num_channel, 8*num_channel)
# self.Conv4 = ConvBlock(8*num_channel, 16*num_channel)
# self.Conv5 = ConvBlock(16*num_channel, 32*num_channel)
# self.mid_conv_1 = single_conv(32*num_channel,32*num_channel)
# self.mid_conv_2 = single_conv(2, 32*num_channel)
# self.MidConv = ConvBlock(64*num_channel, 32*num_channel)
# self.out_conv_S = Conv2d(32*num_channel, 1, (1, 1), padding= 'same')
# self.Up5 = UpConv(64*num_channel, 32*num_channel)
# self.UpConv5 = ConvBlock(64*num_channel, 32*num_channel)
# self.Up4 = UpConv(32*num_channel, 16*num_channel)
# self.UpConv4 = ConvBlock(32*num_channel, 16*num_channel)
# self.Up3 = UpConv(16*num_channel, 8*num_channel)
# self.UpConv3 = ConvBlock(16*num_channel, 8*num_channel)
# self.Up2 = UpConv(8*num_channel, 4*num_channel)
# self.UpConv2 = ConvBlock(8*num_channel, 4*num_channel)
# self.Up1 = UpConv(4*num_channel, 2*num_channel)
# self.UpConv1 = ConvBlock(4*num_channel, 2*num_channel)
# self.out_conv_R = Conv2d(2*num_channel, 1, (1, 1), padding= 'same')
# def forward(self, radar,satellite):
# e1 = self.Conv1(radar)
# e2 = self.MaxPool(e1)
# e2 = self.Conv2(e2)
# e3 = self.MaxPool(e2)
# e3 = self.Conv3(e3)
# e4 = self.MaxPool(e3)
# e4 = self.Conv4(e4)
# e5 = self.MaxPool(e4)
# e5 = self.Conv5(e5)
# e6 = self.MaxPool(e5)
# X = F.relu(self.mid_conv_1(e6))
# Y = F.relu(self.mid_conv_2(satellite))
# X = torch.cat((X,Y),1)
# Y = self.MidConv(X)
# pred_satellite = self.out_conv_S(Y)
# d5 = self.Up5(X)
# d5 = torch.cat((e5, d5), dim=1)
# d5 = self.UpConv5(d5)
# d4 = self.Up4(d5)
# d4 = torch.cat((e4, d4), dim=1)
# d4 = self.UpConv4(d4)
# d3 = self.Up3(d4)
# d3 = torch.cat((e3, d3), dim=1)
# d3 = self.UpConv3(d3)
# d2 = self.Up2(d3)
# d2 = torch.cat((e2, d2), dim=1)
# d2 = self.UpConv2(d2)
# d1 = self.Up1(d2)
# d0 = torch.cat((e1, d1), dim=1)
# d0 = self.UpConv1(d0)
# pred_radar = self.out_conv_R(d0)
# return pred_radar, pred_satellite
class R2Unet(nn.Module):
def __init__(self,num_channel=1,t=2):
super(R2Unet, self).__init__()
self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2)
self.RRCNN1 = RRCNN_block(5,2*num_channel,t=t)
self.RRCNN2 = RRCNN_block(2*num_channel,4*num_channel,t=t)
self.RRCNN3 = RRCNN_block(4*num_channel,8*num_channel,t=t)
self.RRCNN4 = RRCNN_block(8*num_channel,16*num_channel,t=t)
self.RRCNN5 = RRCNN_block(16*num_channel,32*num_channel,t=t)
self.mid_conv_1 = single_conv(32*num_channel,32*num_channel)
self.mid_conv_2 = single_conv(2, 32*num_channel)
self.MidConv = RRCNN_block(64*num_channel, 32*num_channel)
self.out_conv_S = Conv2d(32*num_channel, 1, (1, 1), padding= 'same')
self.Up5 = UpConv(64*num_channel, 32*num_channel)
self.UpRRCNN5 = RRCNN_block(64*num_channel, 32*num_channel)
self.Up4 = UpConv(32*num_channel, 16*num_channel)
self.UpRRCNN4 = RRCNN_block(32*num_channel, 16*num_channel)
self.Up3 = UpConv(16*num_channel, 8*num_channel)
self.UpRRCNN3 = RRCNN_block(16*num_channel, 8*num_channel)
self.Up2 = UpConv(8*num_channel, 4*num_channel)
self.UpRRCNN2 = RRCNN_block(8*num_channel, 4*num_channel)
self.Up1 = UpConv(4*num_channel, 2*num_channel)
self.UpRRCNN1 = RRCNN_block(4*num_channel, 2*num_channel)
self.out_conv_R = Conv2d(2*num_channel, 1, (1, 1), padding= 'same')
def forward(self, radar,satellite):
e1 = self.RRCNN1(radar)
e2 = self.MaxPool(e1)
e2 = self.RRCNN2(e2)
e3 = self.MaxPool(e2)
e3 = self.RRCNN3(e3)
e4 = self.MaxPool(e3)
e4 = self.RRCNN4(e4)
e5 = self.MaxPool(e4)
e5 = self.RRCNN5(e5)
e6 = self.MaxPool(e5)
X = F.relu(self.mid_conv_1(e6))
Y = F.relu(self.mid_conv_2(satellite))
X = torch.cat((X,Y),1)
Y = self.MidConv(X)
pred_satellite = self.out_conv_S(Y)
d5 = self.Up5(X)
d5 = torch.cat((e5, d5), dim=1)
d5 = self.UpRRCNN5(d5)
d4 = self.Up4(d5)
d4 = torch.cat((e4, d4), dim=1)
d4 = self.UpRRCNN4(d4)
d3 = self.Up3(d4)
d3 = torch.cat((e3, d3), dim=1)
d3 = self.UpRRCNN3(d3)
d2 = self.Up2(d3)
d2 = torch.cat((e2, d2), dim=1)
d2 = self.UpRRCNN2(d2)
d1 = self.Up1(d2)
d0 = torch.cat((e1, d1), dim=1)
d0 = self.UpRRCNN1(d0)
pred_radar = self.out_conv_R(d0)
return pred_radar, pred_satellite
class AttUnet(nn.Module):
def __init__(self,num_channel=1):
super(AttUnet, self).__init__()
self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2)
self.Conv1 = ConvBlock(5, 2*num_channel)
self.Conv2 = ConvBlock(2*num_channel, 4*num_channel)
self.Conv3 = ConvBlock(4*num_channel, 8*num_channel)
self.Conv4 = ConvBlock(8*num_channel, 16*num_channel)
self.Conv5 = ConvBlock(16*num_channel, 32*num_channel)
self.mid_conv_1 = single_conv(32*num_channel,32*num_channel)
self.mid_conv_2 = single_conv(2, 32*num_channel)
self.MidConv = ConvBlock(64*num_channel, 32*num_channel)
self.out_conv_S = Conv2d(32*num_channel, 1, (1, 1), padding= 'same')
self.Up5 = UpConv(64*num_channel, 32*num_channel)
self.Att5 = AttentionBlock(F_g=32*num_channel, F_l=32*num_channel, n_coefficients=16*num_channel)
self.UpConv5 = ConvBlock(64*num_channel, 32*num_channel)
self.Up4 = UpConv(32*num_channel, 16*num_channel)
self.Att4 = AttentionBlock(F_g=16*num_channel, F_l=16*num_channel, n_coefficients=8*num_channel)
self.UpConv4 = ConvBlock(32*num_channel, 16*num_channel)
self.Up3 = UpConv(16*num_channel, 8*num_channel)
self.Att3 = AttentionBlock(F_g=8*num_channel, F_l=8*num_channel, n_coefficients=4*num_channel)
self.UpConv3 = ConvBlock(16*num_channel, 8*num_channel)
self.Up2 = UpConv(8*num_channel, 4*num_channel)
self.Att2 = AttentionBlock(F_g=4*num_channel, F_l=4*num_channel, n_coefficients=2*num_channel)
self.UpConv2 = ConvBlock(8*num_channel, 4*num_channel)
self.Up1 = UpConv(4*num_channel, 2*num_channel)
self.Att1 = AttentionBlock(F_g=2*num_channel, F_l=2*num_channel, n_coefficients=1*num_channel)
self.UpConv1 = ConvBlock(4*num_channel, 2*num_channel)
self.out_conv_R = Conv2d(2*num_channel, 1, (1, 1), padding= 'same')
def forward(self, radar,satellite):
e1 = self.Conv1(radar)
e2 = self.MaxPool(e1)
e2 = self.Conv2(e2)
e3 = self.MaxPool(e2)
e3 = self.Conv3(e3)
e4 = self.MaxPool(e3)
e4 = self.Conv4(e4)
e5 = self.MaxPool(e4)
e5 = self.Conv5(e5)
e6 = self.MaxPool(e5)
X = F.relu(self.mid_conv_1(e6))
Y = F.relu(self.mid_conv_2(satellite))
X = torch.cat((X,Y),1)
Y = self.MidConv(X)
pred_satellite = self.out_conv_S(Y)
d5 = self.Up5(X)
s4 = self.Att5(gate=d5, skip_connection=e5)
d5 = torch.cat((s4, d5), dim=1) # concatenate attention-weighted skip connection with previous layer output
d5 = self.UpConv5(d5)
d4 = self.Up4(d5)
s3 = self.Att4(gate=d4, skip_connection=e4)
d4 = torch.cat((s3, d4), dim=1)
d4 = self.UpConv4(d4)
d3 = self.Up3(d4)
s2 = self.Att3(gate=d3, skip_connection=e3)
d3 = torch.cat((s2, d3), dim=1)
d3 = self.UpConv3(d3)
d2 = self.Up2(d3)
s1 = self.Att2(gate=d2, skip_connection=e2)
d2 = torch.cat((s1, d2), dim=1)
d2 = self.UpConv2(d2)
d1 = self.Up1(d2)
s0 = self.Att1(gate=d1, skip_connection=e1)
d0 = torch.cat((s0, d1), dim=1)
d0 = self.UpConv1(d0)
pred_radar = self.out_conv_R(d0)
return pred_radar, pred_satellite
class AttR2Unet(nn.Module):
def __init__(self,num_channel=1,t=2):
super(AttR2Unet, self).__init__()
self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2)
self.RRCNN1 = RRCNN_block(5, 2*num_channel)
self.RRCNN2 = RRCNN_block(2*num_channel, 4*num_channel)
self.RRCNN3 = RRCNN_block(4*num_channel, 8*num_channel)
self.RRCNN4 = RRCNN_block(8*num_channel, 16*num_channel)
self.RRCNN5 = RRCNN_block(16*num_channel, 32*num_channel)
self.mid_conv_1 = single_conv(32*num_channel,32*num_channel)
self.mid_conv_2 = single_conv(2, 32*num_channel)
self.MidConv = RRCNN_block(64*num_channel, 32*num_channel)
self.out_conv_S = Conv2d(32*num_channel, 1, (1, 1), padding= 'same')
self.Up5 = UpConv(64*num_channel, 32*num_channel)
self.Att5 = AttentionBlock(F_g=32*num_channel, F_l=32*num_channel, n_coefficients=16*num_channel)
self.UpRRCNN5 = RRCNN_block(64*num_channel, 32*num_channel)
self.Up4 = UpConv(32*num_channel, 16*num_channel)
self.Att4 = AttentionBlock(F_g=16*num_channel, F_l=16*num_channel, n_coefficients=8*num_channel)
self.UpRRCNN4 = RRCNN_block(32*num_channel, 16*num_channel)
self.Up3 = UpConv(16*num_channel, 8*num_channel)
self.Att3 = AttentionBlock(F_g=8*num_channel, F_l=8*num_channel, n_coefficients=4*num_channel)
self.UpRRCNN3 = RRCNN_block(16*num_channel, 8*num_channel)
self.Up2 = UpConv(8*num_channel, 4*num_channel)
self.Att2 = AttentionBlock(F_g=4*num_channel, F_l=4*num_channel, n_coefficients=2*num_channel)
self.UpRRCNN2 = RRCNN_block(8*num_channel, 4*num_channel)
self.Up1 = UpConv(4*num_channel, 2*num_channel)
self.Att1 = AttentionBlock(F_g=2*num_channel, F_l=2*num_channel, n_coefficients=1*num_channel)
self.UpRRCNN1 = RRCNN_block(4*num_channel, 2*num_channel)
self.out_conv_R = Conv2d(2*num_channel, 1, (1, 1), padding= 'same')
def forward(self, radar,satellite):
e1 = self.RRCNN1(radar)
e2 = self.MaxPool(e1)
e2 = self.RRCNN2(e2)
e3 = self.MaxPool(e2)
e3 = self.RRCNN3(e3)
e4 = self.MaxPool(e3)
e4 = self.RRCNN4(e4)
e5 = self.MaxPool(e4)
e5 = self.RRCNN5(e5)
e6 = self.MaxPool(e5)
X = F.relu(self.mid_conv_1(e6))
Y = F.relu(self.mid_conv_2(satellite))
X = torch.cat((X,Y),1)
Y = self.MidConv(X)
pred_satellite = self.out_conv_S(Y)
d5 = self.Up5(X)
s4 = self.Att5(gate=d5, skip_connection=e5)
d5 = torch.cat((s4, d5), dim=1) # concatenate attention-weighted skip connection with previous layer output
d5 = self.UpRRCNN5(d5)
d4 = self.Up4(d5)
s3 = self.Att4(gate=d4, skip_connection=e4)
d4 = torch.cat((s3, d4), dim=1)
d4 = self.UpRRCNN4(d4)
d3 = self.Up3(d4)
s2 = self.Att3(gate=d3, skip_connection=e3)
d3 = torch.cat((s2, d3), dim=1)
d3 = self.UpRRCNN3(d3)
d2 = self.Up2(d3)
s1 = self.Att2(gate=d2, skip_connection=e2)
d2 = torch.cat((s1, d2), dim=1)
d2 = self.UpRRCNN2(d2)
d1 = self.Up1(d2)
s0 = self.Att1(gate=d1, skip_connection=e1)
d0 = torch.cat((s0, d1), dim=1)
d0 = self.UpRRCNN1(d0)
pred_radar = self.out_conv_R(d0)
return pred_radar, pred_satellite
class Network(nn.Module):
def __init__(self,model_type:str,rad_channel:int, sat_channel:int,rad_size:int,sat_size:int):
super(Network,self).__init__()
print(model_type)
if(model_type == "Nothing"):
self.net = Nothing()
elif(model_type == "Unet"):
self.net = Unet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size)
elif(model_type == "Unet"):
self.net = Unet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size)
elif(model_type == "R2Unet"):
self.net = R2Unet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size)
elif(model_type == "AttUnet"):
self.net = AttUnet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size)
elif(model_type == "AttR2Unet"):
self.net = AttR2Unet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size)
else:
raise ValueError("model_type is wrong")
def forward(self, radar,satellite):
pred_radar, pred_satellite = self.net.forward(radar,satellite)
return pred_radar, pred_satellite