InPeerReview commited on
Commit
a050a18
·
verified ·
1 Parent(s): 4027bb6

Upload eval.py

Browse files
Files changed (1) hide show
  1. eval.py +222 -0
eval.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from model.trainer import Trainer
3
+
4
+ sys.path.insert(0, '.')
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.backends.cudnn as cudnn
9
+ from torch.nn.parallel import gather
10
+ import torch.optim.lr_scheduler
11
+
12
+ import dataset.dataset as myDataLoader
13
+ import dataset.Transforms as myTransforms
14
+ from model.metric_tool import ConfuseMatrixMeter
15
+ from model.utils import BCEDiceLoss, init_seed
16
+ from PIL import Image
17
+ import os
18
+ import time
19
+ import numpy as np
20
+ from argparse import ArgumentParser
21
+ from tqdm import tqdm
22
+
23
+
24
+ @torch.no_grad()
25
+ def validate(args, val_loader, model, save_masks=False):
26
+ model.eval()
27
+
28
+ # 确保所有BatchNorm层使用全局统计量
29
+ for m in model.modules():
30
+ if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)):
31
+ m.track_running_stats = True
32
+ m.eval()
33
+
34
+ salEvalVal = ConfuseMatrixMeter(n_class=2)
35
+ epoch_loss = []
36
+
37
+ if save_masks:
38
+ mask_dir = f"{args.savedir}/pred_masks"
39
+ os.makedirs(mask_dir, exist_ok=True)
40
+ print(f"Saving prediction masks to: {mask_dir}")
41
+
42
+ pbar = tqdm(enumerate(val_loader), total=len(val_loader), desc="Validating")
43
+
44
+ for batch_idx, batched_inputs in pbar:
45
+ img, target = batched_inputs
46
+ # 获取当前batch的所有文件名
47
+ batch_file_names = val_loader.sampler.data_source.file_list[
48
+ batch_idx * args.batch_size : (batch_idx + 1) * args.batch_size
49
+ ]
50
+
51
+ pre_img = img[:, 0:3]
52
+ post_img = img[:, 3:6]
53
+
54
+ if args.onGPU:
55
+ pre_img = pre_img.cuda()
56
+ post_img = post_img.cuda()
57
+ target = target.cuda()
58
+
59
+ target = target.float()
60
+ output = model(pre_img, post_img)
61
+ loss = BCEDiceLoss(output, target)
62
+ pred = (output > 0.5).long()
63
+
64
+ if save_masks:
65
+ pred_np = pred.cpu().numpy().astype(np.uint8)
66
+
67
+ print(f"\nDebug - Batch {batch_idx}: {len(batch_file_names)} files, Mask shape: {pred_np.shape}")
68
+
69
+ try:
70
+ for i in range(pred_np.shape[0]):
71
+ if i >= len(batch_file_names): # 防止文件名不足
72
+ print(f"Warning: Missing filename for mask {i}, using default")
73
+ base_name = f"batch_{batch_idx}_mask_{i}"
74
+ else:
75
+ base_name = os.path.splitext(os.path.basename(batch_file_names[i]))[0]
76
+
77
+ single_mask = pred_np[i, 0] # 获取(1, 256, 256)中的(256, 256)
78
+
79
+ if single_mask.ndim != 2:
80
+ raise ValueError(f"Invalid mask shape: {single_mask.shape}")
81
+
82
+ mask_path = f"{mask_dir}/{base_name}_pred.png"
83
+ Image.fromarray(single_mask * 255).save(mask_path)
84
+ print(f"Saved: {mask_path}")
85
+
86
+ except Exception as e:
87
+ print(f"\nError saving batch {batch_idx}: {str(e)}")
88
+ print(f"Current mask shape: {single_mask.shape if 'single_mask' in locals() else 'N/A'}")
89
+ print(f"Current file: {base_name if 'base_name' in locals() else 'N/A'}")
90
+
91
+ if args.onGPU and torch.cuda.device_count() > 1:
92
+ pred = gather(pred, 0, dim=0)
93
+
94
+ f1 = salEvalVal.update_cm(pr=pred.cpu().numpy(), gt=target.cpu().numpy())
95
+ epoch_loss.append(loss.item())
96
+
97
+ pbar.set_postfix({'Loss': f"{loss.item():.4f}", 'F1': f"{f1:.4f}"})
98
+
99
+ average_loss = sum(epoch_loss) / len(epoch_loss)
100
+ scores = salEvalVal.get_scores()
101
+ return average_loss, scores
102
+
103
+ def ValidateSegmentation(args):
104
+ """完整的验证流程主函数"""
105
+ # 初始化设置
106
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
107
+ torch.backends.cudnn.benchmark = True
108
+ init_seed(args.seed) # 固定随机种子保证可重复性
109
+
110
+ # 模型路径设置
111
+ args.savedir = os.path.join(args.savedir,
112
+ f"{args.file_root}_iter_{args.max_steps}_lr_{args.lr}")
113
+ os.makedirs(args.savedir, exist_ok=True)
114
+
115
+ # 数据集路径配置
116
+ dataset_mapping = {
117
+ 'LEVIR': './levir_cd_256',
118
+ 'WHU': './whu_cd_256',
119
+ 'CLCD': './clcd_256',
120
+ 'SYSU': './sysu_256',
121
+ 'OSCD': './oscd_256'
122
+ }
123
+ args.file_root = dataset_mapping.get(args.file_root, args.file_root)
124
+
125
+ # 初始化模型
126
+ model = Trainer(args.model_type).float()
127
+ if args.onGPU:
128
+ model = model.cuda()
129
+
130
+ # 数据预处理 - 保持与训练时验证集相同的预处理
131
+ mean = [0.406, 0.456, 0.485, 0.406, 0.456, 0.485]
132
+ std = [0.225, 0.224, 0.229, 0.225, 0.224, 0.229]
133
+
134
+ valDataset = myTransforms.Compose([
135
+ myTransforms.Normalize(mean=mean, std=std),
136
+ myTransforms.Scale(args.inWidth, args.inHeight),
137
+ myTransforms.ToTensor()
138
+ ])
139
+
140
+ # 数据加载
141
+ test_data = myDataLoader.Dataset(file_root=args.file_root, mode="test", transform=valDataset)
142
+ testLoader = torch.utils.data.DataLoader(
143
+ test_data,
144
+ batch_size=args.batch_size,
145
+ shuffle=False,
146
+ num_workers=args.num_workers,
147
+ pin_memory=True
148
+ )
149
+
150
+ # 日志设置
151
+ logFileLoc = os.path.join(args.savedir, args.logFile)
152
+ logger = open(logFileLoc, 'a' if os.path.exists(logFileLoc) else 'w')
153
+ if not os.path.exists(logFileLoc):
154
+ logger.write("\n%s\t%s\t%s\t%s\t%s\t%s\t%s" %
155
+ ('Epoch', 'Kappa', 'IoU', 'F1', 'Recall', 'Precision', 'OA'))
156
+ logger.flush()
157
+
158
+ # 加载最佳模型
159
+ model_file_name = os.path.join(args.savedir, 'best_model.pth')
160
+ if not os.path.exists(model_file_name):
161
+ raise FileNotFoundError(f"Model file not found: {model_file_name}")
162
+
163
+ state_dict = torch.load(model_file_name)
164
+ model.load_state_dict(state_dict)
165
+ print(f"Loaded model from {model_file_name}")
166
+
167
+ # 执行验证
168
+ loss_test, score_test = validate(args, testLoader, model, save_masks=args.save_masks)
169
+
170
+ # 输出结果
171
+ print("\nTest Results:")
172
+ print(f"Loss: {loss_test:.4f}")
173
+ print(f"Kappa: {score_test['Kappa']:.4f}")
174
+ print(f"IoU: {score_test['IoU']:.4f}")
175
+ print(f"F1: {score_test['F1']:.4f}")
176
+ print(f"Recall: {score_test['recall']:.4f}")
177
+ print(f"Precision: {score_test['precision']:.4f}")
178
+ print(f"OA: {score_test['OA']:.4f}")
179
+
180
+ # 记录日志
181
+ logger.write("\n%s\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f" %
182
+ ('Test', score_test['Kappa'], score_test['IoU'], score_test['F1'],
183
+ score_test['recall'], score_test['precision'], score_test['OA']))
184
+ logger.close()
185
+
186
+
187
+ if __name__ == '__main__':
188
+ parser = ArgumentParser()
189
+ parser.add_argument('--file_root', default="LEVIR",
190
+ help='Data directory | LEVIR | WHU | CLCD | SYSU | OSCD')
191
+ parser.add_argument('--inWidth', type=int, default=256, help='Width of input image')
192
+ parser.add_argument('--inHeight', type=int, default=256, help='Height of input image')
193
+ parser.add_argument('--max_steps', type=int, default=80000,
194
+ help='Max. number of iterations (for path naming)')
195
+ parser.add_argument('--num_workers', type=int, default=4,
196
+ help='Number of data loading workers')
197
+ parser.add_argument('--model_type', type=str, default='small',
198
+ help='Model type | tiny | small')
199
+ parser.add_argument('--batch_size', type=int, default=16,
200
+ help='Batch size for validation')
201
+ parser.add_argument('--lr', type=float, default=2e-4,
202
+ help='Learning rate (for path naming)')
203
+ parser.add_argument('--seed', type=int, default=16,
204
+ help='Random seed for reproducibility')
205
+ parser.add_argument('--savedir', default='./results',
206
+ help='Base directory to save results')
207
+ parser.add_argument('--logFile', default='testLog.txt',
208
+ help='File to save validation logs')
209
+ parser.add_argument('--onGPU', default=True,
210
+ type=lambda x: (str(x).lower() == 'true'),
211
+ help='Run on GPU if True')
212
+ parser.add_argument('--gpu_id', type=int, default=0,
213
+ help='GPU device id')
214
+ parser.add_argument('--save_masks', action='store_true',
215
+ help='Save predicted masks to disk')
216
+
217
+ args = parser.parse_args()
218
+ print('Validation with args:')
219
+ print(args)
220
+
221
+ ValidateSegmentation(args)
222
+