Upload vista3d version 0.5.8
Browse files- LICENSE +247 -0
- configs/batch_inference.json +7 -0
- configs/evaluate.json +173 -0
- configs/inference.json +204 -0
- configs/inference_trt.json +19 -0
- configs/logging.conf +27 -0
- configs/metadata.json +229 -0
- configs/mgpu_evaluate.json +29 -0
- configs/msd_task09_spleen_folds.json +271 -0
- configs/multi_gpu_train.json +42 -0
- configs/train.json +406 -0
- configs/train_continual.json +107 -0
- docs/README.md +310 -0
- docs/data_license.txt +6 -0
- docs/labels.json +126 -0
- models/model.pt +3 -0
- scripts/__init__.py +15 -0
- scripts/early_stop_score_function.py +15 -0
- scripts/evaluator.py +297 -0
- scripts/inferer.py +117 -0
- scripts/trainer.py +211 -0
LICENSE
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Code License
|
| 2 |
+
|
| 3 |
+
This license applies to all files except the model weights in the directory.
|
| 4 |
+
|
| 5 |
+
Apache License
|
| 6 |
+
Version 2.0, January 2004
|
| 7 |
+
http://www.apache.org/licenses/
|
| 8 |
+
|
| 9 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 10 |
+
|
| 11 |
+
1. Definitions.
|
| 12 |
+
|
| 13 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 14 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 15 |
+
|
| 16 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 17 |
+
the copyright owner that is granting the License.
|
| 18 |
+
|
| 19 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 20 |
+
other entities that control, are controlled by, or are under common
|
| 21 |
+
control with that entity. For the purposes of this definition,
|
| 22 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 23 |
+
direction or management of such entity, whether by contract or
|
| 24 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 25 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 26 |
+
|
| 27 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 28 |
+
exercising permissions granted by this License.
|
| 29 |
+
|
| 30 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 31 |
+
including but not limited to software source code, documentation
|
| 32 |
+
source, and configuration files.
|
| 33 |
+
|
| 34 |
+
"Object" form shall mean any form resulting from mechanical
|
| 35 |
+
transformation or translation of a Source form, including but
|
| 36 |
+
not limited to compiled object code, generated documentation,
|
| 37 |
+
and conversions to other media types.
|
| 38 |
+
|
| 39 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 40 |
+
Object form, made available under the License, as indicated by a
|
| 41 |
+
copyright notice that is included in or attached to the work
|
| 42 |
+
(an example is provided in the Appendix below).
|
| 43 |
+
|
| 44 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 45 |
+
form, that is based on (or derived from) the Work and for which the
|
| 46 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 47 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 48 |
+
of this License, Derivative Works shall not include works that remain
|
| 49 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 50 |
+
the Work and Derivative Works thereof.
|
| 51 |
+
|
| 52 |
+
"Contribution" shall mean any work of authorship, including
|
| 53 |
+
the original version of the Work and any modifications or additions
|
| 54 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 55 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 56 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 57 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 58 |
+
means any form of electronic, verbal, or written communication sent
|
| 59 |
+
to the Licensor or its representatives, including but not limited to
|
| 60 |
+
communication on electronic mailing lists, source code control systems,
|
| 61 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 62 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 63 |
+
excluding communication that is conspicuously marked or otherwise
|
| 64 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 65 |
+
|
| 66 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 67 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 68 |
+
subsequently incorporated within the Work.
|
| 69 |
+
|
| 70 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 71 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 72 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 73 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 74 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 75 |
+
Work and such Derivative Works in Source or Object form.
|
| 76 |
+
|
| 77 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 78 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 79 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 80 |
+
(except as stated in this section) patent license to make, have made,
|
| 81 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 82 |
+
where such license applies only to those patent claims licensable
|
| 83 |
+
by such Contributor that are necessarily infringed by their
|
| 84 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 85 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 86 |
+
institute patent litigation against any entity (including a
|
| 87 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 88 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 89 |
+
or contributory patent infringement, then any patent licenses
|
| 90 |
+
granted to You under this License for that Work shall terminate
|
| 91 |
+
as of the date such litigation is filed.
|
| 92 |
+
|
| 93 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 94 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 95 |
+
modifications, and in Source or Object form, provided that You
|
| 96 |
+
meet the following conditions:
|
| 97 |
+
|
| 98 |
+
(a) You must give any other recipients of the Work or
|
| 99 |
+
Derivative Works a copy of this License; and
|
| 100 |
+
|
| 101 |
+
(b) You must cause any modified files to carry prominent notices
|
| 102 |
+
stating that You changed the files; and
|
| 103 |
+
|
| 104 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 105 |
+
that You distribute, all copyright, patent, trademark, and
|
| 106 |
+
attribution notices from the Source form of the Work,
|
| 107 |
+
excluding those notices that do not pertain to any part of
|
| 108 |
+
the Derivative Works; and
|
| 109 |
+
|
| 110 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 111 |
+
distribution, then any Derivative Works that You distribute must
|
| 112 |
+
include a readable copy of the attribution notices contained
|
| 113 |
+
within such NOTICE file, excluding those notices that do not
|
| 114 |
+
pertain to any part of the Derivative Works, in at least one
|
| 115 |
+
of the following places: within a NOTICE text file distributed
|
| 116 |
+
as part of the Derivative Works; within the Source form or
|
| 117 |
+
documentation, if provided along with the Derivative Works; or,
|
| 118 |
+
within a display generated by the Derivative Works, if and
|
| 119 |
+
wherever such third-party notices normally appear. The contents
|
| 120 |
+
of the NOTICE file are for informational purposes only and
|
| 121 |
+
do not modify the License. You may add Your own attribution
|
| 122 |
+
notices within Derivative Works that You distribute, alongside
|
| 123 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 124 |
+
that such additional attribution notices cannot be construed
|
| 125 |
+
as modifying the License.
|
| 126 |
+
|
| 127 |
+
You may add Your own copyright statement to Your modifications and
|
| 128 |
+
may provide additional or different license terms and conditions
|
| 129 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 130 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 131 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 132 |
+
the conditions stated in this License.
|
| 133 |
+
|
| 134 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 135 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 136 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 137 |
+
this License, without any additional terms or conditions.
|
| 138 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 139 |
+
the terms of any separate license agreement you may have executed
|
| 140 |
+
with Licensor regarding such Contributions.
|
| 141 |
+
|
| 142 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 143 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 144 |
+
except as required for reasonable and customary use in describing the
|
| 145 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 146 |
+
|
| 147 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 148 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 149 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 150 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 151 |
+
implied, including, without limitation, any warranties or conditions
|
| 152 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 153 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 154 |
+
appropriateness of using or redistributing the Work and assume any
|
| 155 |
+
risks associated with Your exercise of permissions under this License.
|
| 156 |
+
|
| 157 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 158 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 159 |
+
unless required by applicable law (such as deliberate and grossly
|
| 160 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 161 |
+
liable to You for damages, including any direct, indirect, special,
|
| 162 |
+
incidental, or consequential damages of any character arising as a
|
| 163 |
+
result of this License or out of the use or inability to use the
|
| 164 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 165 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 166 |
+
other commercial damages or losses), even if such Contributor
|
| 167 |
+
has been advised of the possibility of such damages.
|
| 168 |
+
|
| 169 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 170 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 171 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 172 |
+
or other liability obligations and/or rights consistent with this
|
| 173 |
+
License. However, in accepting such obligations, You may act only
|
| 174 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 175 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 176 |
+
defend, and hold each Contributor harmless for any liability
|
| 177 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 178 |
+
of your accepting any such warranty or additional liability.
|
| 179 |
+
|
| 180 |
+
END OF TERMS AND CONDITIONS
|
| 181 |
+
|
| 182 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 183 |
+
|
| 184 |
+
To apply the Apache License to your work, attach the following
|
| 185 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 186 |
+
replaced with your own identifying information. (Don't include
|
| 187 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 188 |
+
comment syntax for the file format. We also recommend that a
|
| 189 |
+
file or class name and description of purpose be included on the
|
| 190 |
+
same "printed page" as the copyright notice for easier
|
| 191 |
+
identification within third-party archives.
|
| 192 |
+
|
| 193 |
+
Copyright [yyyy] [name of copyright owner]
|
| 194 |
+
|
| 195 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 196 |
+
you may not use this file except in compliance with the License.
|
| 197 |
+
You may obtain a copy of the License at
|
| 198 |
+
|
| 199 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 200 |
+
|
| 201 |
+
Unless required by applicable law or agreed to in writing, software
|
| 202 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 203 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 204 |
+
See the License for the specific language governing permissions and
|
| 205 |
+
limitations under the License.
|
| 206 |
+
|
| 207 |
+
------------------------------------------------------------------------------
|
| 208 |
+
|
| 209 |
+
Model Weights License
|
| 210 |
+
|
| 211 |
+
This license applies to model weights in the directory.
|
| 212 |
+
|
| 213 |
+
NVIDIA License
|
| 214 |
+
|
| 215 |
+
1. Definitions
|
| 216 |
+
|
| 217 |
+
“Licensor” means any person or entity that distributes its Work.
|
| 218 |
+
“Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license.
|
| 219 |
+
The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
|
| 220 |
+
Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license.
|
| 221 |
+
|
| 222 |
+
2. License Grant
|
| 223 |
+
|
| 224 |
+
2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
|
| 225 |
+
|
| 226 |
+
3. Limitations
|
| 227 |
+
|
| 228 |
+
3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
|
| 229 |
+
|
| 230 |
+
3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
|
| 231 |
+
|
| 232 |
+
3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only.
|
| 233 |
+
|
| 234 |
+
3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately.
|
| 235 |
+
|
| 236 |
+
3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license.
|
| 237 |
+
|
| 238 |
+
3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately.
|
| 239 |
+
|
| 240 |
+
4. Disclaimer of Warranty.
|
| 241 |
+
|
| 242 |
+
THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
| 243 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
|
| 244 |
+
|
| 245 |
+
5. Limitation of Liability.
|
| 246 |
+
|
| 247 |
+
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
configs/batch_inference.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"input_dir": "@bundle_root",
|
| 3 |
+
"input_suffix": "*.nii.gz",
|
| 4 |
+
"input_list": "$sorted(glob.glob(os.path.join(@input_dir, @input_suffix)))",
|
| 5 |
+
"input_dicts": "$[{'image': x, 'label_prompt': @everything_labels} for x in @input_list]",
|
| 6 |
+
"dataset#data": "@input_dicts"
|
| 7 |
+
}
|
configs/evaluate.json
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"data_list_file_path": "$@bundle_root + '/configs/msd_task09_spleen_folds.json'",
|
| 3 |
+
"dataset_dir": "/data/Task09_Spleen",
|
| 4 |
+
"output_dir": "$@bundle_root + '/eval'",
|
| 5 |
+
"ckpt_path": "$@bundle_root + '/models/model.pt'",
|
| 6 |
+
"patch_size": [
|
| 7 |
+
128,
|
| 8 |
+
128,
|
| 9 |
+
128
|
| 10 |
+
],
|
| 11 |
+
"resample_to_spacing": [
|
| 12 |
+
1.5,
|
| 13 |
+
1.5,
|
| 14 |
+
1.5
|
| 15 |
+
],
|
| 16 |
+
"label_mappings": {
|
| 17 |
+
"default": [
|
| 18 |
+
[
|
| 19 |
+
1,
|
| 20 |
+
3
|
| 21 |
+
]
|
| 22 |
+
]
|
| 23 |
+
},
|
| 24 |
+
"label_set": "$list(x[1] for x in @label_mappings#default)",
|
| 25 |
+
"validate#evaluator#hyper_kwargs#val_head": "auto",
|
| 26 |
+
"validate#preprocessing": {
|
| 27 |
+
"_target_": "Compose",
|
| 28 |
+
"transforms": [
|
| 29 |
+
{
|
| 30 |
+
"_target_": "LoadImaged",
|
| 31 |
+
"keys": [
|
| 32 |
+
"image",
|
| 33 |
+
"label"
|
| 34 |
+
],
|
| 35 |
+
"image_only": true,
|
| 36 |
+
"ensure_channel_first": true
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"_target_": "CropForegroundd",
|
| 40 |
+
"keys": [
|
| 41 |
+
"image"
|
| 42 |
+
],
|
| 43 |
+
"source_key": "image",
|
| 44 |
+
"margin": 10,
|
| 45 |
+
"allow_smaller": true,
|
| 46 |
+
"start_coord_key": null,
|
| 47 |
+
"end_coord_key": null
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"_target_": "ScaleIntensityRanged",
|
| 51 |
+
"keys": "image",
|
| 52 |
+
"a_min": -963.8247715525971,
|
| 53 |
+
"a_max": 1053.678477684517,
|
| 54 |
+
"b_min": 0.0,
|
| 55 |
+
"b_max": 1.0,
|
| 56 |
+
"clip": true
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"_target_": "Orientationd",
|
| 60 |
+
"keys": [
|
| 61 |
+
"image"
|
| 62 |
+
],
|
| 63 |
+
"axcodes": "RAS"
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"_target_": "Spacingd",
|
| 67 |
+
"keys": [
|
| 68 |
+
"image"
|
| 69 |
+
],
|
| 70 |
+
"pixdim": "$@resample_to_spacing",
|
| 71 |
+
"mode": [
|
| 72 |
+
"bilinear"
|
| 73 |
+
]
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"_target_": "CastToTyped",
|
| 77 |
+
"keys": [
|
| 78 |
+
"image",
|
| 79 |
+
"label"
|
| 80 |
+
],
|
| 81 |
+
"dtype": [
|
| 82 |
+
"$torch.float32",
|
| 83 |
+
"$torch.uint8"
|
| 84 |
+
]
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"_target_": "monai.apps.vista3d.transforms.Relabeld",
|
| 88 |
+
"keys": "label",
|
| 89 |
+
"label_mappings": "@label_mappings",
|
| 90 |
+
"dtype": "$torch.uint8"
|
| 91 |
+
}
|
| 92 |
+
]
|
| 93 |
+
},
|
| 94 |
+
"validate#postprocessing": {
|
| 95 |
+
"_target_": "Compose",
|
| 96 |
+
"transforms": [
|
| 97 |
+
{
|
| 98 |
+
"_target_": "EnsureTyped",
|
| 99 |
+
"keys": [
|
| 100 |
+
"pred",
|
| 101 |
+
"label"
|
| 102 |
+
],
|
| 103 |
+
"device": "cpu",
|
| 104 |
+
"_disabled_": true
|
| 105 |
+
},
|
| 106 |
+
{
|
| 107 |
+
"_target_": "monai.apps.vista3d.transforms.VistaPostTransformd",
|
| 108 |
+
"keys": "pred"
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"_target_": "Invertd",
|
| 112 |
+
"keys": "pred",
|
| 113 |
+
"transform": "$copy.deepcopy(@validate#preprocessing)",
|
| 114 |
+
"orig_keys": "image",
|
| 115 |
+
"nearest_interp": true,
|
| 116 |
+
"to_tensor": true
|
| 117 |
+
},
|
| 118 |
+
{
|
| 119 |
+
"_target_": "Lambdad",
|
| 120 |
+
"func": "$lambda x: torch.nan_to_num(x, nan=255)",
|
| 121 |
+
"keys": "pred"
|
| 122 |
+
},
|
| 123 |
+
{
|
| 124 |
+
"_target_": "SaveImaged",
|
| 125 |
+
"keys": "pred",
|
| 126 |
+
"resample": false,
|
| 127 |
+
"output_dir": "@output_dir"
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"_target_": "monai.apps.vista3d.transforms.Relabeld",
|
| 131 |
+
"keys": [
|
| 132 |
+
"pred",
|
| 133 |
+
"label"
|
| 134 |
+
],
|
| 135 |
+
"label_mappings": "${'default': [[c, i+1] for i, c in enumerate(@label_set)]}",
|
| 136 |
+
"dtype": "$torch.uint8"
|
| 137 |
+
}
|
| 138 |
+
]
|
| 139 |
+
},
|
| 140 |
+
"validate#handlers": [
|
| 141 |
+
{
|
| 142 |
+
"_target_": "CheckpointLoader",
|
| 143 |
+
"load_path": "@ckpt_path",
|
| 144 |
+
"load_dict": {
|
| 145 |
+
"model": "@network"
|
| 146 |
+
}
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
"_target_": "StatsHandler",
|
| 150 |
+
"iteration_log": true,
|
| 151 |
+
"name": "validate_stats"
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"_target_": "MetricsSaver",
|
| 155 |
+
"_disabled_": false,
|
| 156 |
+
"save_dir": "@output_dir",
|
| 157 |
+
"metrics": [
|
| 158 |
+
"val_mean_dice"
|
| 159 |
+
],
|
| 160 |
+
"batch_transform": "$lambda x: [xx['image'].meta for xx in x]",
|
| 161 |
+
"metric_details": "*",
|
| 162 |
+
"summary_ops": "*"
|
| 163 |
+
}
|
| 164 |
+
],
|
| 165 |
+
"validate#dataset": {
|
| 166 |
+
"_target_": "Dataset",
|
| 167 |
+
"data": "$list(@val_datalist)",
|
| 168 |
+
"transform": "@validate#preprocessing"
|
| 169 |
+
},
|
| 170 |
+
"run": [
|
| 171 |
+
"$@validate#evaluator.run()"
|
| 172 |
+
]
|
| 173 |
+
}
|
configs/inference.json
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"imports": [
|
| 3 |
+
"$import glob",
|
| 4 |
+
"$import os",
|
| 5 |
+
"$import scripts",
|
| 6 |
+
"$import numpy as np",
|
| 7 |
+
"$import copy",
|
| 8 |
+
"$import json",
|
| 9 |
+
"$import pathlib"
|
| 10 |
+
],
|
| 11 |
+
"bundle_root": "./",
|
| 12 |
+
"image_key": "image",
|
| 13 |
+
"output_dir": "$@bundle_root + '/eval'",
|
| 14 |
+
"output_ext": ".nii.gz",
|
| 15 |
+
"output_dtype": "$np.float32",
|
| 16 |
+
"output_postfix": "trans",
|
| 17 |
+
"separate_folder": true,
|
| 18 |
+
"input_dict": "${'image': '/data/Task09_Spleen/imagesTr/spleen_10.nii.gz', 'label_prompt': [3]}",
|
| 19 |
+
"everything_labels": "$list(set([i+1 for i in range(132)]) - set([2,16,18,20,21,23,24,25,26,27,128,129,130,131,132]))",
|
| 20 |
+
"metadata_path": "$@bundle_root + '/configs/metadata.json'",
|
| 21 |
+
"metadata": "$json.loads(pathlib.Path(@metadata_path).read_text())",
|
| 22 |
+
"labels_dict": "$@metadata['network_data_format']['outputs']['pred']['channel_def']",
|
| 23 |
+
"subclass": {
|
| 24 |
+
"2": [
|
| 25 |
+
14,
|
| 26 |
+
5
|
| 27 |
+
],
|
| 28 |
+
"20": [
|
| 29 |
+
28,
|
| 30 |
+
29,
|
| 31 |
+
30,
|
| 32 |
+
31,
|
| 33 |
+
32
|
| 34 |
+
],
|
| 35 |
+
"21": "$list(range(33, 57)) + list(range(63, 98)) + [114, 120, 122]"
|
| 36 |
+
},
|
| 37 |
+
"input_channels": 1,
|
| 38 |
+
"resample_spacing": [
|
| 39 |
+
1.5,
|
| 40 |
+
1.5,
|
| 41 |
+
1.5
|
| 42 |
+
],
|
| 43 |
+
"sw_batch_size": 1,
|
| 44 |
+
"patch_size": [
|
| 45 |
+
128,
|
| 46 |
+
128,
|
| 47 |
+
128
|
| 48 |
+
],
|
| 49 |
+
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
|
| 50 |
+
"use_point_window": true,
|
| 51 |
+
"network_def": "$monai.networks.nets.vista3d132(in_channels=@input_channels)",
|
| 52 |
+
"network": "$@network_def.to(@device)",
|
| 53 |
+
"preprocessing_transforms": [
|
| 54 |
+
{
|
| 55 |
+
"_target_": "LoadImaged",
|
| 56 |
+
"keys": "@image_key",
|
| 57 |
+
"image_only": true
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"_target_": "EnsureChannelFirstd",
|
| 61 |
+
"keys": "@image_key"
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"_target_": "EnsureTyped",
|
| 65 |
+
"keys": "@image_key",
|
| 66 |
+
"device": "@device",
|
| 67 |
+
"track_meta": true
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"_target_": "Spacingd",
|
| 71 |
+
"keys": "@image_key",
|
| 72 |
+
"pixdim": "@resample_spacing",
|
| 73 |
+
"mode": "bilinear"
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"_target_": "CropForegroundd",
|
| 77 |
+
"keys": "@image_key",
|
| 78 |
+
"allow_smaller": true,
|
| 79 |
+
"margin": 10,
|
| 80 |
+
"source_key": "@image_key"
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"_target_": "monai.apps.vista3d.transforms.VistaPreTransformd",
|
| 84 |
+
"keys": "@image_key",
|
| 85 |
+
"subclass": "@subclass",
|
| 86 |
+
"labels_dict": "@labels_dict"
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"_target_": "ScaleIntensityRanged",
|
| 90 |
+
"keys": "@image_key",
|
| 91 |
+
"a_min": -963.8247715525971,
|
| 92 |
+
"a_max": 1053.678477684517,
|
| 93 |
+
"b_min": 0,
|
| 94 |
+
"b_max": 1,
|
| 95 |
+
"clip": true
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"_target_": "Orientationd",
|
| 99 |
+
"keys": "@image_key",
|
| 100 |
+
"axcodes": "RAS"
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"_target_": "CastToTyped",
|
| 104 |
+
"keys": "@image_key",
|
| 105 |
+
"dtype": "$torch.float32"
|
| 106 |
+
}
|
| 107 |
+
],
|
| 108 |
+
"preprocessing": {
|
| 109 |
+
"_target_": "Compose",
|
| 110 |
+
"transforms": "$@preprocessing_transforms "
|
| 111 |
+
},
|
| 112 |
+
"dataset": {
|
| 113 |
+
"_target_": "Dataset",
|
| 114 |
+
"data": "$[@input_dict]",
|
| 115 |
+
"transform": "@preprocessing"
|
| 116 |
+
},
|
| 117 |
+
"dataloader": {
|
| 118 |
+
"_target_": "ThreadDataLoader",
|
| 119 |
+
"dataset": "@dataset",
|
| 120 |
+
"batch_size": 1,
|
| 121 |
+
"shuffle": false,
|
| 122 |
+
"num_workers": 0
|
| 123 |
+
},
|
| 124 |
+
"inferer": {
|
| 125 |
+
"_target_": "scripts.inferer.Vista3dInferer",
|
| 126 |
+
"roi_size": "@patch_size",
|
| 127 |
+
"overlap": 0.3,
|
| 128 |
+
"sw_batch_size": "@sw_batch_size",
|
| 129 |
+
"use_point_window": "@use_point_window"
|
| 130 |
+
},
|
| 131 |
+
"postprocessing": {
|
| 132 |
+
"_target_": "Compose",
|
| 133 |
+
"transforms": [
|
| 134 |
+
{
|
| 135 |
+
"_target_": "ToDeviced",
|
| 136 |
+
"keys": "pred",
|
| 137 |
+
"device": "cpu",
|
| 138 |
+
"_disabled_": true
|
| 139 |
+
},
|
| 140 |
+
{
|
| 141 |
+
"_target_": "monai.apps.vista3d.transforms.VistaPostTransformd",
|
| 142 |
+
"keys": "pred"
|
| 143 |
+
},
|
| 144 |
+
{
|
| 145 |
+
"_target_": "Invertd",
|
| 146 |
+
"keys": "pred",
|
| 147 |
+
"transform": "$copy.deepcopy(@preprocessing)",
|
| 148 |
+
"orig_keys": "@image_key",
|
| 149 |
+
"nearest_interp": true,
|
| 150 |
+
"to_tensor": true
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"_target_": "Lambdad",
|
| 154 |
+
"func": "$lambda x: torch.nan_to_num(x, nan=255)",
|
| 155 |
+
"keys": "pred"
|
| 156 |
+
},
|
| 157 |
+
{
|
| 158 |
+
"_target_": "SaveImaged",
|
| 159 |
+
"keys": "pred",
|
| 160 |
+
"resample": false,
|
| 161 |
+
"output_dir": "@output_dir",
|
| 162 |
+
"output_ext": "@output_ext",
|
| 163 |
+
"output_dtype": "@output_dtype",
|
| 164 |
+
"output_postfix": "@output_postfix",
|
| 165 |
+
"separate_folder": "@separate_folder"
|
| 166 |
+
}
|
| 167 |
+
]
|
| 168 |
+
},
|
| 169 |
+
"handlers": [
|
| 170 |
+
{
|
| 171 |
+
"_target_": "StatsHandler",
|
| 172 |
+
"iteration_log": false
|
| 173 |
+
}
|
| 174 |
+
],
|
| 175 |
+
"checkpointloader": {
|
| 176 |
+
"_target_": "CheckpointLoader",
|
| 177 |
+
"load_path": "$@bundle_root + '/models/model.pt'",
|
| 178 |
+
"load_dict": {
|
| 179 |
+
"model": "@network"
|
| 180 |
+
},
|
| 181 |
+
"map_location": "@device"
|
| 182 |
+
},
|
| 183 |
+
"evaluator": {
|
| 184 |
+
"_target_": "scripts.evaluator.Vista3dEvaluator",
|
| 185 |
+
"device": "@device",
|
| 186 |
+
"val_data_loader": "@dataloader",
|
| 187 |
+
"network": "@network",
|
| 188 |
+
"inferer": "@inferer",
|
| 189 |
+
"postprocessing": "@postprocessing",
|
| 190 |
+
"val_handlers": "@handlers",
|
| 191 |
+
"amp": true,
|
| 192 |
+
"hyper_kwargs": {
|
| 193 |
+
"user_prompt": true,
|
| 194 |
+
"everything_labels": "@everything_labels"
|
| 195 |
+
}
|
| 196 |
+
},
|
| 197 |
+
"initialize": [
|
| 198 |
+
"$monai.utils.set_determinism(seed=123)",
|
| 199 |
+
"$@checkpointloader(@evaluator)"
|
| 200 |
+
],
|
| 201 |
+
"run": [
|
| 202 |
+
"$@evaluator.run()"
|
| 203 |
+
]
|
| 204 |
+
}
|
configs/inference_trt.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"base_path": null,
|
| 3 |
+
"+imports": [
|
| 4 |
+
"$from monai.networks import trt_compile"
|
| 5 |
+
],
|
| 6 |
+
"max_prompt_size": 4,
|
| 7 |
+
"head_trt_enabled": false,
|
| 8 |
+
"network_trt_args": {
|
| 9 |
+
"dynamic_batchsize": "$[1, @inferer#sw_batch_size, @inferer#sw_batch_size]"
|
| 10 |
+
},
|
| 11 |
+
"network_dev": "$@network_def.to(@device)",
|
| 12 |
+
"encoder": "$trt_compile(@network_dev, @bundle_root + '/models/model.pt' if not @base_path else @base_path, args=@network_trt_args, submodule=['image_encoder.encoder'])",
|
| 13 |
+
"head_trt_args": {
|
| 14 |
+
"dynamic_batchsize": "$[1, 1, @max_prompt_size]",
|
| 15 |
+
"fallback": "$True"
|
| 16 |
+
},
|
| 17 |
+
"head": "$trt_compile(@network_dev, @bundle_root + '/models/model.pt' if not @base_path else @base_path, args=@head_trt_args, submodule=['class_head']) if @head_trt_enabled else @network_dev",
|
| 18 |
+
"network": "$None if @encoder is None else @head"
|
| 19 |
+
}
|
configs/logging.conf
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[loggers]
|
| 2 |
+
keys=root
|
| 3 |
+
|
| 4 |
+
[handlers]
|
| 5 |
+
keys=consoleHandler,fileHandler
|
| 6 |
+
|
| 7 |
+
[formatters]
|
| 8 |
+
keys=fullFormatter
|
| 9 |
+
|
| 10 |
+
[logger_root]
|
| 11 |
+
level=INFO
|
| 12 |
+
handlers=consoleHandler,fileHandler
|
| 13 |
+
|
| 14 |
+
[handler_consoleHandler]
|
| 15 |
+
class=StreamHandler
|
| 16 |
+
level=INFO
|
| 17 |
+
formatter=fullFormatter
|
| 18 |
+
args=(sys.stdout,)
|
| 19 |
+
|
| 20 |
+
[handler_fileHandler]
|
| 21 |
+
class=FileHandler
|
| 22 |
+
level=INFO
|
| 23 |
+
formatter=fullFormatter
|
| 24 |
+
args=('training.log',)
|
| 25 |
+
|
| 26 |
+
[formatter_fullFormatter]
|
| 27 |
+
format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
|
configs/metadata.json
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json",
|
| 3 |
+
"version": "0.5.8",
|
| 4 |
+
"changelog": {
|
| 5 |
+
"0.5.8": "update to huggingface hosting",
|
| 6 |
+
"0.5.7": "change sw padding mode to replicate",
|
| 7 |
+
"0.5.6": "add mlflow support",
|
| 8 |
+
"0.5.5": "add arg for trt compiler base path",
|
| 9 |
+
"0.5.4": "add undefined label prompt check",
|
| 10 |
+
"0.5.3": "update readme",
|
| 11 |
+
"0.5.2": "fix eval issue",
|
| 12 |
+
"0.5.1": "add description for zero-shot and upate eval",
|
| 13 |
+
"0.5.0": "update json file link and add test",
|
| 14 |
+
"0.4.9": "fix oom issue and update readme",
|
| 15 |
+
"0.4.8": "use 0.3 overlap for inference",
|
| 16 |
+
"0.4.7": "update tensorrt benchmark results",
|
| 17 |
+
"0.4.6": "add tensorrt benchmark result and remove the metric part",
|
| 18 |
+
"0.4.5": "remove wrong path",
|
| 19 |
+
"0.4.4": "enable tensorrt inference",
|
| 20 |
+
"0.4.3": "fix CL and batch infer issues",
|
| 21 |
+
"0.4.2": "use MONAI components for network and utils",
|
| 22 |
+
"0.4.1": "initial OSS version"
|
| 23 |
+
},
|
| 24 |
+
"monai_version": "1.4.0",
|
| 25 |
+
"pytorch_version": "2.4.0",
|
| 26 |
+
"numpy_version": "1.24.4",
|
| 27 |
+
"required_packages_version": {
|
| 28 |
+
"matplotlib": "3.9.1",
|
| 29 |
+
"einops": "0.7.0",
|
| 30 |
+
"scikit-image": "0.23.2",
|
| 31 |
+
"nibabel": "5.2.1",
|
| 32 |
+
"pytorch-ignite": "0.4.11",
|
| 33 |
+
"cucim-cu12": "24.6.0",
|
| 34 |
+
"mlflow": "2.17.2",
|
| 35 |
+
"tensorboard": "2.17.0"
|
| 36 |
+
},
|
| 37 |
+
"supported_apps": {
|
| 38 |
+
"vista3d-nim": ""
|
| 39 |
+
},
|
| 40 |
+
"name": "VISTA3D",
|
| 41 |
+
"task": "Decathlon Spleen segmentation",
|
| 42 |
+
"description": "VISTA3D bundle",
|
| 43 |
+
"authors": "MONAI team",
|
| 44 |
+
"copyright": "Copyright (c) MONAI Consortium",
|
| 45 |
+
"data_source": "Task09_Spleen.tar from http://medicaldecathlon.com/",
|
| 46 |
+
"data_type": "nibabel",
|
| 47 |
+
"image_classes": "1 channel data, intensity scaled to [0, 1]",
|
| 48 |
+
"label_classes": "single channel data",
|
| 49 |
+
"pred_classes": "2 channels OneHot data",
|
| 50 |
+
"intended_use": "This is an example, not to be used for diagnostic purposes",
|
| 51 |
+
"references": [],
|
| 52 |
+
"network_data_format": {
|
| 53 |
+
"inputs": {
|
| 54 |
+
"image": {
|
| 55 |
+
"type": "image",
|
| 56 |
+
"format": "hounsfield",
|
| 57 |
+
"modality": "CT",
|
| 58 |
+
"num_channels": 1,
|
| 59 |
+
"spatial_shape": [
|
| 60 |
+
128,
|
| 61 |
+
128,
|
| 62 |
+
128
|
| 63 |
+
],
|
| 64 |
+
"dtype": "float32",
|
| 65 |
+
"value_range": [
|
| 66 |
+
0,
|
| 67 |
+
1
|
| 68 |
+
],
|
| 69 |
+
"is_patch_data": true,
|
| 70 |
+
"channel_def": {
|
| 71 |
+
"0": "image"
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
},
|
| 75 |
+
"outputs": {
|
| 76 |
+
"pred": {
|
| 77 |
+
"type": "image",
|
| 78 |
+
"format": "segmentation",
|
| 79 |
+
"num_channels": 1,
|
| 80 |
+
"spatial_shape": [
|
| 81 |
+
128,
|
| 82 |
+
128,
|
| 83 |
+
128
|
| 84 |
+
],
|
| 85 |
+
"dtype": "float32",
|
| 86 |
+
"value_range": [
|
| 87 |
+
0,
|
| 88 |
+
1
|
| 89 |
+
],
|
| 90 |
+
"is_patch_data": true,
|
| 91 |
+
"channel_def": {
|
| 92 |
+
"0": "background",
|
| 93 |
+
"1": "liver",
|
| 94 |
+
"2": "kidney",
|
| 95 |
+
"3": "spleen",
|
| 96 |
+
"4": "pancreas",
|
| 97 |
+
"5": "right kidney",
|
| 98 |
+
"6": "aorta",
|
| 99 |
+
"7": "inferior vena cava",
|
| 100 |
+
"8": "right adrenal gland",
|
| 101 |
+
"9": "left adrenal gland",
|
| 102 |
+
"10": "gallbladder",
|
| 103 |
+
"11": "esophagus",
|
| 104 |
+
"12": "stomach",
|
| 105 |
+
"13": "duodenum",
|
| 106 |
+
"14": "left kidney",
|
| 107 |
+
"15": "bladder",
|
| 108 |
+
"16": "prostate or uterus",
|
| 109 |
+
"17": "portal vein and splenic vein",
|
| 110 |
+
"18": "rectum",
|
| 111 |
+
"19": "small bowel",
|
| 112 |
+
"20": "lung",
|
| 113 |
+
"21": "bone",
|
| 114 |
+
"22": "brain",
|
| 115 |
+
"23": "lung tumor",
|
| 116 |
+
"24": "pancreatic tumor",
|
| 117 |
+
"25": "hepatic vessel",
|
| 118 |
+
"26": "hepatic tumor",
|
| 119 |
+
"27": "colon cancer primaries",
|
| 120 |
+
"28": "left lung upper lobe",
|
| 121 |
+
"29": "left lung lower lobe",
|
| 122 |
+
"30": "right lung upper lobe",
|
| 123 |
+
"31": "right lung middle lobe",
|
| 124 |
+
"32": "right lung lower lobe",
|
| 125 |
+
"33": "vertebrae L5",
|
| 126 |
+
"34": "vertebrae L4",
|
| 127 |
+
"35": "vertebrae L3",
|
| 128 |
+
"36": "vertebrae L2",
|
| 129 |
+
"37": "vertebrae L1",
|
| 130 |
+
"38": "vertebrae T12",
|
| 131 |
+
"39": "vertebrae T11",
|
| 132 |
+
"40": "vertebrae T10",
|
| 133 |
+
"41": "vertebrae T9",
|
| 134 |
+
"42": "vertebrae T8",
|
| 135 |
+
"43": "vertebrae T7",
|
| 136 |
+
"44": "vertebrae T6",
|
| 137 |
+
"45": "vertebrae T5",
|
| 138 |
+
"46": "vertebrae T4",
|
| 139 |
+
"47": "vertebrae T3",
|
| 140 |
+
"48": "vertebrae T2",
|
| 141 |
+
"49": "vertebrae T1",
|
| 142 |
+
"50": "vertebrae C7",
|
| 143 |
+
"51": "vertebrae C6",
|
| 144 |
+
"52": "vertebrae C5",
|
| 145 |
+
"53": "vertebrae C4",
|
| 146 |
+
"54": "vertebrae C3",
|
| 147 |
+
"55": "vertebrae C2",
|
| 148 |
+
"56": "vertebrae C1",
|
| 149 |
+
"57": "trachea",
|
| 150 |
+
"58": "left iliac artery",
|
| 151 |
+
"59": "right iliac artery",
|
| 152 |
+
"60": "left iliac vena",
|
| 153 |
+
"61": "right iliac vena",
|
| 154 |
+
"62": "colon",
|
| 155 |
+
"63": "left rib 1",
|
| 156 |
+
"64": "left rib 2",
|
| 157 |
+
"65": "left rib 3",
|
| 158 |
+
"66": "left rib 4",
|
| 159 |
+
"67": "left rib 5",
|
| 160 |
+
"68": "left rib 6",
|
| 161 |
+
"69": "left rib 7",
|
| 162 |
+
"70": "left rib 8",
|
| 163 |
+
"71": "left rib 9",
|
| 164 |
+
"72": "left rib 10",
|
| 165 |
+
"73": "left rib 11",
|
| 166 |
+
"74": "left rib 12",
|
| 167 |
+
"75": "right rib 1",
|
| 168 |
+
"76": "right rib 2",
|
| 169 |
+
"77": "right rib 3",
|
| 170 |
+
"78": "right rib 4",
|
| 171 |
+
"79": "right rib 5",
|
| 172 |
+
"80": "right rib 6",
|
| 173 |
+
"81": "right rib 7",
|
| 174 |
+
"82": "right rib 8",
|
| 175 |
+
"83": "right rib 9",
|
| 176 |
+
"84": "right rib 10",
|
| 177 |
+
"85": "right rib 11",
|
| 178 |
+
"86": "right rib 12",
|
| 179 |
+
"87": "left humerus",
|
| 180 |
+
"88": "right humerus",
|
| 181 |
+
"89": "left scapula",
|
| 182 |
+
"90": "right scapula",
|
| 183 |
+
"91": "left clavicula",
|
| 184 |
+
"92": "right clavicula",
|
| 185 |
+
"93": "left femur",
|
| 186 |
+
"94": "right femur",
|
| 187 |
+
"95": "left hip",
|
| 188 |
+
"96": "right hip",
|
| 189 |
+
"97": "sacrum",
|
| 190 |
+
"98": "left gluteus maximus",
|
| 191 |
+
"99": "right gluteus maximus",
|
| 192 |
+
"100": "left gluteus medius",
|
| 193 |
+
"101": "right gluteus medius",
|
| 194 |
+
"102": "left gluteus minimus",
|
| 195 |
+
"103": "right gluteus minimus",
|
| 196 |
+
"104": "left autochthon",
|
| 197 |
+
"105": "right autochthon",
|
| 198 |
+
"106": "left iliopsoas",
|
| 199 |
+
"107": "right iliopsoas",
|
| 200 |
+
"108": "left atrial appendage",
|
| 201 |
+
"109": "brachiocephalic trunk",
|
| 202 |
+
"110": "left brachiocephalic vein",
|
| 203 |
+
"111": "right brachiocephalic vein",
|
| 204 |
+
"112": "left common carotid artery",
|
| 205 |
+
"113": "right common carotid artery",
|
| 206 |
+
"114": "costal cartilages",
|
| 207 |
+
"115": "heart",
|
| 208 |
+
"116": "left kidney cyst",
|
| 209 |
+
"117": "right kidney cyst",
|
| 210 |
+
"118": "prostate",
|
| 211 |
+
"119": "pulmonary vein",
|
| 212 |
+
"120": "skull",
|
| 213 |
+
"121": "spinal cord",
|
| 214 |
+
"122": "sternum",
|
| 215 |
+
"123": "left subclavian artery",
|
| 216 |
+
"124": "right subclavian artery",
|
| 217 |
+
"125": "superior vena cava",
|
| 218 |
+
"126": "thyroid gland",
|
| 219 |
+
"127": "vertebrae S1",
|
| 220 |
+
"128": "bone lesion",
|
| 221 |
+
"129": "kidney mass",
|
| 222 |
+
"130": "liver tumor",
|
| 223 |
+
"131": "vertebrae L6",
|
| 224 |
+
"132": "airway"
|
| 225 |
+
}
|
| 226 |
+
}
|
| 227 |
+
}
|
| 228 |
+
}
|
| 229 |
+
}
|
configs/mgpu_evaluate.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"device": "$torch.device('cuda:' + os.environ['LOCAL_RANK'])",
|
| 3 |
+
"network": {
|
| 4 |
+
"_target_": "torch.nn.parallel.DistributedDataParallel",
|
| 5 |
+
"module": "$@network_def.to(@device)",
|
| 6 |
+
"device_ids": [
|
| 7 |
+
"@device"
|
| 8 |
+
]
|
| 9 |
+
},
|
| 10 |
+
"validate#sampler": {
|
| 11 |
+
"_target_": "DistributedSampler",
|
| 12 |
+
"dataset": "@validate#dataset",
|
| 13 |
+
"even_divisible": false,
|
| 14 |
+
"shuffle": false
|
| 15 |
+
},
|
| 16 |
+
"validate#dataloader#sampler": "@validate#sampler",
|
| 17 |
+
"validate#handlers#1#_disabled_": "$dist.get_rank() > 0",
|
| 18 |
+
"initialize": [
|
| 19 |
+
"$import torch.distributed as dist",
|
| 20 |
+
"$dist.is_initialized() or dist.init_process_group(backend='nccl')",
|
| 21 |
+
"$torch.cuda.set_device(@device)"
|
| 22 |
+
],
|
| 23 |
+
"run": [
|
| 24 |
+
"$@validate#evaluator.run()"
|
| 25 |
+
],
|
| 26 |
+
"finalize": [
|
| 27 |
+
"$dist.is_initialized() and dist.destroy_process_group()"
|
| 28 |
+
]
|
| 29 |
+
}
|
configs/msd_task09_spleen_folds.json
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"testing": [
|
| 3 |
+
{
|
| 4 |
+
"image": "imagesTs/spleen_15.nii.gz"
|
| 5 |
+
},
|
| 6 |
+
{
|
| 7 |
+
"image": "imagesTs/spleen_23.nii.gz"
|
| 8 |
+
},
|
| 9 |
+
{
|
| 10 |
+
"image": "imagesTs/spleen_1.nii.gz"
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"image": "imagesTs/spleen_42.nii.gz"
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"image": "imagesTs/spleen_50.nii.gz"
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"image": "imagesTs/spleen_54.nii.gz"
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"image": "imagesTs/spleen_37.nii.gz"
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"image": "imagesTs/spleen_58.nii.gz"
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"image": "imagesTs/spleen_39.nii.gz"
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"image": "imagesTs/spleen_48.nii.gz"
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"image": "imagesTs/spleen_35.nii.gz"
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"image": "imagesTs/spleen_11.nii.gz"
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"image": "imagesTs/spleen_7.nii.gz"
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"image": "imagesTs/spleen_30.nii.gz"
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"image": "imagesTs/spleen_43.nii.gz"
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"image": "imagesTs/spleen_51.nii.gz"
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"image": "imagesTs/spleen_36.nii.gz"
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"image": "imagesTs/spleen_55.nii.gz"
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"image": "imagesTs/spleen_57.nii.gz"
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"image": "imagesTs/spleen_34.nii.gz"
|
| 62 |
+
}
|
| 63 |
+
],
|
| 64 |
+
"training": [
|
| 65 |
+
{
|
| 66 |
+
"fold": 0,
|
| 67 |
+
"image": "imagesTr/spleen_19.nii.gz",
|
| 68 |
+
"label": "labelsTr/spleen_19.nii.gz"
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"fold": 0,
|
| 72 |
+
"image": "imagesTr/spleen_31.nii.gz",
|
| 73 |
+
"label": "labelsTr/spleen_31.nii.gz"
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"fold": 0,
|
| 77 |
+
"image": "imagesTr/spleen_52.nii.gz",
|
| 78 |
+
"label": "labelsTr/spleen_52.nii.gz"
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"fold": 0,
|
| 82 |
+
"image": "imagesTr/spleen_40.nii.gz",
|
| 83 |
+
"label": "labelsTr/spleen_40.nii.gz"
|
| 84 |
+
},
|
| 85 |
+
{
|
| 86 |
+
"fold": 0,
|
| 87 |
+
"image": "imagesTr/spleen_3.nii.gz",
|
| 88 |
+
"label": "labelsTr/spleen_3.nii.gz"
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"fold": 0,
|
| 92 |
+
"image": "imagesTr/spleen_17.nii.gz",
|
| 93 |
+
"label": "labelsTr/spleen_17.nii.gz"
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"fold": 0,
|
| 97 |
+
"image": "imagesTr/spleen_21.nii.gz",
|
| 98 |
+
"label": "labelsTr/spleen_21.nii.gz"
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"fold": 0,
|
| 102 |
+
"image": "imagesTr/spleen_33.nii.gz",
|
| 103 |
+
"label": "labelsTr/spleen_33.nii.gz"
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"fold": 1,
|
| 107 |
+
"image": "imagesTr/spleen_9.nii.gz",
|
| 108 |
+
"label": "labelsTr/spleen_9.nii.gz"
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"fold": 1,
|
| 112 |
+
"image": "imagesTr/spleen_29.nii.gz",
|
| 113 |
+
"label": "labelsTr/spleen_29.nii.gz"
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"fold": 1,
|
| 117 |
+
"image": "imagesTr/spleen_46.nii.gz",
|
| 118 |
+
"label": "labelsTr/spleen_46.nii.gz"
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"fold": 1,
|
| 122 |
+
"image": "imagesTr/spleen_25.nii.gz",
|
| 123 |
+
"label": "labelsTr/spleen_25.nii.gz"
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"fold": 1,
|
| 127 |
+
"image": "imagesTr/spleen_13.nii.gz",
|
| 128 |
+
"label": "labelsTr/spleen_13.nii.gz"
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"fold": 1,
|
| 132 |
+
"image": "imagesTr/spleen_62.nii.gz",
|
| 133 |
+
"label": "labelsTr/spleen_62.nii.gz"
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"fold": 1,
|
| 137 |
+
"image": "imagesTr/spleen_27.nii.gz",
|
| 138 |
+
"label": "labelsTr/spleen_27.nii.gz"
|
| 139 |
+
},
|
| 140 |
+
{
|
| 141 |
+
"fold": 1,
|
| 142 |
+
"image": "imagesTr/spleen_44.nii.gz",
|
| 143 |
+
"label": "labelsTr/spleen_44.nii.gz"
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"fold": 2,
|
| 147 |
+
"image": "imagesTr/spleen_56.nii.gz",
|
| 148 |
+
"label": "labelsTr/spleen_56.nii.gz"
|
| 149 |
+
},
|
| 150 |
+
{
|
| 151 |
+
"fold": 2,
|
| 152 |
+
"image": "imagesTr/spleen_60.nii.gz",
|
| 153 |
+
"label": "labelsTr/spleen_60.nii.gz"
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"fold": 2,
|
| 157 |
+
"image": "imagesTr/spleen_2.nii.gz",
|
| 158 |
+
"label": "labelsTr/spleen_2.nii.gz"
|
| 159 |
+
},
|
| 160 |
+
{
|
| 161 |
+
"fold": 2,
|
| 162 |
+
"image": "imagesTr/spleen_53.nii.gz",
|
| 163 |
+
"label": "labelsTr/spleen_53.nii.gz"
|
| 164 |
+
},
|
| 165 |
+
{
|
| 166 |
+
"fold": 2,
|
| 167 |
+
"image": "imagesTr/spleen_41.nii.gz",
|
| 168 |
+
"label": "labelsTr/spleen_41.nii.gz"
|
| 169 |
+
},
|
| 170 |
+
{
|
| 171 |
+
"fold": 2,
|
| 172 |
+
"image": "imagesTr/spleen_22.nii.gz",
|
| 173 |
+
"label": "labelsTr/spleen_22.nii.gz"
|
| 174 |
+
},
|
| 175 |
+
{
|
| 176 |
+
"fold": 2,
|
| 177 |
+
"image": "imagesTr/spleen_14.nii.gz",
|
| 178 |
+
"label": "labelsTr/spleen_14.nii.gz"
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"fold": 2,
|
| 182 |
+
"image": "imagesTr/spleen_18.nii.gz",
|
| 183 |
+
"label": "labelsTr/spleen_18.nii.gz"
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"fold": 3,
|
| 187 |
+
"image": "imagesTr/spleen_20.nii.gz",
|
| 188 |
+
"label": "labelsTr/spleen_20.nii.gz"
|
| 189 |
+
},
|
| 190 |
+
{
|
| 191 |
+
"fold": 3,
|
| 192 |
+
"image": "imagesTr/spleen_32.nii.gz",
|
| 193 |
+
"label": "labelsTr/spleen_32.nii.gz"
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"fold": 3,
|
| 197 |
+
"image": "imagesTr/spleen_16.nii.gz",
|
| 198 |
+
"label": "labelsTr/spleen_16.nii.gz"
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"fold": 3,
|
| 202 |
+
"image": "imagesTr/spleen_12.nii.gz",
|
| 203 |
+
"label": "labelsTr/spleen_12.nii.gz"
|
| 204 |
+
},
|
| 205 |
+
{
|
| 206 |
+
"fold": 3,
|
| 207 |
+
"image": "imagesTr/spleen_63.nii.gz",
|
| 208 |
+
"label": "labelsTr/spleen_63.nii.gz"
|
| 209 |
+
},
|
| 210 |
+
{
|
| 211 |
+
"fold": 3,
|
| 212 |
+
"image": "imagesTr/spleen_28.nii.gz",
|
| 213 |
+
"label": "labelsTr/spleen_28.nii.gz"
|
| 214 |
+
},
|
| 215 |
+
{
|
| 216 |
+
"fold": 3,
|
| 217 |
+
"image": "imagesTr/spleen_24.nii.gz",
|
| 218 |
+
"label": "labelsTr/spleen_24.nii.gz"
|
| 219 |
+
},
|
| 220 |
+
{
|
| 221 |
+
"fold": 3,
|
| 222 |
+
"image": "imagesTr/spleen_59.nii.gz",
|
| 223 |
+
"label": "labelsTr/spleen_59.nii.gz"
|
| 224 |
+
},
|
| 225 |
+
{
|
| 226 |
+
"fold": 4,
|
| 227 |
+
"image": "imagesTr/spleen_47.nii.gz",
|
| 228 |
+
"label": "labelsTr/spleen_47.nii.gz"
|
| 229 |
+
},
|
| 230 |
+
{
|
| 231 |
+
"fold": 4,
|
| 232 |
+
"image": "imagesTr/spleen_8.nii.gz",
|
| 233 |
+
"label": "labelsTr/spleen_8.nii.gz"
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
"fold": 4,
|
| 237 |
+
"image": "imagesTr/spleen_6.nii.gz",
|
| 238 |
+
"label": "labelsTr/spleen_6.nii.gz"
|
| 239 |
+
},
|
| 240 |
+
{
|
| 241 |
+
"fold": 4,
|
| 242 |
+
"image": "imagesTr/spleen_61.nii.gz",
|
| 243 |
+
"label": "labelsTr/spleen_61.nii.gz"
|
| 244 |
+
},
|
| 245 |
+
{
|
| 246 |
+
"fold": 4,
|
| 247 |
+
"image": "imagesTr/spleen_10.nii.gz",
|
| 248 |
+
"label": "labelsTr/spleen_10.nii.gz"
|
| 249 |
+
},
|
| 250 |
+
{
|
| 251 |
+
"fold": 4,
|
| 252 |
+
"image": "imagesTr/spleen_38.nii.gz",
|
| 253 |
+
"label": "labelsTr/spleen_38.nii.gz"
|
| 254 |
+
},
|
| 255 |
+
{
|
| 256 |
+
"fold": 4,
|
| 257 |
+
"image": "imagesTr/spleen_45.nii.gz",
|
| 258 |
+
"label": "labelsTr/spleen_45.nii.gz"
|
| 259 |
+
},
|
| 260 |
+
{
|
| 261 |
+
"fold": 4,
|
| 262 |
+
"image": "imagesTr/spleen_26.nii.gz",
|
| 263 |
+
"label": "labelsTr/spleen_26.nii.gz"
|
| 264 |
+
},
|
| 265 |
+
{
|
| 266 |
+
"image": "imagesTr/spleen_49.nii.gz",
|
| 267 |
+
"label": "labelsTr/spleen_49.nii.gz",
|
| 268 |
+
"fold": 0
|
| 269 |
+
}
|
| 270 |
+
]
|
| 271 |
+
}
|
configs/multi_gpu_train.json
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"device": "$torch.device('cuda:' + os.environ['LOCAL_RANK'])",
|
| 3 |
+
"use_tensorboard": "$dist.get_rank() == 0",
|
| 4 |
+
"network": {
|
| 5 |
+
"_target_": "torch.nn.parallel.DistributedDataParallel",
|
| 6 |
+
"module": "$@network_def.to(@device)",
|
| 7 |
+
"find_unused_parameters": true,
|
| 8 |
+
"device_ids": [
|
| 9 |
+
"@device"
|
| 10 |
+
]
|
| 11 |
+
},
|
| 12 |
+
"train#sampler": {
|
| 13 |
+
"_target_": "DistributedSampler",
|
| 14 |
+
"dataset": "@train#dataset",
|
| 15 |
+
"even_divisible": true,
|
| 16 |
+
"shuffle": true
|
| 17 |
+
},
|
| 18 |
+
"train#dataloader#sampler": "@train#sampler",
|
| 19 |
+
"train#dataloader#shuffle": false,
|
| 20 |
+
"train#trainer#train_handlers": "$@train#handlers[: -1 if dist.get_rank() > 0 else None]",
|
| 21 |
+
"validate#sampler": {
|
| 22 |
+
"_target_": "DistributedSampler",
|
| 23 |
+
"dataset": "@validate#dataset",
|
| 24 |
+
"even_divisible": false,
|
| 25 |
+
"shuffle": false
|
| 26 |
+
},
|
| 27 |
+
"validate#dataloader#sampler": "@validate#sampler",
|
| 28 |
+
"validate#evaluator#val_handlers": "$@validate#handlers[: -2 if dist.get_rank() > 0 else None]",
|
| 29 |
+
"initialize": [
|
| 30 |
+
"$import torch.distributed as dist",
|
| 31 |
+
"$dist.is_initialized() or dist.init_process_group(backend='nccl')",
|
| 32 |
+
"$torch.cuda.set_device(@device)",
|
| 33 |
+
"$monai.utils.set_determinism(seed=123)"
|
| 34 |
+
],
|
| 35 |
+
"run": [
|
| 36 |
+
"$@validate#handlers#0.set_trainer(trainer=@train#trainer) if @early_stop else None",
|
| 37 |
+
"$@train#trainer.run()"
|
| 38 |
+
],
|
| 39 |
+
"finalize": [
|
| 40 |
+
"$dist.is_initialized() and dist.destroy_process_group()"
|
| 41 |
+
]
|
| 42 |
+
}
|
configs/train.json
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"imports": [
|
| 3 |
+
"$import glob",
|
| 4 |
+
"$import os",
|
| 5 |
+
"$import scripts",
|
| 6 |
+
"$import ignite",
|
| 7 |
+
"$import copy"
|
| 8 |
+
],
|
| 9 |
+
"bundle_root": ".",
|
| 10 |
+
"ckpt_dir": "$@bundle_root + '/models'",
|
| 11 |
+
"output_dir": "$@bundle_root + '/eval'",
|
| 12 |
+
"data_list_file_path": "$@bundle_root + '/configs/msd_task09_spleen_folds.json'",
|
| 13 |
+
"dataset_dir": "/data/Task09_Spleen",
|
| 14 |
+
"use_tensorboard": true,
|
| 15 |
+
"finetune": false,
|
| 16 |
+
"finetune_model_path": "$@bundle_root + '/models/model.pt'",
|
| 17 |
+
"early_stop": false,
|
| 18 |
+
"use_mlflow": false,
|
| 19 |
+
"mlflow_dir": "$@bundle_root + '/mlruns'",
|
| 20 |
+
"fold": 0,
|
| 21 |
+
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
|
| 22 |
+
"epochs": 5,
|
| 23 |
+
"val_interval": 1,
|
| 24 |
+
"val_at_start": false,
|
| 25 |
+
"sw_overlap": 0.625,
|
| 26 |
+
"learning_rate": 0.0001,
|
| 27 |
+
"num_patches_per_image": 1,
|
| 28 |
+
"input_channels": 1,
|
| 29 |
+
"output_classes": 2,
|
| 30 |
+
"max_point": 5,
|
| 31 |
+
"max_prompt": null,
|
| 32 |
+
"max_backprompt": null,
|
| 33 |
+
"max_foreprompt": null,
|
| 34 |
+
"drop_label_prob": 0.25,
|
| 35 |
+
"drop_point_prob": 0.25,
|
| 36 |
+
"exclude_background": true,
|
| 37 |
+
"label_set": null,
|
| 38 |
+
"val_label_set": "@label_set",
|
| 39 |
+
"amp": true,
|
| 40 |
+
"train_datalist": "$monai.auto3dseg.utils.datafold_read(datalist=@data_list_file_path, basedir=@dataset_dir, fold=@fold)[0]",
|
| 41 |
+
"val_datalist": "$monai.auto3dseg.utils.datafold_read(datalist=@data_list_file_path, basedir=@dataset_dir, fold=@fold)[1]",
|
| 42 |
+
"patch_size": [
|
| 43 |
+
128,
|
| 44 |
+
128,
|
| 45 |
+
128
|
| 46 |
+
],
|
| 47 |
+
"patch_size_valid": "$@patch_size",
|
| 48 |
+
"network_def": "$monai.networks.nets.vista3d132(in_channels=@input_channels)",
|
| 49 |
+
"network": "$@network_def.to(@device)",
|
| 50 |
+
"loss": {
|
| 51 |
+
"_target_": "DiceCELoss",
|
| 52 |
+
"include_background": true,
|
| 53 |
+
"sigmoid": true,
|
| 54 |
+
"smooth_dr": 1e-05,
|
| 55 |
+
"smooth_nr": 0,
|
| 56 |
+
"squared_pred": true,
|
| 57 |
+
"to_onehot_y": false
|
| 58 |
+
},
|
| 59 |
+
"optimizer": {
|
| 60 |
+
"_target_": "torch.optim.AdamW",
|
| 61 |
+
"params": "$@network.parameters()",
|
| 62 |
+
"lr": "@learning_rate",
|
| 63 |
+
"weight_decay": 1e-05
|
| 64 |
+
},
|
| 65 |
+
"lr_schedule": {
|
| 66 |
+
"activate": true,
|
| 67 |
+
"lr_scheduler": {
|
| 68 |
+
"_target_": "monai.optimizers.WarmupCosineSchedule",
|
| 69 |
+
"optimizer": "@optimizer",
|
| 70 |
+
"t_total": "$@epochs",
|
| 71 |
+
"warmup_steps": 3,
|
| 72 |
+
"warmup_multiplier": 0.1
|
| 73 |
+
}
|
| 74 |
+
},
|
| 75 |
+
"resample_to_spacing": [
|
| 76 |
+
1.5,
|
| 77 |
+
1.5,
|
| 78 |
+
1.5
|
| 79 |
+
],
|
| 80 |
+
"train": {
|
| 81 |
+
"deterministic_transforms": [
|
| 82 |
+
{
|
| 83 |
+
"_target_": "LoadImaged",
|
| 84 |
+
"keys": [
|
| 85 |
+
"image",
|
| 86 |
+
"label"
|
| 87 |
+
],
|
| 88 |
+
"image_only": true,
|
| 89 |
+
"ensure_channel_first": true
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"_target_": "CropForegroundd",
|
| 93 |
+
"keys": [
|
| 94 |
+
"image",
|
| 95 |
+
"label"
|
| 96 |
+
],
|
| 97 |
+
"source_key": "image",
|
| 98 |
+
"margin": 10,
|
| 99 |
+
"allow_smaller": true,
|
| 100 |
+
"start_coord_key": null,
|
| 101 |
+
"end_coord_key": null
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"_target_": "ScaleIntensityRanged",
|
| 105 |
+
"keys": "image",
|
| 106 |
+
"a_min": -963.8247715525971,
|
| 107 |
+
"a_max": 1053.678477684517,
|
| 108 |
+
"b_min": 0.0,
|
| 109 |
+
"b_max": 1.0,
|
| 110 |
+
"clip": true
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
"_target_": "Orientationd",
|
| 114 |
+
"keys": [
|
| 115 |
+
"image",
|
| 116 |
+
"label"
|
| 117 |
+
],
|
| 118 |
+
"axcodes": "RAS"
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"_target_": "Spacingd",
|
| 122 |
+
"keys": [
|
| 123 |
+
"image",
|
| 124 |
+
"label"
|
| 125 |
+
],
|
| 126 |
+
"pixdim": "$@resample_to_spacing",
|
| 127 |
+
"mode": [
|
| 128 |
+
"bilinear",
|
| 129 |
+
"nearest"
|
| 130 |
+
]
|
| 131 |
+
},
|
| 132 |
+
{
|
| 133 |
+
"_target_": "CastToTyped",
|
| 134 |
+
"keys": [
|
| 135 |
+
"image",
|
| 136 |
+
"label"
|
| 137 |
+
],
|
| 138 |
+
"dtype": [
|
| 139 |
+
"$torch.float32",
|
| 140 |
+
"$torch.uint8"
|
| 141 |
+
]
|
| 142 |
+
},
|
| 143 |
+
{
|
| 144 |
+
"_target_": "EnsureTyped",
|
| 145 |
+
"keys": [
|
| 146 |
+
"image",
|
| 147 |
+
"label"
|
| 148 |
+
],
|
| 149 |
+
"track_meta": true
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"_target_": "SpatialPadd",
|
| 153 |
+
"keys": [
|
| 154 |
+
"image",
|
| 155 |
+
"label"
|
| 156 |
+
],
|
| 157 |
+
"spatial_size": "@patch_size",
|
| 158 |
+
"mode": [
|
| 159 |
+
"constant",
|
| 160 |
+
"constant"
|
| 161 |
+
]
|
| 162 |
+
}
|
| 163 |
+
],
|
| 164 |
+
"random_transforms": [
|
| 165 |
+
{
|
| 166 |
+
"_target_": "RandCropByLabelClassesd",
|
| 167 |
+
"keys": [
|
| 168 |
+
"image",
|
| 169 |
+
"label"
|
| 170 |
+
],
|
| 171 |
+
"label_key": "label",
|
| 172 |
+
"num_classes": "@output_classes",
|
| 173 |
+
"spatial_size": "@patch_size",
|
| 174 |
+
"num_samples": "@num_patches_per_image",
|
| 175 |
+
"warn": false
|
| 176 |
+
},
|
| 177 |
+
{
|
| 178 |
+
"_target_": "ResizeWithPadOrCropd",
|
| 179 |
+
"keys": [
|
| 180 |
+
"image",
|
| 181 |
+
"label"
|
| 182 |
+
],
|
| 183 |
+
"spatial_size": "@patch_size"
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"_target_": "RandScaleIntensityd",
|
| 187 |
+
"keys": "image",
|
| 188 |
+
"prob": 0.1,
|
| 189 |
+
"factors": 0.1
|
| 190 |
+
},
|
| 191 |
+
{
|
| 192 |
+
"_target_": "RandShiftIntensityd",
|
| 193 |
+
"keys": "image",
|
| 194 |
+
"prob": 0.1,
|
| 195 |
+
"offsets": 0.1
|
| 196 |
+
}
|
| 197 |
+
],
|
| 198 |
+
"inferer": {
|
| 199 |
+
"_target_": "SimpleInferer"
|
| 200 |
+
},
|
| 201 |
+
"preprocessing": {
|
| 202 |
+
"_target_": "Compose",
|
| 203 |
+
"transforms": "$@train#deterministic_transforms + @train#random_transforms"
|
| 204 |
+
},
|
| 205 |
+
"dataset": {
|
| 206 |
+
"_target_": "Dataset",
|
| 207 |
+
"data": "@train_datalist",
|
| 208 |
+
"transform": "@train#preprocessing"
|
| 209 |
+
},
|
| 210 |
+
"dataloader": {
|
| 211 |
+
"_target_": "DataLoader",
|
| 212 |
+
"dataset": "@train#dataset",
|
| 213 |
+
"batch_size": 1,
|
| 214 |
+
"shuffle": true,
|
| 215 |
+
"num_workers": 4,
|
| 216 |
+
"pin_memory": true,
|
| 217 |
+
"persistent_workers": true
|
| 218 |
+
},
|
| 219 |
+
"handlers": [
|
| 220 |
+
{
|
| 221 |
+
"_target_": "CheckpointLoader",
|
| 222 |
+
"_disabled_": "$not @finetune",
|
| 223 |
+
"load_path": "@finetune_model_path",
|
| 224 |
+
"load_dict": {
|
| 225 |
+
"model": "@network"
|
| 226 |
+
}
|
| 227 |
+
},
|
| 228 |
+
{
|
| 229 |
+
"_target_": "LrScheduleHandler",
|
| 230 |
+
"_disabled_": "$not @lr_schedule#activate",
|
| 231 |
+
"lr_scheduler": "@lr_schedule#lr_scheduler",
|
| 232 |
+
"print_lr": true
|
| 233 |
+
},
|
| 234 |
+
{
|
| 235 |
+
"_target_": "ValidationHandler",
|
| 236 |
+
"validator": "@validate#evaluator",
|
| 237 |
+
"epoch_level": true,
|
| 238 |
+
"exec_at_start": "@val_at_start",
|
| 239 |
+
"interval": "@val_interval"
|
| 240 |
+
},
|
| 241 |
+
{
|
| 242 |
+
"_target_": "TensorBoardStatsHandler",
|
| 243 |
+
"_disabled_": "$not @use_tensorboard",
|
| 244 |
+
"log_dir": "@output_dir",
|
| 245 |
+
"tag_name": "train_loss",
|
| 246 |
+
"output_transform": "$monai.handlers.from_engine(['loss'], first=True)"
|
| 247 |
+
},
|
| 248 |
+
{
|
| 249 |
+
"_target_": "StatsHandler",
|
| 250 |
+
"tag_name": "train_loss",
|
| 251 |
+
"name": "StatsHandler",
|
| 252 |
+
"output_transform": "$monai.handlers.from_engine(['loss'], first=True)"
|
| 253 |
+
},
|
| 254 |
+
{
|
| 255 |
+
"_target_": "MLFlowHandler",
|
| 256 |
+
"_disabled_": "$not @use_mlflow",
|
| 257 |
+
"tracking_uri": "$os.path.abspath(@mlflow_dir)",
|
| 258 |
+
"output_transform": "$monai.handlers.from_engine(['loss'], first=True)"
|
| 259 |
+
}
|
| 260 |
+
],
|
| 261 |
+
"key_metric": {
|
| 262 |
+
"train_accuracy": {
|
| 263 |
+
"_target_": "ignite.metrics.Accuracy",
|
| 264 |
+
"output_transform": "$monai.handlers.from_engine(['pred', 'label'])"
|
| 265 |
+
}
|
| 266 |
+
},
|
| 267 |
+
"trainer": {
|
| 268 |
+
"_target_": "scripts.trainer.Vista3dTrainer",
|
| 269 |
+
"max_epochs": "@epochs",
|
| 270 |
+
"device": "@device",
|
| 271 |
+
"train_data_loader": "@train#dataloader",
|
| 272 |
+
"network": "@network",
|
| 273 |
+
"loss_function": "@loss",
|
| 274 |
+
"optimizer": "@optimizer",
|
| 275 |
+
"inferer": "@train#inferer",
|
| 276 |
+
"key_train_metric": null,
|
| 277 |
+
"train_handlers": "@train#handlers",
|
| 278 |
+
"amp": "@amp",
|
| 279 |
+
"hyper_kwargs": {
|
| 280 |
+
"output_classes": "@output_classes",
|
| 281 |
+
"max_point": "@max_point",
|
| 282 |
+
"max_prompt": "@max_prompt",
|
| 283 |
+
"max_backprompt": "@max_backprompt",
|
| 284 |
+
"max_foreprompt": "@max_foreprompt",
|
| 285 |
+
"drop_label_prob": "@drop_label_prob",
|
| 286 |
+
"drop_point_prob": "@drop_point_prob",
|
| 287 |
+
"exclude_background": "@exclude_background",
|
| 288 |
+
"label_set": "@label_set",
|
| 289 |
+
"patch_size": "@patch_size",
|
| 290 |
+
"user_prompt": false
|
| 291 |
+
}
|
| 292 |
+
}
|
| 293 |
+
},
|
| 294 |
+
"validate": {
|
| 295 |
+
"preprocessing": {
|
| 296 |
+
"_target_": "Compose",
|
| 297 |
+
"transforms": "$@train#deterministic_transforms"
|
| 298 |
+
},
|
| 299 |
+
"postprocessing": {
|
| 300 |
+
"_target_": "Compose",
|
| 301 |
+
"transforms": [
|
| 302 |
+
{
|
| 303 |
+
"_target_": "AsDiscreted",
|
| 304 |
+
"keys": "pred",
|
| 305 |
+
"threshold": 0.0
|
| 306 |
+
}
|
| 307 |
+
]
|
| 308 |
+
},
|
| 309 |
+
"dataset": {
|
| 310 |
+
"_target_": "Dataset",
|
| 311 |
+
"data": "$@val_datalist",
|
| 312 |
+
"transform": "@validate#preprocessing"
|
| 313 |
+
},
|
| 314 |
+
"dataloader": {
|
| 315 |
+
"_target_": "DataLoader",
|
| 316 |
+
"dataset": "@validate#dataset",
|
| 317 |
+
"batch_size": 1,
|
| 318 |
+
"shuffle": false,
|
| 319 |
+
"num_workers": 4
|
| 320 |
+
},
|
| 321 |
+
"inferer": {
|
| 322 |
+
"_target_": "scripts.inferer.Vista3dInferer",
|
| 323 |
+
"roi_size": "@patch_size_valid",
|
| 324 |
+
"overlap": "@sw_overlap"
|
| 325 |
+
},
|
| 326 |
+
"handlers": [
|
| 327 |
+
{
|
| 328 |
+
"_target_": "EarlyStopHandler",
|
| 329 |
+
"_disabled_": "$not @early_stop",
|
| 330 |
+
"trainer": null,
|
| 331 |
+
"patience": 2,
|
| 332 |
+
"score_function": "$scripts.score_function",
|
| 333 |
+
"min_delta": 0.01
|
| 334 |
+
},
|
| 335 |
+
{
|
| 336 |
+
"_target_": "TensorBoardStatsHandler",
|
| 337 |
+
"_disabled_": "$not @use_tensorboard",
|
| 338 |
+
"log_dir": "@output_dir",
|
| 339 |
+
"iteration_log": false
|
| 340 |
+
},
|
| 341 |
+
{
|
| 342 |
+
"_target_": "StatsHandler",
|
| 343 |
+
"iteration_log": false,
|
| 344 |
+
"name": "StatsHandler"
|
| 345 |
+
},
|
| 346 |
+
{
|
| 347 |
+
"_target_": "CheckpointSaver",
|
| 348 |
+
"save_dir": "@ckpt_dir",
|
| 349 |
+
"save_dict": {
|
| 350 |
+
"model": "@network"
|
| 351 |
+
},
|
| 352 |
+
"save_key_metric": true,
|
| 353 |
+
"key_metric_filename": "model.pt"
|
| 354 |
+
},
|
| 355 |
+
{
|
| 356 |
+
"_target_": "MLFlowHandler",
|
| 357 |
+
"_disabled_": "$not @use_mlflow",
|
| 358 |
+
"iteration_log": false,
|
| 359 |
+
"tracking_uri": "$os.path.abspath(@mlflow_dir)"
|
| 360 |
+
}
|
| 361 |
+
],
|
| 362 |
+
"key_metric": {
|
| 363 |
+
"val_mean_dice": {
|
| 364 |
+
"_target_": "MeanDice",
|
| 365 |
+
"include_background": false,
|
| 366 |
+
"output_transform": "$monai.handlers.from_engine(['pred', 'label'])",
|
| 367 |
+
"num_classes": "@output_classes"
|
| 368 |
+
}
|
| 369 |
+
},
|
| 370 |
+
"additional_metrics": {
|
| 371 |
+
"val_accuracy": {
|
| 372 |
+
"_target_": "ignite.metrics.Accuracy",
|
| 373 |
+
"output_transform": "$monai.handlers.from_engine(['pred', 'label'])"
|
| 374 |
+
}
|
| 375 |
+
},
|
| 376 |
+
"evaluator": {
|
| 377 |
+
"_target_": "scripts.evaluator.Vista3dEvaluator",
|
| 378 |
+
"device": "@device",
|
| 379 |
+
"val_data_loader": "@validate#dataloader",
|
| 380 |
+
"network": "@network",
|
| 381 |
+
"inferer": "@validate#inferer",
|
| 382 |
+
"postprocessing": "@validate#postprocessing",
|
| 383 |
+
"key_val_metric": "@validate#key_metric",
|
| 384 |
+
"additional_metrics": null,
|
| 385 |
+
"val_handlers": "@validate#handlers",
|
| 386 |
+
"amp": true,
|
| 387 |
+
"hyper_kwargs": {
|
| 388 |
+
"output_classes": "@output_classes",
|
| 389 |
+
"drop_label_prob": "@drop_label_prob",
|
| 390 |
+
"drop_point_prob": "@drop_point_prob",
|
| 391 |
+
"exclude_background": "@exclude_background",
|
| 392 |
+
"label_set": "@label_set",
|
| 393 |
+
"val_head": "auto",
|
| 394 |
+
"user_prompt": false
|
| 395 |
+
}
|
| 396 |
+
}
|
| 397 |
+
},
|
| 398 |
+
"initialize": [
|
| 399 |
+
"$monai.utils.set_determinism(seed=0)"
|
| 400 |
+
],
|
| 401 |
+
"run": [
|
| 402 |
+
"$@validate#handlers#0.set_trainer(trainer=@train#trainer) if @early_stop else None",
|
| 403 |
+
"$@train#trainer.add_event_handler(ignite.engine.Events.ITERATION_COMPLETED, ignite.handlers.TerminateOnNan())",
|
| 404 |
+
"$@train#trainer.run()"
|
| 405 |
+
]
|
| 406 |
+
}
|
configs/train_continual.json
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"data_list_file_path": "$@bundle_root + '/configs/msd_task09_spleen_folds.json'",
|
| 3 |
+
"dataset_dir": "/data/Task09_Spleen",
|
| 4 |
+
"finetune": true,
|
| 5 |
+
"val_at_start": true,
|
| 6 |
+
"finetune_model_path": "$@bundle_root + '/models/model.pt'",
|
| 7 |
+
"n_train_samples": 10,
|
| 8 |
+
"n_val_samples": 10,
|
| 9 |
+
"val_interval": 1,
|
| 10 |
+
"learning_rate": 5e-05,
|
| 11 |
+
"lr_schedule#activate": false,
|
| 12 |
+
"loss#smooth_dr": 0.01,
|
| 13 |
+
"loss#smooth_nr": 0.0001,
|
| 14 |
+
"train_dataset_cache_rate": 1.0,
|
| 15 |
+
"val_dataset_cache_rate": 1.0,
|
| 16 |
+
"num_cache_workers": 4,
|
| 17 |
+
"label_mappings": {
|
| 18 |
+
"default": [
|
| 19 |
+
[
|
| 20 |
+
1,
|
| 21 |
+
3
|
| 22 |
+
]
|
| 23 |
+
]
|
| 24 |
+
},
|
| 25 |
+
"patch_size": [
|
| 26 |
+
128,
|
| 27 |
+
128,
|
| 28 |
+
128
|
| 29 |
+
],
|
| 30 |
+
"label_set": "$[0] + list(x[1] for x in @label_mappings#default)",
|
| 31 |
+
"val_label_set": "$[0] + list(x[0] for x in @label_mappings#default)",
|
| 32 |
+
"num_classes": 255,
|
| 33 |
+
"output_classes": "$len(@label_set)",
|
| 34 |
+
"optimizer": {
|
| 35 |
+
"_target_": "torch.optim.AdamW",
|
| 36 |
+
"lr": "@learning_rate",
|
| 37 |
+
"params": "$@network.parameters()"
|
| 38 |
+
},
|
| 39 |
+
"show_cache_progress": true,
|
| 40 |
+
"resample_to_spacing": [
|
| 41 |
+
1.5,
|
| 42 |
+
1.5,
|
| 43 |
+
1.5
|
| 44 |
+
],
|
| 45 |
+
"cache_cls_idx": {
|
| 46 |
+
"activate": true,
|
| 47 |
+
"indices_key": "$'label_cls_indices' if @cache_cls_idx#activate else None"
|
| 48 |
+
},
|
| 49 |
+
"train#random_transforms": [
|
| 50 |
+
{
|
| 51 |
+
"_target_": "ClassesToIndicesd",
|
| 52 |
+
"_disabled_": "$not @cache_cls_idx#activate",
|
| 53 |
+
"keys": "label",
|
| 54 |
+
"num_classes": "@num_classes",
|
| 55 |
+
"indices_postfix": "_cls_indices",
|
| 56 |
+
"max_samples_per_class": "$int(10 * @epochs)"
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"_target_": "RandCropByLabelClassesd",
|
| 60 |
+
"keys": [
|
| 61 |
+
"image",
|
| 62 |
+
"label"
|
| 63 |
+
],
|
| 64 |
+
"label_key": "label",
|
| 65 |
+
"num_classes": "@num_classes",
|
| 66 |
+
"spatial_size": "@patch_size",
|
| 67 |
+
"num_samples": "@num_patches_per_image",
|
| 68 |
+
"ratios": "$tuple(float(i>=0) for i in range(@num_classes))",
|
| 69 |
+
"indices_key": "$@cache_cls_idx#indices_key",
|
| 70 |
+
"warn": false
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"_target_": "monai.apps.vista3d.transforms.Relabeld",
|
| 74 |
+
"keys": "label",
|
| 75 |
+
"label_mappings": "@label_mappings",
|
| 76 |
+
"dtype": "$torch.uint8"
|
| 77 |
+
}
|
| 78 |
+
],
|
| 79 |
+
"train#handlers#0#strict": false,
|
| 80 |
+
"train#dataset": {
|
| 81 |
+
"_target_": "CacheDataset",
|
| 82 |
+
"data": "$@train_datalist[:@n_train_samples]",
|
| 83 |
+
"transform": "@train#preprocessing",
|
| 84 |
+
"cache_rate": "@train_dataset_cache_rate",
|
| 85 |
+
"hash_as_key": true,
|
| 86 |
+
"num_workers": "@num_cache_workers",
|
| 87 |
+
"progress": "@show_cache_progress"
|
| 88 |
+
},
|
| 89 |
+
"validate#dataset": {
|
| 90 |
+
"_target_": "CacheDataset",
|
| 91 |
+
"data": "$@val_datalist[:@n_val_samples]",
|
| 92 |
+
"transform": "@validate#preprocessing",
|
| 93 |
+
"cache_rate": "@val_dataset_cache_rate",
|
| 94 |
+
"hash_as_key": true,
|
| 95 |
+
"num_workers": "@num_cache_workers",
|
| 96 |
+
"progress": "@show_cache_progress"
|
| 97 |
+
},
|
| 98 |
+
"validate#evaluator#hyper_kwargs#val_label_set": "$list(range(len(@val_label_set)))",
|
| 99 |
+
"validate#preprocessing#transforms": "$@train#deterministic_transforms + [@valid_remap]",
|
| 100 |
+
"valid_remap": {
|
| 101 |
+
"_target_": "monai.apps.vista3d.transforms.Relabeld",
|
| 102 |
+
"keys": "label",
|
| 103 |
+
"label_mappings": "${'default': [[c, i] for i, c in enumerate(@val_label_set)]}",
|
| 104 |
+
"dtype": "$torch.uint8"
|
| 105 |
+
},
|
| 106 |
+
"validate#handlers#3#key_metric_filename": "model_finetune.pt"
|
| 107 |
+
}
|
docs/README.md
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model Overview
|
| 2 |
+
Vista3D model fintuning/evaluation/inference pipeline. VISTA3D is trained using over 20 partial datasets with more complicated pipeline. To avoid confusion, we will only provide finetuning/continual learning APIs for users to finetune on their
|
| 3 |
+
own datasets.
|
| 4 |
+
|
| 5 |
+
## Continual learning
|
| 6 |
+
|
| 7 |
+
For continual learning, user can change `configs/train_continual.json`. More advanced users can change configurations in `configs/train.json`. The hyperparameters in `configs/train_continual.json` will overwrite ones in `configs/train.json`. Most hyperparameters are straighforward and user can tell based on their names. We list hyperparameters that needs to be modified.
|
| 8 |
+
|
| 9 |
+
### Data
|
| 10 |
+
|
| 11 |
+
The spleen Task from the Medical Segmentation Decathalon is selected as an example to show how to continuous learning. Users can find more details on the datasets at http://medicaldecathlon.com/.
|
| 12 |
+
|
| 13 |
+
To train with other datasets, users need to provide a json data split for training and continuous learning (`configs/msd_task09_spleen_folds.json` is an example for reference). The data split should meet the following format ('testing' labels are optional):
|
| 14 |
+
```json
|
| 15 |
+
{
|
| 16 |
+
"training": [
|
| 17 |
+
{"image": "img0001.nii.gz", "label": "label0001.nii.gz", "fold": 0},
|
| 18 |
+
{"image": "img0002.nii.gz", "label": "label0002.nii.gz", "fold": 2},
|
| 19 |
+
...
|
| 20 |
+
],
|
| 21 |
+
"testing": [
|
| 22 |
+
{"image": "img0003.nii.gz", "label": "label0003.nii.gz"},
|
| 23 |
+
{"image": "img0004.nii.gz", "label": "label0004.nii.gz"},
|
| 24 |
+
...
|
| 25 |
+
]
|
| 26 |
+
}
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
```
|
| 30 |
+
Note the data is not the absolute path to the image and label file. The actual image file will be `os.path.join(dataset_dir, data["training"][item]["image"])`, where `dataset_dir` is defined in `configs/train_continual.json`. Also 5-fold cross-validation is not required! `fold=0` is defined in train.json, which means any data item with fold==0 will be used as validation and other fold will be used for training. So if you only have 2 data, you can manually set one data to be validation by setting "fold": 0 in its datalist and the other to be training by setting "fold" to any number other than 0.
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
### Best practice to generate data list
|
| 34 |
+
User can use monai to generate the 5-fold data lists. Full exampls can be found in VISTA3D open source [codebase](https://github.com/Project-MONAI/VISTA/blob/main/vista3d/data/make_datalists.py)
|
| 35 |
+
```python
|
| 36 |
+
from monai.data.utils import partition_dataset
|
| 37 |
+
from monai.bundle import ConfigParser
|
| 38 |
+
base_url = "/path_to_your_folder/"
|
| 39 |
+
json_name = "./your_5_folds.json"
|
| 40 |
+
# create matching image and label lists.
|
| 41 |
+
# The code to generate the lists is based on your local data structure.
|
| 42 |
+
# You can use glob.glob("**.nii.gz") e.t.c.
|
| 43 |
+
image_list = ['images/1.nii.gz', 'images/2.nii.gz', ...]
|
| 44 |
+
label_list = ['labels/1.nii.gz', 'labels/2.nii.gz', ...]
|
| 45 |
+
items = [{"image": img, "label": lab} for img, lab in zip(image_list, label_list)]
|
| 46 |
+
# 80% for training 20% for testing.
|
| 47 |
+
train_test = partition_dataset(items, ratios=[0.8, 0.2], shuffle=True, seed=0)
|
| 48 |
+
print(f"training: {len(train_test[0])}, testing: {len(train_test[1])}")
|
| 49 |
+
# num_partitions-fold split for the training set.
|
| 50 |
+
train_val = partition_dataset(train_test[0], num_partitions=5, shuffle=True, seed=0)
|
| 51 |
+
print(f"training validation folds sizes: {[len(x) for x in train_val]}")
|
| 52 |
+
# add the fold index to each training data.
|
| 53 |
+
training = []
|
| 54 |
+
for f, x in enumerate(train_val):
|
| 55 |
+
for item in x:
|
| 56 |
+
item["fold"] = f
|
| 57 |
+
training.append(item)
|
| 58 |
+
# save json file
|
| 59 |
+
parser = ConfigParser({})
|
| 60 |
+
parser["training"] = training
|
| 61 |
+
parser["testing"] = train_test[1]
|
| 62 |
+
print(f"writing {json_name}\n\n")
|
| 63 |
+
if os.path.exists(json_name):
|
| 64 |
+
logger.warning(f"rewrite existing datalist file: {json_name}")
|
| 65 |
+
ConfigParser.export_config_file(parser.config, json_name, indent=4)
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
### Configurations
|
| 69 |
+
|
| 70 |
+
#### `label_mappings`
|
| 71 |
+
The core concept of label_mapping is to convert ground-truth label index of each dataset to a unified class index. For example, "Spleen" in MSD09 groundtruth will be represented by 1, while in AbdomenCT-1K it's 3. We unified a global label index (`docs/labels.json`) to represent all 132 classes, and create a label mapping to map those local index to this global index. So when a user is training on their own dataset, we need to know this mapping.
|
| 72 |
+
|
| 73 |
+
The current label mapping `[[1, 3]]` indicates that training labels' class indices `1` is mapped
|
| 74 |
+
to the VISTA model's class `3` (format `[[src_class_0, dst_class_0], [src_class_1, dst_class_1], ...]`). So during inference, "3" is used to segment spleen.
|
| 75 |
+
|
| 76 |
+
Since it's finetuning, you can map your local class to any global class. If you use [[1,4]], where "4" represents pancreas, the finetuning can still work but requires more training data and epoch because the class "4" is already assigned and trained with pancreas. If you use [[1,3]], where "3" already represents spleen, the finetuning will converge much faster.
|
| 77 |
+
|
| 78 |
+
#### Best practice to set label_mapping
|
| 79 |
+
|
| 80 |
+
For a class that represent the same or similar class as the global index, directly map it to the global index. For example, "mouse left lung" (e.g. index 2 in the mouse dataset) can be mapped to the 28 "left lung upper lobe"(or 29 "left lung lower lobe") with [[2,28]]. After finetuning, 28 now represents "mouse left lung" and will be used for segmentation. If you want to segment 4 substructures of aorta, you can map one of the substructuress to 6 aorta and the rest to any unused classes (class > 132), [[1,6],[2,133],[3,134],[4,135]]. For a completely novel class that none of the VISTA global classes are related, directly map to unused classes (class > 132).
|
| 81 |
+
```
|
| 82 |
+
NOTE: Do not map to global index value >= 255. `num_classes=255` in the config only represent the maximum mapping index, while the actual output class number only depends on your label_mapping definition. The 255 value in the inference output is also used to represent 'NaN' value.
|
| 83 |
+
```
|
| 84 |
+
#### `n_train_samples` and `n_val_samples`
|
| 85 |
+
In `train_continual.json`, only `n_train_samples` and `n_val_samples` are used for training and validation. Remember to change these two values.
|
| 86 |
+
|
| 87 |
+
#### `patch_size`
|
| 88 |
+
The patch size parameter is defined in `configs/train_continual.json`: `"patch_size": [128, 128, 128]`. For finetuning purposes, this value needs to be changed acccording to user's task and GPU memory. Usually a larger patch_size will give better final results.
|
| 89 |
+
|
| 90 |
+
#### `resample_to_spacing`
|
| 91 |
+
The resample_to_spacing parameter is defined in `configs/train_continual.json` and it represents the resolution the model will be trained on. The `1.5,1.5,1.5` mm default is suitable for large CT organs, but for other tasks, this value should be changed to achive the optimal performance.
|
| 92 |
+
|
| 93 |
+
#### Advanced user: `drop_label_prob` and `drop_point_prob` (in train.json)
|
| 94 |
+
VISTA3D is trained to perform both automatic (class prompts) and interactive point segmentation.
|
| 95 |
+
`drop_label_prob` and `drop_point_prob` means percentage to remove class prompts and point prompts during training respectively. If `drop_point_prob=1`, the
|
| 96 |
+
model is only finetuning for automatic segmentation, while `drop_label_prob=1` means only finetuning for interactive segmentation. The VISTA3D foundation
|
| 97 |
+
model is trained with interactive only (drop_label_prob=1) and then froze the point branch and trained with fully automatic segmentation (`drop_point_prob=1`).
|
| 98 |
+
In this bundle, the training is simplified by jointly training with class prompts and point prompts and both of the drop ratio is set to 0.25.
|
| 99 |
+
```
|
| 100 |
+
NOTE: If user doesn't use interactive segmentation, set `drop_point_prob=1` and `drop_label_prob=0` in train.json might provide a faster and easier finetuning process.
|
| 101 |
+
```
|
| 102 |
+
#### Other explanatory items
|
| 103 |
+
In `train.json`, `validate[evaluator][val_head]` can be `auto` and `point`. If `auto`, the validation results will be automatic segmentation. If `point`,
|
| 104 |
+
the validation results will be sampling one positive point per object per patch. The validation scheme of combining auto and point is deprecated due to
|
| 105 |
+
speed issue.
|
| 106 |
+
|
| 107 |
+
In `train_continual.json`, `valid_remap` is a transform that maps the groundtruth label indexes, e.g. [0,2,3,5,6] to sequential and continuous labels [0,1,2,3,4]. This is
|
| 108 |
+
required by monai dice calculation. It is not related to mapping label index to VISTA3D defined global class index. The validation data is not mapped
|
| 109 |
+
to the VISTA3D global class index.
|
| 110 |
+
|
| 111 |
+
`label_set` is used to identify the VISTA model classes for providing training prompts.
|
| 112 |
+
`val_label_set` is used to identify the original training label classes for computing foreground/background mask during validation.
|
| 113 |
+
The default configs for both variables are derived from the `label_mappings` config and include `[0]`:
|
| 114 |
+
```
|
| 115 |
+
"label_set": "$[0] + list(x[1] for x in @label_mappings#default)"
|
| 116 |
+
"val_label_set": "$[0] + list(x[0] for x in @label_mappings#default)"
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
Note: Please ensure the input data header is correct. The output file will use the same header as the input data, but if the input data is missing header information, MONAI will automatically provide some default values for missing values (e.g. `np.eye(4)` will be used if affine information is absent). This may cause a visualization misalignment depending on the visualization tool.
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
### Commands
|
| 123 |
+
|
| 124 |
+
Single-GPU:
|
| 125 |
+
```bash
|
| 126 |
+
python -m monai.bundle run \
|
| 127 |
+
--config_file="['configs/train.json','configs/train_continual.json']" --epochs=320 --learning_rate=0.00005
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
Multi-GPU:
|
| 131 |
+
```bash
|
| 132 |
+
torchrun --nnodes=1 --nproc_per_node=8 -m monai.bundle run \
|
| 133 |
+
--config_file="['configs/train.json','configs/train_continual.json','configs/multi_gpu_train.json']" --epochs=320 --learning_rate=0.00005
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
### MLFlow support
|
| 137 |
+
|
| 138 |
+
MLflow can be enabled to track and manage your machine learning experiments. To enable MLflow, set the `use_mlflow` parameter to `True`. Below is an example of how to run a single-GPU training command with MLflow enabled:
|
| 139 |
+
|
| 140 |
+
```bash
|
| 141 |
+
python -m monai.bundle run \
|
| 142 |
+
--config_file="['configs/train.json','configs/train_continual.json']" --epochs=320 --learning_rate=0.00005 --use_mlflow True
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
By default, the data of MLflow is stored in the `mlruns/` folder under the bundle's root directory. To launch the MLflow UI and track your experiment data, follow these steps:
|
| 146 |
+
|
| 147 |
+
1. Open a terminal and navigate to the root directory of your bundle where the `mlruns/` folder is located.
|
| 148 |
+
|
| 149 |
+
2. Execute the following command to start the MLflow server. This will make the MLflow UI accessible.
|
| 150 |
+
|
| 151 |
+
```Bash
|
| 152 |
+
mlflow ui
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
## Evaluation
|
| 156 |
+
Evaluation can be used to calculate dice scores for the model or a finetuned model. Change the `ckpt_path` to the checkpoint you wish to evaluate. The dice score is calculated on the original image spacing using `invertd`, while the dice score during finetuning is calculated on resampled space.
|
| 157 |
+
|
| 158 |
+
```
|
| 159 |
+
NOTE: Evaluation does not support point evaluation.`"validate#evaluator#hyper_kwargs#val_head` is always set to `auto`.
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
Single-GPU:
|
| 163 |
+
```
|
| 164 |
+
python -m monai.bundle run \
|
| 165 |
+
--config_file="['configs/train.json','configs/train_continual.json','configs/evaluate.json']"
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
Multi-GPU:
|
| 169 |
+
```
|
| 170 |
+
torchrun --nnodes=1 --nproc_per_node=8 -m monai.bundle run \
|
| 171 |
+
--config_file="['configs/train.json','configs/train_continual.json','configs/evaluate.json','configs/mgpu_evaluate.json']"
|
| 172 |
+
```
|
| 173 |
+
#### Other explanatory items
|
| 174 |
+
The `label_mapping` in `evaluation.json` does not include `0` because the postprocessing step performs argmax (`VistaPostTransformd`), and a `0` prediction would negatively impact performance. In continuous learning, however, `0` is included for validation because no argmax is performed, and validation is done channel-wise (include_background=False). Additionally, `Relabeld` in `postprocessing` is required to map `label` and `pred` back to sequential indexes like `0, 1, 2, 3, 4` for dice calculation, as they are not in one-hot format. Evaluation does not support `point`, but finetuning does, as it does not perform argmax.
|
| 175 |
+
|
| 176 |
+
## Inference:
|
| 177 |
+
For inference, VISTA3d bundle requires at least one prompt for segmentation. It supports label prompt, which is the index of the class for automatic segmentation.
|
| 178 |
+
It also supports point click prompts for binary interactive segmentation. User can provide both prompts at the same time.
|
| 179 |
+
|
| 180 |
+
All the configurations for inference is stored in inference.json, change those parameters:
|
| 181 |
+
### `input_dict`
|
| 182 |
+
`input_dict` defines the image to segment and the prompt for segmentation.
|
| 183 |
+
```
|
| 184 |
+
"input_dict": "$[{'image': '/data/Task09_Spleen/imagesTs/spleen_15.nii.gz', 'label_prompt':[1]}]",
|
| 185 |
+
"input_dict": "$[{'image': '/data/Task09_Spleen/imagesTs/spleen_15.nii.gz', 'points':[[138,245,18], [271,343,27]], 'point_labels':[1,0]}]"
|
| 186 |
+
```
|
| 187 |
+
- The input_dict must include the key `image` which contain the absolute path to the nii image file, and includes prompt keys of `label_prompt`, `points` and `point_labels`.
|
| 188 |
+
- The `label_prompt` is a list of length `B`, which can perform `B` foreground objects segmentation, e.g. `[2,3,4,5]`. If `B>1`, Point prompts must NOT be provided.
|
| 189 |
+
- The `points` is of shape `[N, 3]` like `[[x1,y1,z1],[x2,y2,z2],...[xN,yN,zN]]`, representing `N` point coordinates **IN THE ORIGINAL IMAGE SPACE** of a single foreground object. `point_labels` is a list of length [N] like [1,1,0,-1,...], which
|
| 190 |
+
matches the `points`. 0 means background, 1 means foreground, -1 means ignoring this point. `points` and `point_labels` must pe provided together and match length.
|
| 191 |
+
- **B must be 1 if label_prompt and points are provided together**. The inferer only supports SINGLE OBJECT point click segmentatation.
|
| 192 |
+
- If no prompt is provided, the model will use `everything_labels` to segment 117 classes:
|
| 193 |
+
|
| 194 |
+
```Python
|
| 195 |
+
list(set([i+1 for i in range(132)]) - set([2,16,18,20,21,23,24,25,26,27,128,129,130,131,132]))
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
- The `points` together with `label_prompts` for "Kidney", "Lung", "Bone" (class index [2, 20, 21]) are not allowed since those prompts will be divided into sub-categories (e.g. left kidney and right kidney). Use `points` for the sub-categories as defined in the `inference.json`.
|
| 199 |
+
- To specify a new class for zero-shot segmentation, set the `label_prompt` to a value between 133 and 254. Ensure that `points` and `point_labels` are also provided; otherwise, the inference result will be a tensor of zeros.
|
| 200 |
+
|
| 201 |
+
### `label_prompt` and `label_dict`
|
| 202 |
+
The `label_dict` defined in `docs/labels.json` has in total 132 classes. However, there are 5 we do not support and we keep them due to legacy issue. So in total
|
| 203 |
+
VISTA3D support 127 classes.
|
| 204 |
+
|
| 205 |
+
```
|
| 206 |
+
"16, # prostate or uterus" since we already have "prostate" class,
|
| 207 |
+
"18, # rectum", insufficient data or dataset excluded.
|
| 208 |
+
"130, # liver tumor" already have hepatic tumor.
|
| 209 |
+
"129, # kidney mass" insufficient data or dataset excluded.
|
| 210 |
+
"131, # vertebrae L6", insufficient data or dataset excluded.
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
+
These 5 are excluded in the `everything_labels`. Another 7 tumor and vessel classes are also removed since they will overlap with other organs and make the output messy. To segment those 7 classes, we recommend users to directly set `label_prompt` to those indexes and avoid using them in `everything_labels`. For "Kidney", "Lung", "Bone" (class index [2, 20, 21]), VISTA3D did not directly use the class index for segmentation, but instead convert them to their subclass indexes as defined by `subclass` dict. For example, "2-Kidney" is converted to "14-Left Kidney" + "5-Right Kidney" since "2" is defined in `subclasss` dict.
|
| 214 |
+
|
| 215 |
+
```
|
| 216 |
+
Note: if the finetuning mapped the local user data index to global index "2, 20, 21", remove the `subclass` dict from inference.json since those values defined in `subclass` will trigger the wrong subclass segmentation.
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
### `resample_spacing`
|
| 220 |
+
The optimal inference resample spacing should be changed according to the task. For monkey data, a high resolution of [1,1,1] showed better automatic inference results. This spacing applies to both automatic and interactive segmentation. For zero-shot interactive segmentation for non-human CTs e.g. mouse CT or even rock/stone CT, using original resolution (set `resample_spacing` to [-1,-1,-1]) may give better interactive results.
|
| 221 |
+
|
| 222 |
+
### `use_point_window`
|
| 223 |
+
When user click a point, there is no need to perform whole image sliding window inference. Set "use_point_window" to true in the inference.json to enable this function.
|
| 224 |
+
A window centered at the clicked points will be used for inference. All values outside of the window will set to be "NaN" unless "prev_mask" is passed to the inferer (255 is used to represent NaN).
|
| 225 |
+
If no point click exists, this function will not be used. Notice if "use_point_window" is true and user provided point clicks, there will be obvious cut-off box artefacts.
|
| 226 |
+
|
| 227 |
+
### Inference GPU benchmarks
|
| 228 |
+
Benchmarks on a 16GB V100 GPU with 400G system cpu memory.
|
| 229 |
+
| Volume size at 1.5x1.5x1.5 mm | 333x333x603 | 512x512x512 | 512x512x768 | 1024x1024x512 | 1024x1024x768 |
|
| 230 |
+
| :---: | :---: | :---: | :---: | :---: | :---: |
|
| 231 |
+
|RunTime| 1m07s | 2m09s | 3m25s| 9m20s| killed |
|
| 232 |
+
## Commands
|
| 233 |
+
The bundle only provides single-gpu inference.
|
| 234 |
+
### Single image inference
|
| 235 |
+
```
|
| 236 |
+
python -m monai.bundle run --config_file configs/inference.json
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
### Batch inference for segmenting everything
|
| 240 |
+
```
|
| 241 |
+
python -m monai.bundle run --config_file="['configs/inference.json', 'configs/batch_inference.json']" --input_dir="/data/Task09_Spleen/imagesTr" --output_dir="./eval_task09"
|
| 242 |
+
```
|
| 243 |
+
|
| 244 |
+
`configs/batch_inference.json` by default runs the segment everything workflow (classes defined by `everything_labels`) on all (`*.nii.gz`) files in `input_dir`.
|
| 245 |
+
This default is overridable by changing the input folder `input_dir`, or the input image name suffix `input_suffix`, or directly setting the list of filenames `input_list`.
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
### Execute inference with the TensorRT model:
|
| 249 |
+
|
| 250 |
+
```
|
| 251 |
+
python -m monai.bundle run --config_file "['configs/inference.json', 'configs/inference_trt.json']"
|
| 252 |
+
```
|
| 253 |
+
|
| 254 |
+
By default, the argument `head_trt_enabled` is set to `false` in `configs/inference_trt.json`. This means that the `class_head` module of the network will not be converted into a TensorRT model. Setting this to `true` may accelerate the process, but there are some limitations:
|
| 255 |
+
|
| 256 |
+
Since the `label_prompt` will be converted into a tensor and input into the `class_head` module, the batch size of this input tensor will equal the length of the original `label_prompt` list (if no prompt is provided, the length is 117). To make the TensorRT model work on the `class_head` module, you should set a suitable dynamic batch size range. The maximum dynamic batch size can be configured using the argument `max_prompt_size` in `configs/inference_trt.json`. If the length of the `label_prompt` list exceeds `max_prompt_size`, the engine will fall back to using the normal PyTorch model for inference. Setting a larger `max_prompt_size` can cover more input cases but may require more GPU memory (the default value is 4, which requires 16 GB of GPU memory). Therefore, please set it to a reasonable value according to your actual requirements.
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
### TroubleShoot for Out-of-Memory
|
| 260 |
+
- Changing `patch_size` to a smaller value such as `"patch_size": [96, 96, 96]` would reduce the training/inference memory footprint.
|
| 261 |
+
- Changing `train_dataset_cache_rate` and `val_dataset_cache_rate` to a smaller value like `0.1` can solve the out-of-cpu memory issue when using huge finetuning dataset.
|
| 262 |
+
- Set `"postprocessing#transforms#0#_disabled_": false` to move the postprocessing to cpu to reduce the GPU memory footprint.
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
### TensorRT speedup
|
| 267 |
+
The `vista3d` bundle supports acceleration with TensorRT. The table below displays the speedup ratios observed on an A100 80G GPU. Please note for 32bit precision models, they are benchmarked with tf32 weight format.
|
| 268 |
+
|
| 269 |
+
| method | torch_tf32(ms) | torch_amp(ms) | trt_tf32(ms) | trt_fp16(ms) | speedup amp | speedup tf32 | speedup fp16 | amp vs fp16|
|
| 270 |
+
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
|
| 271 |
+
| model computation | 108.53| 91.9 | 106.84 | 60.02 | 1.18 | 1.02 | 1.81 | 1.53 |
|
| 272 |
+
| end2end | 6740 | 5166 | 5242 | 3386 | 1.30 | 1.29 | 1.99 | 1.53 |
|
| 273 |
+
|
| 274 |
+
Where:
|
| 275 |
+
- `model computation` means the speedup ratio of model's inference with a random input without preprocessing and postprocessing
|
| 276 |
+
- `end2end` means run the bundle end-to-end with the TensorRT based model.
|
| 277 |
+
- `torch_tf32` and `torch_amp` are for the PyTorch models with or without `amp` mode.
|
| 278 |
+
- `trt_tf32` and `trt_fp16` are for the TensorRT based models converted in corresponding precision.
|
| 279 |
+
- `speedup amp`, `speedup tf32` and `speedup fp16` are the speedup ratios of corresponding models versus the PyTorch float32 model
|
| 280 |
+
- `amp vs fp16` is the speedup ratio between the PyTorch amp model and the TensorRT float16 based model.
|
| 281 |
+
|
| 282 |
+
This result is benchmarked under:
|
| 283 |
+
- TensorRT: 10.3.0+cuda12.6
|
| 284 |
+
- Torch-TensorRT Version: 2.4.0
|
| 285 |
+
- CPU Architecture: x86-64
|
| 286 |
+
- OS: ubuntu 20.04
|
| 287 |
+
- Python version:3.10.12
|
| 288 |
+
- CUDA version: 12.6
|
| 289 |
+
- GPU models and configuration: A100 80G
|
| 290 |
+
|
| 291 |
+
# References
|
| 292 |
+
- Antonelli, M., Reinke, A., Bakas, S. et al. The Medical Segmentation Decathlon. Nat Commun 13, 4128 (2022). https://doi.org/10.1038/s41467-022-30695-9
|
| 293 |
+
|
| 294 |
+
- VISTA3D: Versatile Imaging SegmenTation and Annotation model for 3D Computed Tomography. arxiv (2024) https://arxiv.org/abs/2406.05285
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# License
|
| 298 |
+
|
| 299 |
+
## Code License
|
| 300 |
+
|
| 301 |
+
This project includes code licensed under the Apache License 2.0.
|
| 302 |
+
You may obtain a copy of the License at
|
| 303 |
+
|
| 304 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 305 |
+
|
| 306 |
+
## Model Weights License
|
| 307 |
+
|
| 308 |
+
The model weights included in this project are licensed under the NCLS v1 License.
|
| 309 |
+
|
| 310 |
+
Both licenses' full texts have been combined into a single `LICENSE` file. Please refer to this `LICENSE` file for more details about the terms and conditions of both licenses.
|
docs/data_license.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Third Party Licenses
|
| 2 |
+
-----------------------------------------------------------------------
|
| 3 |
+
|
| 4 |
+
/*********************************************************************/
|
| 5 |
+
i. Medical Segmentation Decathlon
|
| 6 |
+
http://medicaldecathlon.com/
|
docs/labels.json
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"liver": 1,
|
| 3 |
+
"spleen": 3,
|
| 4 |
+
"pancreas": 4,
|
| 5 |
+
"right kidney": 5,
|
| 6 |
+
"aorta": 6,
|
| 7 |
+
"inferior vena cava": 7,
|
| 8 |
+
"right adrenal gland": 8,
|
| 9 |
+
"left adrenal gland": 9,
|
| 10 |
+
"gallbladder": 10,
|
| 11 |
+
"esophagus": 11,
|
| 12 |
+
"stomach": 12,
|
| 13 |
+
"duodenum": 13,
|
| 14 |
+
"left kidney": 14,
|
| 15 |
+
"bladder": 15,
|
| 16 |
+
"portal vein and splenic vein": 17,
|
| 17 |
+
"small bowel": 19,
|
| 18 |
+
"brain": 22,
|
| 19 |
+
"lung tumor": 23,
|
| 20 |
+
"pancreatic tumor": 24,
|
| 21 |
+
"hepatic vessel": 25,
|
| 22 |
+
"hepatic tumor": 26,
|
| 23 |
+
"colon cancer primaries": 27,
|
| 24 |
+
"left lung upper lobe": 28,
|
| 25 |
+
"left lung lower lobe": 29,
|
| 26 |
+
"right lung upper lobe": 30,
|
| 27 |
+
"right lung middle lobe": 31,
|
| 28 |
+
"right lung lower lobe": 32,
|
| 29 |
+
"vertebrae L5": 33,
|
| 30 |
+
"vertebrae L4": 34,
|
| 31 |
+
"vertebrae L3": 35,
|
| 32 |
+
"vertebrae L2": 36,
|
| 33 |
+
"vertebrae L1": 37,
|
| 34 |
+
"vertebrae T12": 38,
|
| 35 |
+
"vertebrae T11": 39,
|
| 36 |
+
"vertebrae T10": 40,
|
| 37 |
+
"vertebrae T9": 41,
|
| 38 |
+
"vertebrae T8": 42,
|
| 39 |
+
"vertebrae T7": 43,
|
| 40 |
+
"vertebrae T6": 44,
|
| 41 |
+
"vertebrae T5": 45,
|
| 42 |
+
"vertebrae T4": 46,
|
| 43 |
+
"vertebrae T3": 47,
|
| 44 |
+
"vertebrae T2": 48,
|
| 45 |
+
"vertebrae T1": 49,
|
| 46 |
+
"vertebrae C7": 50,
|
| 47 |
+
"vertebrae C6": 51,
|
| 48 |
+
"vertebrae C5": 52,
|
| 49 |
+
"vertebrae C4": 53,
|
| 50 |
+
"vertebrae C3": 54,
|
| 51 |
+
"vertebrae C2": 55,
|
| 52 |
+
"vertebrae C1": 56,
|
| 53 |
+
"trachea": 57,
|
| 54 |
+
"left iliac artery": 58,
|
| 55 |
+
"right iliac artery": 59,
|
| 56 |
+
"left iliac vena": 60,
|
| 57 |
+
"right iliac vena": 61,
|
| 58 |
+
"colon": 62,
|
| 59 |
+
"left rib 1": 63,
|
| 60 |
+
"left rib 2": 64,
|
| 61 |
+
"left rib 3": 65,
|
| 62 |
+
"left rib 4": 66,
|
| 63 |
+
"left rib 5": 67,
|
| 64 |
+
"left rib 6": 68,
|
| 65 |
+
"left rib 7": 69,
|
| 66 |
+
"left rib 8": 70,
|
| 67 |
+
"left rib 9": 71,
|
| 68 |
+
"left rib 10": 72,
|
| 69 |
+
"left rib 11": 73,
|
| 70 |
+
"left rib 12": 74,
|
| 71 |
+
"right rib 1": 75,
|
| 72 |
+
"right rib 2": 76,
|
| 73 |
+
"right rib 3": 77,
|
| 74 |
+
"right rib 4": 78,
|
| 75 |
+
"right rib 5": 79,
|
| 76 |
+
"right rib 6": 80,
|
| 77 |
+
"right rib 7": 81,
|
| 78 |
+
"right rib 8": 82,
|
| 79 |
+
"right rib 9": 83,
|
| 80 |
+
"right rib 10": 84,
|
| 81 |
+
"right rib 11": 85,
|
| 82 |
+
"right rib 12": 86,
|
| 83 |
+
"left humerus": 87,
|
| 84 |
+
"right humerus": 88,
|
| 85 |
+
"left scapula": 89,
|
| 86 |
+
"right scapula": 90,
|
| 87 |
+
"left clavicula": 91,
|
| 88 |
+
"right clavicula": 92,
|
| 89 |
+
"left femur": 93,
|
| 90 |
+
"right femur": 94,
|
| 91 |
+
"left hip": 95,
|
| 92 |
+
"right hip": 96,
|
| 93 |
+
"sacrum": 97,
|
| 94 |
+
"left gluteus maximus": 98,
|
| 95 |
+
"right gluteus maximus": 99,
|
| 96 |
+
"left gluteus medius": 100,
|
| 97 |
+
"right gluteus medius": 101,
|
| 98 |
+
"left gluteus minimus": 102,
|
| 99 |
+
"right gluteus minimus": 103,
|
| 100 |
+
"left autochthon": 104,
|
| 101 |
+
"right autochthon": 105,
|
| 102 |
+
"left iliopsoas": 106,
|
| 103 |
+
"right iliopsoas": 107,
|
| 104 |
+
"left atrial appendage": 108,
|
| 105 |
+
"brachiocephalic trunk": 109,
|
| 106 |
+
"left brachiocephalic vein": 110,
|
| 107 |
+
"right brachiocephalic vein": 111,
|
| 108 |
+
"left common carotid artery": 112,
|
| 109 |
+
"right common carotid artery": 113,
|
| 110 |
+
"costal cartilages": 114,
|
| 111 |
+
"heart": 115,
|
| 112 |
+
"left kidney cyst": 116,
|
| 113 |
+
"right kidney cyst": 117,
|
| 114 |
+
"prostate": 118,
|
| 115 |
+
"pulmonary vein": 119,
|
| 116 |
+
"skull": 120,
|
| 117 |
+
"spinal cord": 121,
|
| 118 |
+
"sternum": 122,
|
| 119 |
+
"left subclavian artery": 123,
|
| 120 |
+
"right subclavian artery": 124,
|
| 121 |
+
"superior vena cava": 125,
|
| 122 |
+
"thyroid gland": 126,
|
| 123 |
+
"vertebrae S1": 127,
|
| 124 |
+
"bone lesion": 128,
|
| 125 |
+
"airway": 132
|
| 126 |
+
}
|
models/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c92bab26d00b4a5d89fa8a383900cdeb88302fd318e5e816df0bbec7106d9a1b
|
| 3 |
+
size 871970895
|
scripts/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
# from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator
|
| 13 |
+
# from .multi_gpu_supervised_trainer import create_multigpu_supervised_evaluator, create_multigpu_supervised_trainer
|
| 14 |
+
|
| 15 |
+
from .early_stop_score_function import score_function
|
scripts/early_stop_score_function.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def score_function(engine):
|
| 8 |
+
val_metric = engine.state.metrics["val_mean_dice"]
|
| 9 |
+
if dist.is_initialized():
|
| 10 |
+
device = torch.device("cuda:" + os.environ["LOCAL_RANK"])
|
| 11 |
+
val_metric = torch.tensor([val_metric]).to(device)
|
| 12 |
+
dist.all_reduce(val_metric, op=dist.ReduceOp.SUM)
|
| 13 |
+
val_metric /= dist.get_world_size()
|
| 14 |
+
return val_metric.item()
|
| 15 |
+
return val_metric
|
scripts/evaluator.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
from monai.engines.evaluator import SupervisedEvaluator
|
| 19 |
+
from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch
|
| 20 |
+
from monai.inferers import Inferer, SimpleInferer
|
| 21 |
+
from monai.transforms import Transform, reset_ops_id
|
| 22 |
+
from monai.utils import ForwardMode, IgniteInfo, RankFilter, min_version, optional_import
|
| 23 |
+
from monai.utils.enums import CommonKeys as Keys
|
| 24 |
+
from torch.utils.data import DataLoader
|
| 25 |
+
|
| 26 |
+
rearrange, _ = optional_import("einops", name="rearrange")
|
| 27 |
+
|
| 28 |
+
if TYPE_CHECKING:
|
| 29 |
+
from ignite.engine import Engine, EventEnum
|
| 30 |
+
from ignite.metrics import Metric
|
| 31 |
+
else:
|
| 32 |
+
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
|
| 33 |
+
Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
|
| 34 |
+
EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")
|
| 35 |
+
|
| 36 |
+
__all__ = ["Vista3dEvaluator"]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Vista3dEvaluator(SupervisedEvaluator):
|
| 40 |
+
"""
|
| 41 |
+
Supervised detection evaluation method with image and label, inherits from ``SupervisedEvaluator`` and ``Workflow``.
|
| 42 |
+
Args:
|
| 43 |
+
device: an object representing the device on which to run.
|
| 44 |
+
val_data_loader: Ignite engine use data_loader to run, must be Iterable, typically be torch.DataLoader.
|
| 45 |
+
network: detector to evaluate in the evaluator, should be regular PyTorch `torch.nn.Module`.
|
| 46 |
+
epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
|
| 47 |
+
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
|
| 48 |
+
with respect to the host. For other cases, this argument has no effect.
|
| 49 |
+
prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
|
| 50 |
+
from `engine.state.batch` for every iteration, for more details please refer to:
|
| 51 |
+
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
|
| 52 |
+
iteration_update: the callable function for every iteration, expect to accept `engine`
|
| 53 |
+
and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
|
| 54 |
+
if not provided, use `self._iteration()` instead. for more details please refer to:
|
| 55 |
+
https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
|
| 56 |
+
inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
|
| 57 |
+
postprocessing: execute additional transformation for the model output data.
|
| 58 |
+
Typically, several Tensor based transforms composed by `Compose`.
|
| 59 |
+
key_val_metric: compute metric when every iteration completed, and save average value to
|
| 60 |
+
engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
|
| 61 |
+
checkpoint into files.
|
| 62 |
+
additional_metrics: more Ignite metrics that also attach to Ignite Engine.
|
| 63 |
+
metric_cmp_fn: function to compare current key metric with previous best key metric value,
|
| 64 |
+
it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
|
| 65 |
+
`best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
|
| 66 |
+
val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
|
| 67 |
+
CheckpointHandler, StatsHandler, etc.
|
| 68 |
+
amp: whether to enable auto-mixed-precision evaluation, default is False.
|
| 69 |
+
mode: model forward mode during evaluation, should be 'eval' or 'train',
|
| 70 |
+
which maps to `model.eval()` or `model.train()`, default to 'eval'.
|
| 71 |
+
event_names: additional custom ignite events that will register to the engine.
|
| 72 |
+
new events can be a list of str or `ignite.engine.events.EventEnum`.
|
| 73 |
+
event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
|
| 74 |
+
for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
|
| 75 |
+
#ignite.engine.engine.Engine.register_events.
|
| 76 |
+
decollate: whether to decollate the batch-first data to a list of data after model computation,
|
| 77 |
+
recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
|
| 78 |
+
default to `True`.
|
| 79 |
+
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
| 80 |
+
`device`, `non_blocking`.
|
| 81 |
+
amp_kwargs: dict of the args for `torch.amp.autocast()` API, for more details:
|
| 82 |
+
https://pytorch.org/docs/stable/amp.html#torch.amp.autocast.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
device: torch.device,
|
| 88 |
+
val_data_loader: Iterable | DataLoader,
|
| 89 |
+
network: torch.nn.Module,
|
| 90 |
+
epoch_length: int | None = None,
|
| 91 |
+
non_blocking: bool = False,
|
| 92 |
+
prepare_batch: Callable = default_prepare_batch,
|
| 93 |
+
iteration_update: Callable[[Engine, Any], Any] | None = None,
|
| 94 |
+
inferer: Inferer | None = None,
|
| 95 |
+
postprocessing: Transform | None = None,
|
| 96 |
+
key_val_metric: dict[str, Metric] | None = None,
|
| 97 |
+
additional_metrics: dict[str, Metric] | None = None,
|
| 98 |
+
metric_cmp_fn: Callable = default_metric_cmp_fn,
|
| 99 |
+
val_handlers: Sequence | None = None,
|
| 100 |
+
amp: bool = False,
|
| 101 |
+
mode: ForwardMode | str = ForwardMode.EVAL,
|
| 102 |
+
event_names: list[str | EventEnum | type[EventEnum]] | None = None,
|
| 103 |
+
event_to_attr: dict | None = None,
|
| 104 |
+
decollate: bool = True,
|
| 105 |
+
to_kwargs: dict | None = None,
|
| 106 |
+
amp_kwargs: dict | None = None,
|
| 107 |
+
hyper_kwargs: dict | None = None,
|
| 108 |
+
) -> None:
|
| 109 |
+
super().__init__(
|
| 110 |
+
device=device,
|
| 111 |
+
val_data_loader=val_data_loader,
|
| 112 |
+
network=network,
|
| 113 |
+
epoch_length=epoch_length,
|
| 114 |
+
non_blocking=non_blocking,
|
| 115 |
+
prepare_batch=prepare_batch,
|
| 116 |
+
iteration_update=iteration_update,
|
| 117 |
+
postprocessing=postprocessing,
|
| 118 |
+
key_val_metric=key_val_metric,
|
| 119 |
+
additional_metrics=additional_metrics,
|
| 120 |
+
metric_cmp_fn=metric_cmp_fn,
|
| 121 |
+
val_handlers=val_handlers,
|
| 122 |
+
amp=amp,
|
| 123 |
+
mode=mode,
|
| 124 |
+
event_names=event_names,
|
| 125 |
+
event_to_attr=event_to_attr,
|
| 126 |
+
decollate=decollate,
|
| 127 |
+
to_kwargs=to_kwargs,
|
| 128 |
+
amp_kwargs=amp_kwargs,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
self.network = network
|
| 132 |
+
self.device = device
|
| 133 |
+
self.inferer = SimpleInferer() if inferer is None else inferer
|
| 134 |
+
self.hyper_kwargs = hyper_kwargs
|
| 135 |
+
self.logger.addFilter(RankFilter())
|
| 136 |
+
|
| 137 |
+
def transform_points(self, point, affine):
|
| 138 |
+
"""transform point to the coordinates of the transformed image
|
| 139 |
+
point: numpy array [bs, N, 3]
|
| 140 |
+
"""
|
| 141 |
+
bs, n = point.shape[:2]
|
| 142 |
+
point = np.concatenate((point, np.ones((bs, n, 1))), axis=-1)
|
| 143 |
+
point = rearrange(point, "b n d -> d (b n)")
|
| 144 |
+
point = affine @ point
|
| 145 |
+
point = rearrange(point, "d (b n)-> b n d", b=bs)[:, :, :3]
|
| 146 |
+
return point
|
| 147 |
+
|
| 148 |
+
def check_prompts_format(self, label_prompt, points, point_labels):
|
| 149 |
+
"""check the format of user prompts
|
| 150 |
+
label_prompt: [1,2,3,4,...,B] List of tensors
|
| 151 |
+
points: [[[x,y,z], [x,y,z], ...]] List of coordinates of a single object
|
| 152 |
+
point_labels: [[1,1,0,...]] List of scalar that matches number of points
|
| 153 |
+
"""
|
| 154 |
+
# check prompt is given
|
| 155 |
+
if label_prompt is None and points is None:
|
| 156 |
+
everything_labels = self.hyper_kwargs.get("everything_labels", None)
|
| 157 |
+
if everything_labels is not None:
|
| 158 |
+
label_prompt = [torch.tensor(_) for _ in everything_labels]
|
| 159 |
+
return label_prompt, points, point_labels
|
| 160 |
+
else:
|
| 161 |
+
raise ValueError("Prompt must be given for inference.")
|
| 162 |
+
# check label_prompt
|
| 163 |
+
if label_prompt is not None:
|
| 164 |
+
if isinstance(label_prompt, list):
|
| 165 |
+
if not np.all([len(_) == 1 for _ in label_prompt]):
|
| 166 |
+
raise ValueError("Label prompt must be a list of single scalar, [1,2,3,4,...,].")
|
| 167 |
+
if not np.all([(x < 255).item() for x in label_prompt]):
|
| 168 |
+
raise ValueError("Current bundle only supports label prompt smaller than 255.")
|
| 169 |
+
if points is None:
|
| 170 |
+
supported_list = list({i + 1 for i in range(132)} - {16, 18, 129, 130, 131})
|
| 171 |
+
if not np.all([x in supported_list for x in label_prompt]):
|
| 172 |
+
raise ValueError("Undefined label prompt detected. Provide point prompts for zero-shot.")
|
| 173 |
+
else:
|
| 174 |
+
raise ValueError("Label prompt must be a list, [1,2,3,4,...,].")
|
| 175 |
+
# check points
|
| 176 |
+
if points is not None:
|
| 177 |
+
if point_labels is None:
|
| 178 |
+
raise ValueError("Point labels must be given if points are given.")
|
| 179 |
+
if not np.all([len(_) == 3 for _ in points]):
|
| 180 |
+
raise ValueError("Points must be three dimensional (x,y,z) in the shape of [[x,y,z],...,[x,y,z]].")
|
| 181 |
+
if len(points) != len(point_labels):
|
| 182 |
+
raise ValueError("Points must match point labels.")
|
| 183 |
+
if not np.all([_ in [-1, 0, 1, 2, 3] for _ in point_labels]):
|
| 184 |
+
raise ValueError("Point labels can only be -1,0,1 and 2,3 for special flags.")
|
| 185 |
+
if label_prompt is not None and points is not None:
|
| 186 |
+
if len(label_prompt) != 1:
|
| 187 |
+
raise ValueError("Label prompt can only be a single object if provided with point prompts.")
|
| 188 |
+
# check point_labels
|
| 189 |
+
if point_labels is not None:
|
| 190 |
+
if points is None:
|
| 191 |
+
raise ValueError("Points must be given if point labels are given.")
|
| 192 |
+
return label_prompt, points, point_labels
|
| 193 |
+
|
| 194 |
+
def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Tensor]) -> dict:
|
| 195 |
+
"""
|
| 196 |
+
callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
|
| 197 |
+
Return below items in a dictionary:
|
| 198 |
+
- IMAGE: image Tensor data for model input, already moved to device.
|
| 199 |
+
- LABEL: label Tensor data corresponding to the image, already moved to device.
|
| 200 |
+
- PRED: prediction result of model.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
engine: `SupervisedEvaluator` to execute operation for an iteration.
|
| 204 |
+
batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
|
| 205 |
+
|
| 206 |
+
Raises:
|
| 207 |
+
ValueError: When ``batchdata`` is None.
|
| 208 |
+
|
| 209 |
+
"""
|
| 210 |
+
if batchdata is None:
|
| 211 |
+
raise ValueError("Must provide batch data for current iteration.")
|
| 212 |
+
label_set = engine.hyper_kwargs.get("label_set", None)
|
| 213 |
+
# this validation label set should be consistent with 'labels.unique()', used to generate fg/bg points
|
| 214 |
+
val_label_set = engine.hyper_kwargs.get("val_label_set", label_set)
|
| 215 |
+
# If user provide prompts in the inference, input image must contain original affine.
|
| 216 |
+
# the point coordinates are from the original_affine space, while image here is after preprocess transforms.
|
| 217 |
+
if engine.hyper_kwargs["user_prompt"]:
|
| 218 |
+
inputs, label_prompt, points, point_labels = (
|
| 219 |
+
batchdata["image"],
|
| 220 |
+
batchdata.get("label_prompt", None),
|
| 221 |
+
batchdata.get("points", None),
|
| 222 |
+
batchdata.get("point_labels", None),
|
| 223 |
+
)
|
| 224 |
+
labels = None
|
| 225 |
+
label_prompt, points, point_labels = self.check_prompts_format(label_prompt, points, point_labels)
|
| 226 |
+
inputs = inputs.to(engine.device)
|
| 227 |
+
# For N foreground object, label_prompt is [1, N], but the batch number 1 needs to be removed. Convert to [N, 1]
|
| 228 |
+
label_prompt = (
|
| 229 |
+
torch.as_tensor([label_prompt]).to(inputs.device)[0].unsqueeze(-1) if label_prompt is not None else None
|
| 230 |
+
)
|
| 231 |
+
# For points, the size can only be [1, K, 3], where K is the number of points for this single foreground object.
|
| 232 |
+
if points is not None:
|
| 233 |
+
points = torch.as_tensor([points])
|
| 234 |
+
points = self.transform_points(
|
| 235 |
+
points, np.linalg.inv(inputs.affine[0]) @ inputs.meta["original_affine"][0].numpy()
|
| 236 |
+
)
|
| 237 |
+
points = torch.from_numpy(points).to(inputs.device)
|
| 238 |
+
point_labels = torch.as_tensor([point_labels]).to(inputs.device) if point_labels is not None else None
|
| 239 |
+
|
| 240 |
+
# If validation with ground truth label available.
|
| 241 |
+
else:
|
| 242 |
+
inputs, labels = engine.prepare_batch(
|
| 243 |
+
batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs
|
| 244 |
+
)
|
| 245 |
+
# create label prompt, this should be consistent with the label prompt used for training.
|
| 246 |
+
if label_set is None:
|
| 247 |
+
output_classes = engine.hyper_kwargs["output_classes"]
|
| 248 |
+
label_set = np.arange(output_classes).tolist()
|
| 249 |
+
label_prompt = torch.tensor(label_set).to(engine.state.device).unsqueeze(-1)
|
| 250 |
+
# point prompt is generated withing vista3d, provide empty points
|
| 251 |
+
points = torch.zeros(label_prompt.shape[0], 1, 3).to(inputs.device)
|
| 252 |
+
point_labels = -1 + torch.zeros(label_prompt.shape[0], 1).to(inputs.device)
|
| 253 |
+
# validation for either auto or point.
|
| 254 |
+
if engine.hyper_kwargs.get("val_head", "auto") == "auto":
|
| 255 |
+
# automatic only validation
|
| 256 |
+
# remove val_label_set, vista3d will not sample points from gt labels.
|
| 257 |
+
val_label_set = None
|
| 258 |
+
else:
|
| 259 |
+
# point only validation
|
| 260 |
+
label_prompt = None
|
| 261 |
+
|
| 262 |
+
# put iteration outputs into engine.state
|
| 263 |
+
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: labels}
|
| 264 |
+
# execute forward computation
|
| 265 |
+
with engine.mode(engine.network):
|
| 266 |
+
if engine.amp:
|
| 267 |
+
with torch.amp.autocast("cuda", **engine.amp_kwargs):
|
| 268 |
+
engine.state.output[Keys.PRED] = engine.inferer(
|
| 269 |
+
inputs=inputs,
|
| 270 |
+
network=engine.network,
|
| 271 |
+
point_coords=points,
|
| 272 |
+
point_labels=point_labels,
|
| 273 |
+
class_vector=label_prompt,
|
| 274 |
+
labels=labels,
|
| 275 |
+
label_set=val_label_set,
|
| 276 |
+
)
|
| 277 |
+
else:
|
| 278 |
+
engine.state.output[Keys.PRED] = engine.inferer(
|
| 279 |
+
inputs=inputs,
|
| 280 |
+
network=engine.network,
|
| 281 |
+
point_coords=points,
|
| 282 |
+
point_labels=point_labels,
|
| 283 |
+
class_vector=label_prompt,
|
| 284 |
+
labels=labels,
|
| 285 |
+
label_set=val_label_set,
|
| 286 |
+
)
|
| 287 |
+
inputs = reset_ops_id(inputs)
|
| 288 |
+
# Add dim 0 for decollate batch
|
| 289 |
+
engine.state.output["label_prompt"] = label_prompt.unsqueeze(0) if label_prompt is not None else None
|
| 290 |
+
engine.state.output["points"] = points.unsqueeze(0) if points is not None else None
|
| 291 |
+
engine.state.output["point_labels"] = point_labels.unsqueeze(0) if point_labels is not None else None
|
| 292 |
+
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
|
| 293 |
+
engine.fire_event(IterationEvents.MODEL_COMPLETED)
|
| 294 |
+
if torch.cuda.is_available():
|
| 295 |
+
torch.cuda.empty_cache()
|
| 296 |
+
|
| 297 |
+
return engine.state.output
|
scripts/inferer.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
import copy
|
| 13 |
+
from typing import List, Union
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from monai.apps.vista3d.inferer import point_based_window_inferer
|
| 17 |
+
from monai.inferers import Inferer, SlidingWindowInfererAdapt
|
| 18 |
+
from torch import Tensor
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Vista3dInferer(Inferer):
|
| 22 |
+
"""
|
| 23 |
+
Vista3D Inferer
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
roi_size: the sliding window patch size.
|
| 27 |
+
overlap: sliding window overlap ratio.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, roi_size, overlap, use_point_window=False, sw_batch_size=1) -> None:
|
| 31 |
+
Inferer.__init__(self)
|
| 32 |
+
self.roi_size = roi_size
|
| 33 |
+
self.overlap = overlap
|
| 34 |
+
self.sw_batch_size = sw_batch_size
|
| 35 |
+
self.use_point_window = use_point_window
|
| 36 |
+
|
| 37 |
+
def __call__(
|
| 38 |
+
self,
|
| 39 |
+
inputs: Union[List[Tensor], Tensor],
|
| 40 |
+
network,
|
| 41 |
+
point_coords,
|
| 42 |
+
point_labels,
|
| 43 |
+
class_vector,
|
| 44 |
+
labels=None,
|
| 45 |
+
label_set=None,
|
| 46 |
+
prev_mask=None,
|
| 47 |
+
):
|
| 48 |
+
"""
|
| 49 |
+
Unified callable function API of Inferers.
|
| 50 |
+
Notice: The point_based_window_inferer currently only supports SINGLE OBJECT INFERENCE with B=1.
|
| 51 |
+
It only used in interactive segmentation.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
inputs: input tensor images.
|
| 55 |
+
network: vista3d model.
|
| 56 |
+
point_coords: point click coordinates. [B, N, 3].
|
| 57 |
+
point_labels: point click labels (0 for negative, 1 for positive) [B, N].
|
| 58 |
+
class_vector: class vector of length B.
|
| 59 |
+
labels: groundtruth labels. Used for sampling validation points.
|
| 60 |
+
label_set: [0,1,2,3,...,output_classes].
|
| 61 |
+
prev_mask: [1, B, H, W, D], THE VALUE IS BEFORE SIGMOID!
|
| 62 |
+
|
| 63 |
+
"""
|
| 64 |
+
prompt_class = copy.deepcopy(class_vector)
|
| 65 |
+
if class_vector is not None:
|
| 66 |
+
# Check if network has attribute 'point_head' directly or within its 'module'
|
| 67 |
+
if hasattr(network, "point_head"):
|
| 68 |
+
point_head = network.point_head
|
| 69 |
+
elif hasattr(network, "module") and hasattr(network.module, "point_head"):
|
| 70 |
+
point_head = network.module.point_head
|
| 71 |
+
else:
|
| 72 |
+
raise AttributeError("Network does not have attribute 'point_head'.")
|
| 73 |
+
|
| 74 |
+
if torch.any(class_vector > point_head.last_supported):
|
| 75 |
+
class_vector = None
|
| 76 |
+
val_outputs = None
|
| 77 |
+
torch.cuda.empty_cache()
|
| 78 |
+
if self.use_point_window and point_coords is not None:
|
| 79 |
+
if isinstance(inputs, list):
|
| 80 |
+
device = inputs[0].device
|
| 81 |
+
else:
|
| 82 |
+
device = inputs.device
|
| 83 |
+
val_outputs = point_based_window_inferer(
|
| 84 |
+
inputs=inputs,
|
| 85 |
+
roi_size=self.roi_size,
|
| 86 |
+
sw_batch_size=self.sw_batch_size,
|
| 87 |
+
transpose=True,
|
| 88 |
+
with_coord=True,
|
| 89 |
+
predictor=network,
|
| 90 |
+
mode="gaussian",
|
| 91 |
+
sw_device=device,
|
| 92 |
+
device=device,
|
| 93 |
+
overlap=self.overlap,
|
| 94 |
+
point_coords=point_coords,
|
| 95 |
+
point_labels=point_labels,
|
| 96 |
+
class_vector=class_vector,
|
| 97 |
+
prompt_class=prompt_class,
|
| 98 |
+
prev_mask=prev_mask,
|
| 99 |
+
labels=labels,
|
| 100 |
+
label_set=label_set,
|
| 101 |
+
)
|
| 102 |
+
else:
|
| 103 |
+
val_outputs = SlidingWindowInfererAdapt(
|
| 104 |
+
roi_size=self.roi_size, sw_batch_size=self.sw_batch_size, with_coord=True, padding_mode="replicate"
|
| 105 |
+
)(
|
| 106 |
+
inputs,
|
| 107 |
+
network,
|
| 108 |
+
transpose=True,
|
| 109 |
+
point_coords=point_coords,
|
| 110 |
+
point_labels=point_labels,
|
| 111 |
+
class_vector=class_vector,
|
| 112 |
+
prompt_class=prompt_class,
|
| 113 |
+
prev_mask=prev_mask,
|
| 114 |
+
labels=labels,
|
| 115 |
+
label_set=label_set,
|
| 116 |
+
)
|
| 117 |
+
return val_outputs
|
scripts/trainer.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
from monai.apps.vista3d.sampler import sample_prompt_pairs
|
| 19 |
+
from monai.engines.trainer import Trainer
|
| 20 |
+
from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch
|
| 21 |
+
from monai.inferers import Inferer, SimpleInferer
|
| 22 |
+
from monai.transforms import Transform
|
| 23 |
+
from monai.utils import IgniteInfo, RankFilter, min_version, optional_import
|
| 24 |
+
from monai.utils.enums import CommonKeys as Keys
|
| 25 |
+
from torch.optim.optimizer import Optimizer
|
| 26 |
+
from torch.utils.data import DataLoader
|
| 27 |
+
|
| 28 |
+
if TYPE_CHECKING:
|
| 29 |
+
from ignite.engine import Engine, EventEnum
|
| 30 |
+
from ignite.metrics import Metric
|
| 31 |
+
else:
|
| 32 |
+
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
|
| 33 |
+
Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
|
| 34 |
+
EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")
|
| 35 |
+
|
| 36 |
+
__all__ = ["Vista3dTrainer"]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Vista3dTrainer(Trainer):
|
| 40 |
+
"""
|
| 41 |
+
Supervised detection training method with image and label, inherits from ``Trainer`` and ``Workflow``.
|
| 42 |
+
Args:
|
| 43 |
+
device: an object representing the device on which to run.
|
| 44 |
+
max_epochs: the total epoch number for trainer to run.
|
| 45 |
+
train_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader.
|
| 46 |
+
detector: detector to train in the trainer, should be regular PyTorch `torch.nn.Module`.
|
| 47 |
+
optimizer: the optimizer associated to the detector, should be regular PyTorch optimizer from `torch.optim`
|
| 48 |
+
or its subclass.
|
| 49 |
+
epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
|
| 50 |
+
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
|
| 51 |
+
with respect to the host. For other cases, this argument has no effect.
|
| 52 |
+
prepare_batch: function to parse expected data (usually `image`,`box`, `label` and other detector args)
|
| 53 |
+
from `engine.state.batch` for every iteration, for more details please refer to:
|
| 54 |
+
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
|
| 55 |
+
iteration_update: the callable function for every iteration, expect to accept `engine`
|
| 56 |
+
and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
|
| 57 |
+
if not provided, use `self._iteration()` instead. for more details please refer to:
|
| 58 |
+
https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
|
| 59 |
+
inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
|
| 60 |
+
postprocessing: execute additional transformation for the model output data.
|
| 61 |
+
Typically, several Tensor based transforms composed by `Compose`.
|
| 62 |
+
key_train_metric: compute metric when every iteration completed, and save average value to
|
| 63 |
+
engine.state.metrics when epoch completlabel_set = np.arange(output_classes).tolist().
|
| 64 |
+
key_train_metric is the main metric to compare and save the checkpoint into files.
|
| 65 |
+
additional_metrics: more Ignite metrics that also attach to Ignite Engine.
|
| 66 |
+
metric_cmp_fn: function to compare current key metric with previous best key metric value,
|
| 67 |
+
it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
|
| 68 |
+
`best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
|
| 69 |
+
train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
|
| 70 |
+
CheckpointHandler, StatsHandler, etc.
|
| 71 |
+
amp: whether to enable auto-mixed-precision training, default is False.
|
| 72 |
+
event_names: additional custom ignite events that will register to the engine.
|
| 73 |
+
new events can be a list of str or `ignite.engine.events.EventEnum`.
|
| 74 |
+
event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
|
| 75 |
+
for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
|
| 76 |
+
#ignite.engine.engine.Engine.register_events.
|
| 77 |
+
decollate: whether to decollate the batch-first data to a list of data after model computation,
|
| 78 |
+
recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
|
| 79 |
+
default to `True`.
|
| 80 |
+
optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
|
| 81 |
+
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
|
| 82 |
+
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
| 83 |
+
`device`, `non_blocking`.
|
| 84 |
+
amp_kwargs: dict of the args for `torch.amp.autocast()` API, for more details:
|
| 85 |
+
https://pytorch.org/docs/stable/amp.html#torch.amp.autocast.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
device: torch.device,
|
| 91 |
+
max_epochs: int,
|
| 92 |
+
train_data_loader: Iterable | DataLoader,
|
| 93 |
+
network: torch.nn.Module,
|
| 94 |
+
optimizer: Optimizer,
|
| 95 |
+
loss_function: Callable,
|
| 96 |
+
epoch_length: int | None = None,
|
| 97 |
+
non_blocking: bool = False,
|
| 98 |
+
prepare_batch: Callable = default_prepare_batch,
|
| 99 |
+
iteration_update: Callable[[Engine, Any], Any] | None = None,
|
| 100 |
+
inferer: Inferer | None = None,
|
| 101 |
+
postprocessing: Transform | None = None,
|
| 102 |
+
key_train_metric: dict[str, Metric] | None = None,
|
| 103 |
+
additional_metrics: dict[str, Metric] | None = None,
|
| 104 |
+
metric_cmp_fn: Callable = default_metric_cmp_fn,
|
| 105 |
+
train_handlers: Sequence | None = None,
|
| 106 |
+
amp: bool = False,
|
| 107 |
+
event_names: list[str | EventEnum] | None = None,
|
| 108 |
+
event_to_attr: dict | None = None,
|
| 109 |
+
decollate: bool = True,
|
| 110 |
+
optim_set_to_none: bool = False,
|
| 111 |
+
to_kwargs: dict | None = None,
|
| 112 |
+
amp_kwargs: dict | None = None,
|
| 113 |
+
hyper_kwargs: dict | None = None,
|
| 114 |
+
) -> None:
|
| 115 |
+
super().__init__(
|
| 116 |
+
device=device,
|
| 117 |
+
max_epochs=max_epochs,
|
| 118 |
+
data_loader=train_data_loader,
|
| 119 |
+
epoch_length=epoch_length,
|
| 120 |
+
non_blocking=non_blocking,
|
| 121 |
+
prepare_batch=prepare_batch,
|
| 122 |
+
iteration_update=iteration_update,
|
| 123 |
+
postprocessing=postprocessing,
|
| 124 |
+
key_metric=key_train_metric,
|
| 125 |
+
additional_metrics=additional_metrics,
|
| 126 |
+
metric_cmp_fn=metric_cmp_fn,
|
| 127 |
+
handlers=train_handlers,
|
| 128 |
+
amp=amp,
|
| 129 |
+
event_names=event_names,
|
| 130 |
+
event_to_attr=event_to_attr,
|
| 131 |
+
decollate=decollate,
|
| 132 |
+
to_kwargs=to_kwargs,
|
| 133 |
+
amp_kwargs=amp_kwargs,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
self.network = network
|
| 137 |
+
self.optimizer = optimizer
|
| 138 |
+
self.loss_function = loss_function
|
| 139 |
+
self.inferer = SimpleInferer() if inferer is None else inferer
|
| 140 |
+
self.optim_set_to_none = optim_set_to_none
|
| 141 |
+
self.hyper_kwargs = hyper_kwargs
|
| 142 |
+
self.logger.addFilter(RankFilter())
|
| 143 |
+
|
| 144 |
+
def _iteration(self, engine, batchdata: dict[str, torch.Tensor]):
|
| 145 |
+
"""
|
| 146 |
+
Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine.
|
| 147 |
+
Return below items in a dictionary:
|
| 148 |
+
- IMAGE: image Tensor data for model input, already moved to device.
|
| 149 |
+
Args:
|
| 150 |
+
engine: `Vista3DTrainer` to execute operation for an iteration.
|
| 151 |
+
batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
|
| 152 |
+
Raises:
|
| 153 |
+
ValueError: When ``batchdata`` is None.
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
if batchdata is None:
|
| 157 |
+
raise ValueError("Must provide batch data for current iteration.")
|
| 158 |
+
|
| 159 |
+
inputs, labels = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)
|
| 160 |
+
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: labels}
|
| 161 |
+
|
| 162 |
+
label_set = engine.hyper_kwargs["label_set"]
|
| 163 |
+
output_classes = engine.hyper_kwargs["output_classes"]
|
| 164 |
+
if label_set is None:
|
| 165 |
+
label_set = np.arange(output_classes).tolist()
|
| 166 |
+
label_prompt, point, point_label, prompt_class = sample_prompt_pairs(
|
| 167 |
+
labels,
|
| 168 |
+
label_set,
|
| 169 |
+
image_size=engine.hyper_kwargs["patch_size"],
|
| 170 |
+
max_point=engine.hyper_kwargs["max_point"],
|
| 171 |
+
max_prompt=engine.hyper_kwargs["max_prompt"],
|
| 172 |
+
max_backprompt=engine.hyper_kwargs["max_backprompt"],
|
| 173 |
+
max_foreprompt=engine.hyper_kwargs["max_foreprompt"],
|
| 174 |
+
drop_label_prob=engine.hyper_kwargs["drop_label_prob"],
|
| 175 |
+
drop_point_prob=engine.hyper_kwargs["drop_point_prob"],
|
| 176 |
+
include_background=not engine.hyper_kwargs["exclude_background"],
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
def _compute_pred_loss():
|
| 180 |
+
outputs = engine.network(
|
| 181 |
+
input_images=inputs, point_coords=point, point_labels=point_label, class_vector=label_prompt
|
| 182 |
+
)
|
| 183 |
+
# engine.state.output[Keys.PRED] = outputs
|
| 184 |
+
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
|
| 185 |
+
loss, loss_n = torch.tensor(0.0, device=engine.state.device), torch.tensor(0.0, device=engine.state.device)
|
| 186 |
+
for id in range(len(prompt_class)):
|
| 187 |
+
loss += engine.loss_function(outputs[[id]].float(), labels == prompt_class[id])
|
| 188 |
+
loss_n += 1.0
|
| 189 |
+
loss /= max(loss_n, 1.0)
|
| 190 |
+
engine.state.output[Keys.LOSS] = loss
|
| 191 |
+
outputs = None
|
| 192 |
+
torch.cuda.empty_cache()
|
| 193 |
+
engine.fire_event(IterationEvents.LOSS_COMPLETED)
|
| 194 |
+
|
| 195 |
+
engine.network.train()
|
| 196 |
+
engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
|
| 197 |
+
|
| 198 |
+
if engine.amp and engine.scaler is not None:
|
| 199 |
+
with torch.amp.autocast("cuda", **engine.amp_kwargs):
|
| 200 |
+
_compute_pred_loss()
|
| 201 |
+
engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
|
| 202 |
+
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
|
| 203 |
+
engine.scaler.step(engine.optimizer)
|
| 204 |
+
engine.scaler.update()
|
| 205 |
+
else:
|
| 206 |
+
_compute_pred_loss()
|
| 207 |
+
engine.state.output[Keys.LOSS].backward()
|
| 208 |
+
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
|
| 209 |
+
engine.optimizer.step()
|
| 210 |
+
engine.fire_event(IterationEvents.MODEL_COMPLETED)
|
| 211 |
+
return engine.state.output
|