upload the model and the archs and images
Browse files- archs/NAFBlock.py +172 -0
- archs/arch_util.py +73 -0
- archs/model.py +174 -0
- examples/inputs/0010.png +0 -0
- examples/inputs/0060.png +0 -0
- examples/inputs/0075.png +0 -0
- examples/inputs/0087.png +0 -0
- examples/inputs/0088.png +0 -0
- models/NAFourNet16_LOLv2Real.pt +3 -0
- requirements.txt +183 -0
archs/NAFBlock.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from arch_utilNAFNET import LayerNorm2d
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
# Modules from model
|
| 7 |
+
import arch_util as arch_util
|
| 8 |
+
|
| 9 |
+
# Process Block 4 en SFNet y 5 bloques en AmpNet, con el spatial block aplicado en AmpNet (frequency stage)
|
| 10 |
+
# tal y como lo tienen ellos en su github (aunque en el paper es al revés) y no lo aplican el space stage
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SimpleGate(nn.Module):
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
x1, x2 = x.chunk(2, dim=1)
|
| 16 |
+
return x1 * x2
|
| 17 |
+
|
| 18 |
+
class SpaBlock(nn.Module):
|
| 19 |
+
def __init__(self, nc, DW_Expand = 2, FFN_Expand=2, drop_out_rate=0.):
|
| 20 |
+
super(SpaBlock, self).__init__()
|
| 21 |
+
dw_channel = nc * DW_Expand
|
| 22 |
+
self.conv1 = nn.Conv2d(in_channels=nc, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
| 23 |
+
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
|
| 24 |
+
bias=True) # the dconv
|
| 25 |
+
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=nc, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
| 26 |
+
|
| 27 |
+
# Simplified Channel Attention
|
| 28 |
+
self.sca = nn.Sequential(
|
| 29 |
+
nn.AdaptiveAvgPool2d(1),
|
| 30 |
+
nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
|
| 31 |
+
groups=1, bias=True),
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# SimpleGate
|
| 35 |
+
self.sg = SimpleGate()
|
| 36 |
+
|
| 37 |
+
ffn_channel = FFN_Expand * nc
|
| 38 |
+
self.conv4 = nn.Conv2d(in_channels=nc, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
| 39 |
+
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=nc, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
| 40 |
+
|
| 41 |
+
self.norm1 = LayerNorm2d(nc)
|
| 42 |
+
self.norm2 = LayerNorm2d(nc)
|
| 43 |
+
|
| 44 |
+
self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
| 45 |
+
self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
| 46 |
+
|
| 47 |
+
self.beta = nn.Parameter(torch.zeros((1, nc, 1, 1)), requires_grad=True)
|
| 48 |
+
self.gamma = nn.Parameter(torch.zeros((1, nc, 1, 1)), requires_grad=True)
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
|
| 52 |
+
x = self.norm1(x) # size [B, C, H, W]
|
| 53 |
+
|
| 54 |
+
x = self.conv1(x) # size [B, 2*C, H, W]
|
| 55 |
+
x = self.conv2(x) # size [B, 2*C, H, W]
|
| 56 |
+
x = self.sg(x) # size [B, C, H, W]
|
| 57 |
+
x = x * self.sca(x) # size [B, C, H, W]
|
| 58 |
+
x = self.conv3(x) # size [B, C, H, W]
|
| 59 |
+
|
| 60 |
+
x = self.dropout1(x)
|
| 61 |
+
|
| 62 |
+
y = x + x * self.beta # size [B, C, H, W]
|
| 63 |
+
|
| 64 |
+
x = self.conv4(self.norm2(y)) # size [B, 2*C, H, W]
|
| 65 |
+
x = self.sg(x) # size [B, C, H, W]
|
| 66 |
+
x = self.conv5(x) # size [B, C, H, W]
|
| 67 |
+
|
| 68 |
+
x = self.dropout2(x)
|
| 69 |
+
|
| 70 |
+
return y + x * self.gamma
|
| 71 |
+
|
| 72 |
+
class FreBlock(nn.Module):
|
| 73 |
+
def __init__(self, nc):
|
| 74 |
+
super(FreBlock, self).__init__()
|
| 75 |
+
self.fpre = nn.Conv2d(nc, nc, 1, 1, 0)
|
| 76 |
+
self.process1 = nn.Sequential(
|
| 77 |
+
nn.Conv2d(nc, nc, 1, 1, 0),
|
| 78 |
+
nn.LeakyReLU(0.1, inplace=True),
|
| 79 |
+
nn.Conv2d(nc, nc, 1, 1, 0))
|
| 80 |
+
self.process2 = nn.Sequential(
|
| 81 |
+
nn.Conv2d(nc, nc, 1, 1, 0),
|
| 82 |
+
nn.LeakyReLU(0.1, inplace=True),
|
| 83 |
+
nn.Conv2d(nc, nc, 1, 1, 0))
|
| 84 |
+
|
| 85 |
+
def forward(self, x):
|
| 86 |
+
_, _, H, W = x.shape
|
| 87 |
+
x_freq = torch.fft.rfft2(self.fpre(x), norm='backward')
|
| 88 |
+
mag = torch.abs(x_freq)
|
| 89 |
+
pha = torch.angle(x_freq)
|
| 90 |
+
mag = self.process1(mag)
|
| 91 |
+
pha = self.process2(pha)
|
| 92 |
+
real = mag * torch.cos(pha)
|
| 93 |
+
imag = mag * torch.sin(pha)
|
| 94 |
+
x_out = torch.complex(real, imag)
|
| 95 |
+
x_out = torch.fft.irfft2(x_out, s=(H, W), norm='backward')
|
| 96 |
+
|
| 97 |
+
return x_out+x
|
| 98 |
+
|
| 99 |
+
class ProcessBlock(nn.Module):
|
| 100 |
+
def __init__(self, in_nc, spatial = True):
|
| 101 |
+
super(ProcessBlock,self).__init__()
|
| 102 |
+
self.spatial = spatial
|
| 103 |
+
self.spatial_process = SpaBlock(in_nc) if spatial else nn.Identity()
|
| 104 |
+
self.frequency_process = FreBlock(in_nc)
|
| 105 |
+
self.cat = nn.Conv2d(2*in_nc,in_nc,1,1,0) if spatial else nn.Conv2d(in_nc,in_nc,1,1,0)
|
| 106 |
+
|
| 107 |
+
def forward(self, x):
|
| 108 |
+
xori = x
|
| 109 |
+
x_freq = self.frequency_process(x)
|
| 110 |
+
x_spatial = self.spatial_process(x)
|
| 111 |
+
xcat = torch.cat([x_spatial,x_freq],1)
|
| 112 |
+
x_out = self.cat(xcat) if self.spatial else self.cat(x_freq)
|
| 113 |
+
|
| 114 |
+
return x_out+xori
|
| 115 |
+
|
| 116 |
+
class SFNet(nn.Module):
|
| 117 |
+
|
| 118 |
+
def __init__(self, nc,n=5):
|
| 119 |
+
super(SFNet,self).__init__()
|
| 120 |
+
|
| 121 |
+
self.list_block = list()
|
| 122 |
+
for index in range(n):
|
| 123 |
+
|
| 124 |
+
self.list_block.append(ProcessBlock(nc,spatial=False))
|
| 125 |
+
|
| 126 |
+
self.block = nn.Sequential(*self.list_block)
|
| 127 |
+
|
| 128 |
+
def forward(self, x):
|
| 129 |
+
|
| 130 |
+
x_ori = x
|
| 131 |
+
x_out = self.block(x_ori)
|
| 132 |
+
xout = x_ori + x_out
|
| 133 |
+
|
| 134 |
+
return xout
|
| 135 |
+
|
| 136 |
+
class AmplitudeNet_skip(nn.Module):
|
| 137 |
+
def __init__(self, nc,n=1):
|
| 138 |
+
super(AmplitudeNet_skip,self).__init__()
|
| 139 |
+
|
| 140 |
+
self.conv1 = nn.Sequential(
|
| 141 |
+
nn.Conv2d(3, nc, 1, 1, 0),
|
| 142 |
+
ProcessBlock(nc),
|
| 143 |
+
)
|
| 144 |
+
self.conv2 = ProcessBlock(nc)
|
| 145 |
+
self.conv3 = ProcessBlock(nc)
|
| 146 |
+
self.conv4 = nn.Sequential(
|
| 147 |
+
ProcessBlock(nc * 2),
|
| 148 |
+
nn.Conv2d(nc * 2, nc, 1, 1, 0),
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
self.conv5 = nn.Sequential(
|
| 152 |
+
ProcessBlock(nc * 2),
|
| 153 |
+
nn.Conv2d(nc * 2, nc, 1, 1, 0),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
self.convout = nn.Sequential(
|
| 157 |
+
ProcessBlock(nc * 2),
|
| 158 |
+
nn.Conv2d(nc * 2, 3, 1, 1, 0),
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def forward(self, x):
|
| 162 |
+
|
| 163 |
+
x1 = self.conv1(x)
|
| 164 |
+
x2 = self.conv2(x1)
|
| 165 |
+
x3 = self.conv3(x2)
|
| 166 |
+
x4 = self.conv5(torch.cat((x2, x3), dim=1))
|
| 167 |
+
xout = self.convout(torch.cat((x1, x4), dim=1))
|
| 168 |
+
|
| 169 |
+
return xout
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
|
archs/arch_util.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.init as init
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def initialize_weights(net_l, scale=1):
|
| 8 |
+
if not isinstance(net_l, list):
|
| 9 |
+
net_l = [net_l]
|
| 10 |
+
for net in net_l:
|
| 11 |
+
for m in net.modules():
|
| 12 |
+
if isinstance(m, nn.Conv2d):
|
| 13 |
+
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
| 14 |
+
m.weight.data *= scale # for residual block
|
| 15 |
+
if m.bias is not None:
|
| 16 |
+
m.bias.data.zero_()
|
| 17 |
+
elif isinstance(m, nn.Linear):
|
| 18 |
+
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
| 19 |
+
m.weight.data *= scale
|
| 20 |
+
if m.bias is not None:
|
| 21 |
+
m.bias.data.zero_()
|
| 22 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 23 |
+
init.constant_(m.weight, 1)
|
| 24 |
+
init.constant_(m.bias.data, 0.0)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def make_layer(block, n_layers):
|
| 28 |
+
layers = []
|
| 29 |
+
for _ in range(n_layers):
|
| 30 |
+
layers.append(block())
|
| 31 |
+
return nn.Sequential(*layers)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ResidualBlock_noBN(nn.Module):
|
| 35 |
+
'''Residual block w/o BN
|
| 36 |
+
---Conv-ReLU-Conv-+-
|
| 37 |
+
|________________|
|
| 38 |
+
'''
|
| 39 |
+
|
| 40 |
+
def __init__(self, nf=64):
|
| 41 |
+
super(ResidualBlock_noBN, self).__init__()
|
| 42 |
+
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
| 43 |
+
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
| 44 |
+
|
| 45 |
+
# initialization
|
| 46 |
+
initialize_weights([self.conv1, self.conv2], 0.1)
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
identity = x
|
| 50 |
+
out = F.relu(self.conv1(x), inplace=True)
|
| 51 |
+
out = self.conv2(out)
|
| 52 |
+
return identity + out
|
| 53 |
+
|
| 54 |
+
class ResidualBlock(nn.Module):
|
| 55 |
+
'''Residual block w/o BN
|
| 56 |
+
---Conv-ReLU-Conv-+-
|
| 57 |
+
|________________|
|
| 58 |
+
'''
|
| 59 |
+
|
| 60 |
+
def __init__(self, nf=64):
|
| 61 |
+
super(ResidualBlock, self).__init__()
|
| 62 |
+
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
| 63 |
+
self.bn = nn.BatchNorm2d(nf)
|
| 64 |
+
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
| 65 |
+
|
| 66 |
+
# initialization
|
| 67 |
+
initialize_weights([self.conv1, self.conv2], 0.1)
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
identity = x
|
| 71 |
+
out = F.relu(self.bn(self.conv1(x)), inplace=True)
|
| 72 |
+
out = self.conv2(out)
|
| 73 |
+
return identity + out
|
archs/model.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import functools
|
| 5 |
+
import arch_util as arch_util
|
| 6 |
+
from NAFBlock import *
|
| 7 |
+
import kornia
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchvision.models
|
| 10 |
+
|
| 11 |
+
class VGG19(torch.nn.Module):
|
| 12 |
+
|
| 13 |
+
def __init__(self, requires_grad=False):
|
| 14 |
+
super().__init__()
|
| 15 |
+
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
|
| 16 |
+
self.slice1 = torch.nn.Sequential()
|
| 17 |
+
self.slice2 = torch.nn.Sequential()
|
| 18 |
+
self.slice3 = torch.nn.Sequential()
|
| 19 |
+
self.slice4 = torch.nn.Sequential()
|
| 20 |
+
self.slice5 = torch.nn.Sequential()
|
| 21 |
+
for x in range(2):
|
| 22 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
| 23 |
+
for x in range(2, 7):
|
| 24 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
| 25 |
+
for x in range(7, 12):
|
| 26 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
| 27 |
+
for x in range(12, 21):
|
| 28 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
| 29 |
+
for x in range(21, 30):
|
| 30 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
| 31 |
+
if not requires_grad:
|
| 32 |
+
for param in self.parameters():
|
| 33 |
+
param.requires_grad = False
|
| 34 |
+
|
| 35 |
+
def forward(self, X):
|
| 36 |
+
h_relu1 = self.slice1(X)
|
| 37 |
+
h_relu2 = self.slice2(h_relu1)
|
| 38 |
+
h_relu3 = self.slice3(h_relu2)
|
| 39 |
+
h_relu4 = self.slice4(h_relu3)
|
| 40 |
+
h_relu5 = self.slice5(h_relu4)
|
| 41 |
+
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
| 42 |
+
return out
|
| 43 |
+
|
| 44 |
+
class VGGLoss(nn.Module):
|
| 45 |
+
|
| 46 |
+
def __init__(self):
|
| 47 |
+
|
| 48 |
+
super(VGGLoss, self).__init__()
|
| 49 |
+
self.vgg = VGG19().cuda()
|
| 50 |
+
# self.criterion = nn.L1Loss()
|
| 51 |
+
self.criterion = nn.L1Loss(reduction='sum')
|
| 52 |
+
self.criterion2 = nn.L1Loss()
|
| 53 |
+
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
|
| 54 |
+
|
| 55 |
+
def forward(self, x, y):
|
| 56 |
+
|
| 57 |
+
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
|
| 58 |
+
# print(x_vgg.shape, x_vgg.dtype, torch.max(x_vgg), torch.min(x_vgg), y_vgg.shape, y_vgg.dtype, torch.max(y_vgg), torch.min(y_vgg))
|
| 59 |
+
loss = 0
|
| 60 |
+
for i in range(len(x_vgg)):
|
| 61 |
+
# print(x_vgg[i].shape, y_vgg[i].shape, 'hey')
|
| 62 |
+
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
|
| 63 |
+
# print(loss, i, 'hey')
|
| 64 |
+
|
| 65 |
+
return loss
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class FourNet(nn.Module):
|
| 69 |
+
def __init__(self, nf=64):
|
| 70 |
+
super(FourNet, self).__init__()
|
| 71 |
+
|
| 72 |
+
# AMPLITUDE ENHANCEMENT
|
| 73 |
+
self.AmpNet = nn.Sequential(
|
| 74 |
+
AmplitudeNet_skip(8),
|
| 75 |
+
nn.Sigmoid()
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
self.nf = nf
|
| 79 |
+
ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
|
| 80 |
+
|
| 81 |
+
self.conv_first_1 = nn.Conv2d(3 * 2, nf, 3, 1, 1, bias=True)
|
| 82 |
+
self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
| 83 |
+
self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
| 84 |
+
|
| 85 |
+
self.feature_extraction = arch_util.make_layer(ResidualBlock_noBN_f, 1)
|
| 86 |
+
self.recon_trunk = arch_util.make_layer(ResidualBlock_noBN_f, 1)
|
| 87 |
+
|
| 88 |
+
self.upconv1 = nn.Conv2d(nf*2, nf * 4, 3, 1, 1, bias=True)
|
| 89 |
+
self.upconv2 = nn.Conv2d(nf*2, nf * 4, 3, 1, 1, bias=True)
|
| 90 |
+
self.pixel_shuffle = nn.PixelShuffle(2)
|
| 91 |
+
self.HRconv = nn.Conv2d(nf*2, nf, 3, 1, 1, bias=True)
|
| 92 |
+
self.conv_last = nn.Conv2d(nf, 3, 3, 1, 1, bias=True)
|
| 93 |
+
|
| 94 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
| 95 |
+
self.transformer = SFNet(nf, n = 4)
|
| 96 |
+
self.recon_trunk_light = arch_util.make_layer(ResidualBlock_noBN_f, 6)
|
| 97 |
+
|
| 98 |
+
def get_mask(self,dark): # SNR map
|
| 99 |
+
|
| 100 |
+
light = kornia.filters.gaussian_blur2d(dark, (5, 5), (1.5, 1.5))
|
| 101 |
+
dark = dark[:, 0:1, :, :] * 0.299 + dark[:, 1:2, :, :] * 0.587 + dark[:, 2:3, :, :] * 0.114
|
| 102 |
+
light = light[:, 0:1, :, :] * 0.299 + light[:, 1:2, :, :] * 0.587 + light[:, 2:3, :, :] * 0.114
|
| 103 |
+
noise = torch.abs(dark - light)
|
| 104 |
+
|
| 105 |
+
mask = torch.div(light, noise + 0.0001)
|
| 106 |
+
|
| 107 |
+
batch_size = mask.shape[0]
|
| 108 |
+
height = mask.shape[2]
|
| 109 |
+
width = mask.shape[3]
|
| 110 |
+
mask_max = torch.max(mask.view(batch_size, -1), dim=1)[0]
|
| 111 |
+
mask_max = mask_max.view(batch_size, 1, 1, 1)
|
| 112 |
+
mask_max = mask_max.repeat(1, 1, height, width)
|
| 113 |
+
mask = mask * 1.0 / (mask_max + 0.0001)
|
| 114 |
+
|
| 115 |
+
mask = torch.clamp(mask, min=0, max=1.0)
|
| 116 |
+
return mask.float()
|
| 117 |
+
|
| 118 |
+
def forward(self, x):
|
| 119 |
+
|
| 120 |
+
# AMPLITUDE ENHANCEMENT
|
| 121 |
+
#--------------------------------------------------------Frequency Stage---------------------------------------------------
|
| 122 |
+
_, _, H, W = x.shape
|
| 123 |
+
image_fft = torch.fft.fft2(x, norm='backward')
|
| 124 |
+
mag_image = torch.abs(image_fft)
|
| 125 |
+
pha_image = torch.angle(image_fft)
|
| 126 |
+
curve_amps = self.AmpNet(x)
|
| 127 |
+
mag_image = mag_image / (curve_amps + 0.00000001) # * d4
|
| 128 |
+
real_image_enhanced = mag_image * torch.cos(pha_image)
|
| 129 |
+
imag_image_enhanced = mag_image * torch.sin(pha_image)
|
| 130 |
+
img_amp_enhanced = torch.fft.ifft2(torch.complex(real_image_enhanced, imag_image_enhanced), s=(H, W),
|
| 131 |
+
norm='backward').real
|
| 132 |
+
|
| 133 |
+
x_center = img_amp_enhanced
|
| 134 |
+
|
| 135 |
+
rate = 2 ** 3
|
| 136 |
+
pad_h = (rate - H % rate) % rate
|
| 137 |
+
pad_w = (rate - W % rate) % rate
|
| 138 |
+
if pad_h != 0 or pad_w != 0:
|
| 139 |
+
x_center = F.pad(x_center, (0, pad_w, 0, pad_h), "reflect")
|
| 140 |
+
x = F.pad(x, (0, pad_w, 0, pad_h), "reflect")
|
| 141 |
+
|
| 142 |
+
#------------------------------------------Spatial Stage---------------------------------------------------------------------
|
| 143 |
+
|
| 144 |
+
L1_fea_1 = self.lrelu(self.conv_first_1(torch.cat((x_center,x),dim=1)))
|
| 145 |
+
L1_fea_2 = self.lrelu(self.conv_first_2(L1_fea_1)) # Encoder
|
| 146 |
+
L1_fea_3 = self.lrelu(self.conv_first_3(L1_fea_2))
|
| 147 |
+
|
| 148 |
+
fea = self.feature_extraction(L1_fea_3)
|
| 149 |
+
fea_light = self.recon_trunk_light(fea)
|
| 150 |
+
|
| 151 |
+
h_feature = fea.shape[2]
|
| 152 |
+
w_feature = fea.shape[3]
|
| 153 |
+
mask_image = self.get_mask(x_center) # SNR Map
|
| 154 |
+
mask = F.interpolate(mask_image, size=[h_feature, w_feature], mode='nearest') # Resize and Normalize SNR map
|
| 155 |
+
|
| 156 |
+
fea_unfold = self.transformer(fea)
|
| 157 |
+
|
| 158 |
+
channel = fea.shape[1]
|
| 159 |
+
mask = mask.repeat(1, channel, 1, 1)
|
| 160 |
+
fea = fea_unfold * (1 - mask) + fea_light * mask # SNR-based Interaction
|
| 161 |
+
|
| 162 |
+
out_noise = self.recon_trunk(fea)
|
| 163 |
+
out_noise = torch.cat([out_noise, L1_fea_3], dim=1)
|
| 164 |
+
out_noise = self.lrelu(self.pixel_shuffle(self.upconv1(out_noise)))
|
| 165 |
+
out_noise = torch.cat([out_noise, L1_fea_2], dim=1) # Decoder
|
| 166 |
+
out_noise = self.lrelu(self.pixel_shuffle(self.upconv2(out_noise)))
|
| 167 |
+
out_noise = torch.cat([out_noise, L1_fea_1], dim=1)
|
| 168 |
+
out_noise = self.lrelu(self.HRconv(out_noise))
|
| 169 |
+
out_noise = self.conv_last(out_noise)
|
| 170 |
+
out_noise = out_noise + x
|
| 171 |
+
out_noise = out_noise[:, :, :H, :W]
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
return out_noise, mag_image, x_center, mask_image
|
examples/inputs/0010.png
ADDED
|
examples/inputs/0060.png
ADDED
|
examples/inputs/0075.png
ADDED
|
examples/inputs/0087.png
ADDED
|
examples/inputs/0088.png
ADDED
|
models/NAFourNet16_LOLv2Real.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e6a451ba6d06e262b8f7c4e4336b87b4448c1516f377af725cb1171e8e478a0d
|
| 3 |
+
size 1605726
|
requirements.txt
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==2.1.0
|
| 2 |
+
anyio==4.2.0
|
| 3 |
+
appdirs==1.4.4
|
| 4 |
+
argon2-cffi==23.1.0
|
| 5 |
+
argon2-cffi-bindings==21.2.0
|
| 6 |
+
arrow==1.3.0
|
| 7 |
+
asttokens==2.4.1
|
| 8 |
+
async-lru==2.0.4
|
| 9 |
+
attrs==23.2.0
|
| 10 |
+
Babel==2.14.0
|
| 11 |
+
backcall==0.2.0
|
| 12 |
+
beautifulsoup4==4.12.3
|
| 13 |
+
bleach==6.1.0
|
| 14 |
+
certifi==2023.7.22
|
| 15 |
+
cffi==1.16.0
|
| 16 |
+
charset-normalizer==3.3.0
|
| 17 |
+
click==8.1.7
|
| 18 |
+
colorama==0.4.6
|
| 19 |
+
coloredlogs==15.0.1
|
| 20 |
+
comm==0.2.1
|
| 21 |
+
contourpy==1.1.1
|
| 22 |
+
cycler==0.12.0
|
| 23 |
+
Cython==3.0.6
|
| 24 |
+
debugpy==1.8.0
|
| 25 |
+
decorator==4.4.2
|
| 26 |
+
defusedxml==0.7.1
|
| 27 |
+
docker-pycreds==0.4.0
|
| 28 |
+
exceptiongroup==1.2.0
|
| 29 |
+
executing==2.0.1
|
| 30 |
+
fastjsonschema==2.19.1
|
| 31 |
+
filelock==3.12.4
|
| 32 |
+
flatbuffers==23.5.26
|
| 33 |
+
fonttools==4.43.0
|
| 34 |
+
fqdn==1.5.1
|
| 35 |
+
fsspec==2023.9.2
|
| 36 |
+
gitdb==4.0.11
|
| 37 |
+
GitPython==3.1.42
|
| 38 |
+
humanfriendly==10.0
|
| 39 |
+
idna==3.4
|
| 40 |
+
imageio==2.33.0
|
| 41 |
+
imageio-ffmpeg==0.4.9
|
| 42 |
+
importlib-metadata==7.0.1
|
| 43 |
+
importlib-resources==6.1.0
|
| 44 |
+
ipykernel==6.29.0
|
| 45 |
+
ipython==8.12.3
|
| 46 |
+
ipywidgets==8.1.1
|
| 47 |
+
isoduration==20.11.0
|
| 48 |
+
jax==0.4.25
|
| 49 |
+
jedi==0.19.1
|
| 50 |
+
Jinja2==3.1.2
|
| 51 |
+
json5==0.9.14
|
| 52 |
+
jsonpointer==2.4
|
| 53 |
+
jsonschema==4.21.1
|
| 54 |
+
jsonschema-specifications==2023.12.1
|
| 55 |
+
jupyter==1.0.0
|
| 56 |
+
jupyter-console==6.6.3
|
| 57 |
+
jupyter-events==0.9.0
|
| 58 |
+
jupyter-lsp==2.2.2
|
| 59 |
+
jupyter_client==8.6.0
|
| 60 |
+
jupyter_core==5.7.1
|
| 61 |
+
jupyter_server==2.12.5
|
| 62 |
+
jupyter_server_terminals==0.5.2
|
| 63 |
+
jupyterlab==4.0.11
|
| 64 |
+
jupyterlab-widgets==3.0.9
|
| 65 |
+
jupyterlab_pygments==0.3.0
|
| 66 |
+
jupyterlab_server==2.25.2
|
| 67 |
+
kiwisolver==1.4.5
|
| 68 |
+
kornia==0.7.2
|
| 69 |
+
kornia_rs==0.1.3
|
| 70 |
+
lazy_loader==0.3
|
| 71 |
+
lightning-utilities==0.11.2
|
| 72 |
+
lpips==0.1.4
|
| 73 |
+
MarkupSafe==2.1.3
|
| 74 |
+
matplotlib==3.7.3
|
| 75 |
+
matplotlib-inline==0.1.6
|
| 76 |
+
mediapipe==0.10.10
|
| 77 |
+
mistune==3.0.2
|
| 78 |
+
ml-dtypes==0.3.2
|
| 79 |
+
moviepy==1.0.3
|
| 80 |
+
mpmath==1.3.0
|
| 81 |
+
nbclient==0.9.0
|
| 82 |
+
nbconvert==7.14.2
|
| 83 |
+
nbformat==5.9.2
|
| 84 |
+
nest-asyncio==1.6.0
|
| 85 |
+
netron==7.5.3
|
| 86 |
+
networkx==3.1
|
| 87 |
+
notebook==7.0.7
|
| 88 |
+
notebook_shim==0.2.3
|
| 89 |
+
numpy==1.24.4
|
| 90 |
+
nvidia-cublas-cu12==12.1.3.1
|
| 91 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
| 92 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
| 93 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
| 94 |
+
nvidia-cudnn-cu12==8.9.2.26
|
| 95 |
+
nvidia-cufft-cu12==11.0.2.54
|
| 96 |
+
nvidia-curand-cu12==10.3.2.106
|
| 97 |
+
nvidia-cusolver-cu12==11.4.5.107
|
| 98 |
+
nvidia-cusparse-cu12==12.1.0.106
|
| 99 |
+
nvidia-nccl-cu12==2.18.1
|
| 100 |
+
nvidia-nvjitlink-cu12==12.3.101
|
| 101 |
+
nvidia-nvtx-cu12==12.1.105
|
| 102 |
+
onnx==1.14.1
|
| 103 |
+
onnxruntime==1.16.0
|
| 104 |
+
opencv-contrib-python==4.9.0.80
|
| 105 |
+
opencv-python==4.8.1.78
|
| 106 |
+
opt-einsum==3.3.0
|
| 107 |
+
overrides==7.7.0
|
| 108 |
+
packaging==23.2
|
| 109 |
+
pandas==2.0.3
|
| 110 |
+
pandocfilters==1.5.1
|
| 111 |
+
parso==0.8.3
|
| 112 |
+
pexpect==4.9.0
|
| 113 |
+
pickleshare==0.7.5
|
| 114 |
+
Pillow==10.0.1
|
| 115 |
+
pkgutil_resolve_name==1.3.10
|
| 116 |
+
platformdirs==4.1.0
|
| 117 |
+
proglog==0.1.10
|
| 118 |
+
prometheus-client==0.19.0
|
| 119 |
+
prompt-toolkit==3.0.43
|
| 120 |
+
protobuf==3.20.3
|
| 121 |
+
psutil==5.9.5
|
| 122 |
+
ptflops==0.7.3
|
| 123 |
+
ptyprocess==0.7.0
|
| 124 |
+
pure-eval==0.2.2
|
| 125 |
+
py-cpuinfo==9.0.0
|
| 126 |
+
pycparser==2.21
|
| 127 |
+
Pygments==2.17.2
|
| 128 |
+
pyparsing==3.1.1
|
| 129 |
+
pyreadline3==3.4.1
|
| 130 |
+
python-dateutil==2.8.2
|
| 131 |
+
python-json-logger==2.0.7
|
| 132 |
+
pytz==2023.3.post1
|
| 133 |
+
PyWavelets==1.4.1
|
| 134 |
+
PyYAML==6.0.1
|
| 135 |
+
pyzmq==25.1.2
|
| 136 |
+
qtconsole==5.5.1
|
| 137 |
+
QtPy==2.4.1
|
| 138 |
+
rawpy==0.18.1
|
| 139 |
+
referencing==0.33.0
|
| 140 |
+
requests==2.31.0
|
| 141 |
+
rfc3339-validator==0.1.4
|
| 142 |
+
rfc3986-validator==0.1.1
|
| 143 |
+
rpds-py==0.17.1
|
| 144 |
+
scikit-image==0.21.0
|
| 145 |
+
scipy==1.10.1
|
| 146 |
+
seaborn==0.13.0
|
| 147 |
+
Send2Trash==1.8.2
|
| 148 |
+
sentry-sdk==1.40.5
|
| 149 |
+
setproctitle==1.3.3
|
| 150 |
+
shapely==2.0.2
|
| 151 |
+
six==1.16.0
|
| 152 |
+
smmap==5.0.1
|
| 153 |
+
sniffio==1.3.0
|
| 154 |
+
sounddevice==0.4.6
|
| 155 |
+
soupsieve==2.5
|
| 156 |
+
stack-data==0.6.3
|
| 157 |
+
sympy==1.12
|
| 158 |
+
terminado==0.18.0
|
| 159 |
+
thop==0.1.1.post2209072238
|
| 160 |
+
tifffile==2023.7.10
|
| 161 |
+
tinycss2==1.2.1
|
| 162 |
+
tk==0.1.0
|
| 163 |
+
tomli==2.0.1
|
| 164 |
+
torch==2.1.0
|
| 165 |
+
torchmetrics==1.4.0.post0
|
| 166 |
+
torchvision==0.16.0
|
| 167 |
+
tornado==6.4
|
| 168 |
+
tqdm==4.66.1
|
| 169 |
+
traitlets==5.14.1
|
| 170 |
+
triton==2.1.0
|
| 171 |
+
types-python-dateutil==2.8.19.20240106
|
| 172 |
+
typing_extensions==4.8.0
|
| 173 |
+
tzdata==2023.3
|
| 174 |
+
ultralytics==8.1.8
|
| 175 |
+
uri-template==1.3.0
|
| 176 |
+
urllib3==2.0.6
|
| 177 |
+
wandb==0.16.3
|
| 178 |
+
wcwidth==0.2.13
|
| 179 |
+
webcolors==1.13
|
| 180 |
+
webencodings==0.5.1
|
| 181 |
+
websocket-client==1.7.0
|
| 182 |
+
widgetsnbextension==4.0.9
|
| 183 |
+
zipp==3.17.0
|