Chayanat commited on
Commit
5d8a836
·
verified ·
1 Parent(s): b626913

Upload 4 files

Browse files
Files changed (4) hide show
  1. .gitattributes +35 -35
  2. README.md +15 -12
  3. app.py +303 -0
  4. requirements.txt +7 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,15 @@
1
- ---
2
- title: ChestX Ray CTR
3
- emoji: 🏃
4
- colorFrom: pink
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.47.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
1
+ ---
2
+ title: Chest x-ray HybridGNet Segmentation
3
+ emoji:
4
+ colorFrom: yellow
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.42.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: gpl-3.0
11
+ ---
12
+
13
+ Demo of the HybridGNet model with 2 image-to-graph skip connections from: arxiv.org/abs/2203.10977
14
+ Original HybridGNet model: arxiv.org/abs/2106.09832
15
+ The training procedure was taken from: arxiv.org/abs/2211.07395
app.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import cv2
4
+
5
+ from models.HybridGNet2IGSC import Hybrid
6
+ from utils.utils import scipy_to_torch_sparse, genMatrixesLungsHeart
7
+ import scipy.sparse as sp
8
+ import torch
9
+ from zipfile import ZipFile
10
+
11
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
12
+ hybrid = None
13
+
14
+ def getDenseMask(landmarks, h, w):
15
+
16
+ RL = landmarks[0:44]
17
+ LL = landmarks[44:94]
18
+ H = landmarks[94:]
19
+
20
+ img = np.zeros([h, w], dtype = 'uint8')
21
+
22
+ RL = RL.reshape(-1, 1, 2).astype('int')
23
+ LL = LL.reshape(-1, 1, 2).astype('int')
24
+ H = H.reshape(-1, 1, 2).astype('int')
25
+
26
+ img = cv2.drawContours(img, [RL], -1, 1, -1)
27
+ img = cv2.drawContours(img, [LL], -1, 1, -1)
28
+ img = cv2.drawContours(img, [H], -1, 2, -1)
29
+
30
+ return img
31
+
32
+ def getMasks(landmarks, h, w):
33
+
34
+ RL = landmarks[0:44]
35
+ LL = landmarks[44:94]
36
+ H = landmarks[94:]
37
+
38
+ RL = RL.reshape(-1, 1, 2).astype('int')
39
+ LL = LL.reshape(-1, 1, 2).astype('int')
40
+ H = H.reshape(-1, 1, 2).astype('int')
41
+
42
+ RL_mask = np.zeros([h, w], dtype = 'uint8')
43
+ LL_mask = np.zeros([h, w], dtype = 'uint8')
44
+ H_mask = np.zeros([h, w], dtype = 'uint8')
45
+
46
+ RL_mask = cv2.drawContours(RL_mask, [RL], -1, 255, -1)
47
+ LL_mask = cv2.drawContours(LL_mask, [LL], -1, 255, -1)
48
+ H_mask = cv2.drawContours(H_mask, [H], -1, 255, -1)
49
+
50
+ return RL_mask, LL_mask, H_mask
51
+
52
+ def drawOnTop(img, landmarks, original_shape):
53
+ h, w = original_shape
54
+ output = getDenseMask(landmarks, h, w)
55
+
56
+ image = np.zeros([h, w, 3])
57
+ image[:,:,0] = img + 0.3 * (output == 1).astype('float') - 0.1 * (output == 2).astype('float')
58
+ image[:,:,1] = img + 0.3 * (output == 2).astype('float') - 0.1 * (output == 1).astype('float')
59
+ image[:,:,2] = img - 0.1 * (output == 1).astype('float') - 0.2 * (output == 2).astype('float')
60
+
61
+ image = np.clip(image, 0, 1)
62
+
63
+ RL, LL, H = landmarks[0:44], landmarks[44:94], landmarks[94:]
64
+
65
+ # Draw the landmarks as dots
66
+
67
+ for l in RL:
68
+ image = cv2.circle(image, (int(l[0]), int(l[1])), 5, (1, 0, 1), -1)
69
+ for l in LL:
70
+ image = cv2.circle(image, (int(l[0]), int(l[1])), 5, (1, 0, 1), -1)
71
+ for l in H:
72
+ image = cv2.circle(image, (int(l[0]), int(l[1])), 5, (1, 1, 0), -1)
73
+
74
+ return image
75
+
76
+
77
+ def loadModel(device):
78
+ A, AD, D, U = genMatrixesLungsHeart()
79
+ N1 = A.shape[0]
80
+ N2 = AD.shape[0]
81
+
82
+ A = sp.csc_matrix(A).tocoo()
83
+ AD = sp.csc_matrix(AD).tocoo()
84
+ D = sp.csc_matrix(D).tocoo()
85
+ U = sp.csc_matrix(U).tocoo()
86
+
87
+ D_ = [D.copy()]
88
+ U_ = [U.copy()]
89
+
90
+ config = {}
91
+
92
+ config['n_nodes'] = [N1, N1, N1, N2, N2, N2]
93
+ A_ = [A.copy(), A.copy(), A.copy(), AD.copy(), AD.copy(), AD.copy()]
94
+
95
+ A_t, D_t, U_t = ([scipy_to_torch_sparse(x).to(device) for x in X] for X in (A_, D_, U_))
96
+
97
+ config['latents'] = 64
98
+ config['inputsize'] = 1024
99
+
100
+ f = 32
101
+ config['filters'] = [2, f, f, f, f//2, f//2, f//2]
102
+ config['skip_features'] = f
103
+
104
+ hybrid = Hybrid(config.copy(), D_t, U_t, A_t).to(device)
105
+ hybrid.load_state_dict(torch.load("weights/weights.pt", map_location=torch.device(device)))
106
+ hybrid.eval()
107
+
108
+ return hybrid
109
+
110
+
111
+ def pad_to_square(img):
112
+ h, w = img.shape[:2]
113
+
114
+ if h > w:
115
+ padw = (h - w)
116
+ auxw = padw % 2
117
+ img = np.pad(img, ((0, 0), (padw//2, padw//2 + auxw)), 'constant')
118
+
119
+ padh = 0
120
+ auxh = 0
121
+
122
+ else:
123
+ padh = (w - h)
124
+ auxh = padh % 2
125
+ img = np.pad(img, ((padh//2, padh//2 + auxh), (0, 0)), 'constant')
126
+
127
+ padw = 0
128
+ auxw = 0
129
+
130
+ return img, (padh, padw, auxh, auxw)
131
+
132
+
133
+ def preprocess(input_img):
134
+ img, padding = pad_to_square(input_img)
135
+
136
+ h, w = img.shape[:2]
137
+ if h != 1024 or w != 1024:
138
+ img = cv2.resize(img, (1024, 1024), interpolation = cv2.INTER_CUBIC)
139
+
140
+ return img, (h, w, padding)
141
+
142
+
143
+ def removePreprocess(output, info):
144
+ h, w, padding = info
145
+
146
+ if h != 1024 or w != 1024:
147
+ output = output * h
148
+ else:
149
+ output = output * 1024
150
+
151
+ padh, padw, auxh, auxw = padding
152
+
153
+ output[:, 0] = output[:, 0] - padw//2
154
+ output[:, 1] = output[:, 1] - padh//2
155
+
156
+ return output
157
+
158
+
159
+ def zip_files(files):
160
+ with ZipFile("complete_results.zip", "w") as zipObj:
161
+ for idx, file in enumerate(files):
162
+ zipObj.write(file, arcname=file.split("/")[-1])
163
+ return "complete_results.zip"
164
+
165
+
166
+ def calculate_ctr(landmarks):
167
+ H = landmarks[94:]
168
+ RL = landmarks[0:44]
169
+ LL = landmarks[44:94]
170
+ cardiac_width = np.max(H[:,0]) - np.min(H[:,0])
171
+ thoracic_width = max(np.max(RL[:,0]), np.max(LL[:,0])) - min(np.min(RL[:,0]), np.min(LL[:,0]))
172
+ ctr = cardiac_width / thoracic_width if thoracic_width > 0 else 0
173
+ return ctr
174
+
175
+ def segment(input_img):
176
+ global hybrid, device
177
+
178
+ if hybrid is None:
179
+ hybrid = loadModel(device)
180
+
181
+ input_img = cv2.imread(input_img, 0) / 255.0
182
+ original_shape = input_img.shape[:2]
183
+
184
+ img, (h, w, padding) = preprocess(input_img)
185
+
186
+ data = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).to(device).float()
187
+
188
+ with torch.no_grad():
189
+ output = hybrid(data)[0].cpu().numpy().reshape(-1, 2)
190
+
191
+ output = removePreprocess(output, (h, w, padding))
192
+
193
+ output = output.astype('int')
194
+
195
+ outseg = drawOnTop(input_img, output, original_shape)
196
+
197
+ seg_to_save = (outseg.copy() * 255).astype('uint8')
198
+ cv2.imwrite("tmp/overlap_segmentation.png" , cv2.cvtColor(seg_to_save, cv2.COLOR_RGB2BGR))
199
+
200
+ RL = output[0:44]
201
+ LL = output[44:94]
202
+ H = output[94:]
203
+
204
+ np.savetxt("tmp/RL_landmarks.txt", RL, delimiter=" ", fmt="%d")
205
+ np.savetxt("tmp/LL_landmarks.txt", LL, delimiter=" ", fmt="%d")
206
+ np.savetxt("tmp/H_landmarks.txt", H, delimiter=" ", fmt="%d")
207
+
208
+ RL_mask, LL_mask, H_mask = getMasks(output, original_shape[0], original_shape[1])
209
+
210
+ cv2.imwrite("tmp/RL_mask.png", RL_mask)
211
+ cv2.imwrite("tmp/LL_mask.png", LL_mask)
212
+ cv2.imwrite("tmp/H_mask.png", H_mask)
213
+
214
+ zip = zip_files(["tmp/RL_landmarks.txt", "tmp/LL_landmarks.txt", "tmp/H_landmarks.txt", "tmp/RL_mask.png", "tmp/LL_mask.png", "tmp/H_mask.png", "tmp/overlap_segmentation.png"])
215
+
216
+ ctr_value = calculate_ctr(output)
217
+
218
+ return outseg, ["tmp/RL_landmarks.txt", "tmp/LL_landmarks.txt", "tmp/H_landmarks.txt", "tmp/RL_mask.png", "tmp/LL_mask.png", "tmp/H_mask.png", "tmp/overlap_segmentation.png", zip], ctr_value
219
+
220
+
221
+ if __name__ == "__main__":
222
+
223
+ with gr.Blocks() as demo:
224
+
225
+ gr.Markdown("""
226
+ # Chest X-ray HybridGNet Segmentation.
227
+
228
+ Demo of the HybridGNet model introduced in "Improving anatomical plausibility in medical image segmentation via hybrid graph neural networks: applications to chest x-ray analysis."
229
+
230
+ Instructions:
231
+ 1. Upload a chest X-ray image (PA or AP) in PNG or JPEG format.
232
+ 2. Click on "Segment Image".
233
+
234
+ Note: Pre-processing is not needed, it will be done automatically and removed after the segmentation.
235
+
236
+ Please check citations below.
237
+ """)
238
+
239
+ with gr.Tab("Segment Image"):
240
+ with gr.Row():
241
+ with gr.Column():
242
+ image_input = gr.Image(type="filepath", height=750)
243
+
244
+ with gr.Row():
245
+ clear_button = gr.Button("Clear")
246
+ image_button = gr.Button("Segment Image")
247
+
248
+ gr.Examples(inputs=image_input, examples=['utils/example1.jpg','utils/example2.jpg','utils/example3.png','utils/example4.jpg'])
249
+
250
+ with gr.Column():
251
+ image_output = gr.Image(type="filepath", height=750)
252
+ results = gr.File()
253
+ ctr_output = gr.Number(label="CTR (Cardiothoracic Ratio)")
254
+
255
+ gr.Markdown("""
256
+ If you use this code, please cite:
257
+
258
+ ```
259
+ @article{gaggion2022TMI,
260
+ doi = {10.1109/tmi.2022.3224660},
261
+ url = {https://doi.org/10.1109%2Ftmi.2022.3224660},
262
+ year = 2022,
263
+ publisher = {Institute of Electrical and Electronics Engineers ({IEEE})},
264
+ author = {Nicolas Gaggion and Lucas Mansilla and Candelaria Mosquera and Diego H. Milone and Enzo Ferrante},
265
+ title = {Improving anatomical plausibility in medical image segmentation via hybrid graph neural networks: applications to chest x-ray analysis},
266
+ journal = {{IEEE} Transactions on Medical Imaging}
267
+ }
268
+ ```
269
+
270
+ This model was trained following the procedure explained on:
271
+
272
+ ```
273
+ @INPROCEEDINGS{gaggion2022ISBI,
274
+ author={Gaggion, Nicolás and Vakalopoulou, Maria and Milone, Diego H. and Ferrante, Enzo},
275
+ booktitle={2023 IEEE 20th International Symposium on Biomedical Imaging (ISBI)},
276
+ title={Multi-Center Anatomical Segmentation with Heterogeneous Labels Via Landmark-Based Models},
277
+ year={2023},
278
+ volume={},
279
+ number={},
280
+ pages={1-5},
281
+ doi={10.1109/ISBI53787.2023.10230691}
282
+ }
283
+ ```
284
+
285
+ Example images extracted from Wikipedia, released under:
286
+ 1. CC0 Universial Public Domain. Source: https://commons.wikimedia.org/wiki/File:Normal_posteroanterior_(PA)_chest_radiograph_(X-ray).jpg
287
+ 2. Creative Commons Attribution-Share Alike 4.0 International. Source: https://commons.wikimedia.org/wiki/File:Chest_X-ray.jpg
288
+ 3. Creative Commons Attribution 3.0 Unported. Source https://commons.wikimedia.org/wiki/File:Implantable_cardioverter_defibrillator_chest_X-ray.jpg
289
+ 4. Creative Commons Attribution-Share Alike 3.0 Unported. Source: https://commons.wikimedia.org/wiki/File:Medical_X-Ray_imaging_PRD06_nevit.jpg
290
+
291
+ Author: Nicolás Gaggion
292
+ Website: [ngaggion.github.io](https://ngaggion.github.io/)
293
+
294
+ """)
295
+
296
+
297
+ clear_button.click(lambda: None, None, image_input, queue=False)
298
+ clear_button.click(lambda: None, None, image_output, queue=False)
299
+ clear_button.click(lambda: None, None, ctr_output, queue=False)
300
+
301
+ image_button.click(segment, inputs=image_input, outputs=[image_output, results, ctr_output], queue=False)
302
+
303
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ numpy==1.25.0
3
+ opencv-python==4.8.0.74
4
+ scipy==1.10.1
5
+ torch_geometric==2.3.0
6
+ torchvision
7
+ gradio==4.15.0