Skip to content

Commit c2172d0

Browse files
authored
Dev (#66)
* feat: add foreground masking for ZNormalization, update patch tutorials * chore: bump version to 0.8.1 and run nbdev_prepare
1 parent b2bdd07 commit c2172d0

8 files changed

Lines changed: 279 additions & 195 deletions

fastMONAI/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.8.0"
1+
__version__ = "0.8.1"

fastMONAI/_modidx.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,8 @@
291291
'fastMONAI/vision_augmentation.py'),
292292
'fastMONAI.vision_augmentation._create_ellipsoid_mask': ( 'vision_augment.html#_create_ellipsoid_mask',
293293
'fastMONAI/vision_augmentation.py'),
294+
'fastMONAI.vision_augmentation._foreground_masking': ( 'vision_augment.html#_foreground_masking',
295+
'fastMONAI/vision_augmentation.py'),
294296
'fastMONAI.vision_augmentation.do_pad_or_crop': ( 'vision_augment.html#do_pad_or_crop',
295297
'fastMONAI/vision_augmentation.py'),
296298
'fastMONAI.vision_augmentation.suggest_patch_augmentations': ( 'vision_augment.html#suggest_patch_augmentations',

fastMONAI/vision_augmentation.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,20 @@ def tio_transform(self):
9191
def encodes(self, o: (MedImage, MedMask)):
9292
return type(o)(self.pad_or_crop(o))
9393

94+
# %% ../nbs/03_vision_augment.ipynb #534509q2nn
95+
def _foreground_masking(tensor):
96+
"""Mask for non-zero voxels (nnU-Net-style foreground normalization)."""
97+
return tensor > 0
98+
9499
# %% ../nbs/03_vision_augment.ipynb #ca95a690
95100
class ZNormalization(DisplayedTransform):
96101
"""Apply TorchIO `ZNormalization`."""
97102

98103
order = 0
99104

100105
def __init__(self, masking_method=None, channel_wise=True):
106+
if masking_method == 'foreground':
107+
masking_method = _foreground_masking
101108
self.z_normalization = tio.ZNormalization(masking_method=masking_method)
102109
self.channel_wise = channel_wise
103110

nbs/03_vision_augment.ipynb

Lines changed: 9 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -157,58 +157,21 @@
157157
" return type(o)(self.pad_or_crop(o))"
158158
]
159159
},
160+
{
161+
"cell_type": "code",
162+
"execution_count": null,
163+
"id": "534509q2nn",
164+
"metadata": {},
165+
"outputs": [],
166+
"source": "#| export\ndef _foreground_masking(tensor):\n \"\"\"Mask for non-zero voxels (nnU-Net-style foreground normalization).\"\"\"\n return tensor > 0"
167+
},
160168
{
161169
"cell_type": "code",
162170
"execution_count": null,
163171
"id": "ca95a690",
164172
"metadata": {},
165173
"outputs": [],
166-
"source": [
167-
"# | export\n",
168-
"class ZNormalization(DisplayedTransform):\n",
169-
" \"\"\"Apply TorchIO `ZNormalization`.\"\"\"\n",
170-
"\n",
171-
" order = 0\n",
172-
"\n",
173-
" def __init__(self, masking_method=None, channel_wise=True):\n",
174-
" self.z_normalization = tio.ZNormalization(masking_method=masking_method)\n",
175-
" self.channel_wise = channel_wise\n",
176-
"\n",
177-
" @property\n",
178-
" def tio_transform(self):\n",
179-
" \"\"\"Return the underlying TorchIO transform.\"\"\"\n",
180-
" return self.z_normalization\n",
181-
"\n",
182-
" def encodes(self, o: MedImage):\n",
183-
" try:\n",
184-
" if self.channel_wise:\n",
185-
" o = torch.stack([self.z_normalization(c[None])[0] for c in o])\n",
186-
" else: \n",
187-
" o = self.z_normalization(o)\n",
188-
" except RuntimeError as e:\n",
189-
" if \"Standard deviation is 0\" in str(e):\n",
190-
" # Calculate mean for debugging information\n",
191-
" mean = float(o.mean())\n",
192-
" \n",
193-
" error_msg = (\n",
194-
" f\"Standard deviation is 0 for image (mean={mean:.3f}).\\n\"\n",
195-
" f\"This indicates uniform pixel values.\\n\\n\"\n",
196-
" f\"Possible causes:\\n\"\n",
197-
" f\"• Corrupted or blank image\\n\"\n",
198-
" f\"• Oversaturated regions\\n\" \n",
199-
" f\"• Background-only regions\\n\"\n",
200-
" f\"• All-zero mask being processed as image\\n\\n\"\n",
201-
" f\"Suggested solutions:\\n\"\n",
202-
" f\"• Check image quality and acquisition\\n\"\n",
203-
" f\"• Verify image vs mask data loading\"\n",
204-
" )\n",
205-
" raise RuntimeError(error_msg) from e\n",
206-
"\n",
207-
" return MedImage.create(o)\n",
208-
"\n",
209-
" def encodes(self, o: MedMask):\n",
210-
" return o"
211-
]
174+
"source": "# | export\nclass ZNormalization(DisplayedTransform):\n \"\"\"Apply TorchIO `ZNormalization`.\"\"\"\n\n order = 0\n\n def __init__(self, masking_method=None, channel_wise=True):\n if masking_method == 'foreground':\n masking_method = _foreground_masking\n self.z_normalization = tio.ZNormalization(masking_method=masking_method)\n self.channel_wise = channel_wise\n\n @property\n def tio_transform(self):\n \"\"\"Return the underlying TorchIO transform.\"\"\"\n return self.z_normalization\n\n def encodes(self, o: MedImage):\n try:\n if self.channel_wise:\n o = torch.stack([self.z_normalization(c[None])[0] for c in o])\n else: \n o = self.z_normalization(o)\n except RuntimeError as e:\n if \"Standard deviation is 0\" in str(e):\n # Calculate mean for debugging information\n mean = float(o.mean())\n \n error_msg = (\n f\"Standard deviation is 0 for image (mean={mean:.3f}).\\n\"\n f\"This indicates uniform pixel values.\\n\\n\"\n f\"Possible causes:\\n\"\n f\"• Corrupted or blank image\\n\"\n f\"• Oversaturated regions\\n\" \n f\"• Background-only regions\\n\"\n f\"• All-zero mask being processed as image\\n\\n\"\n f\"Suggested solutions:\\n\"\n f\"• Check image quality and acquisition\\n\"\n f\"• Verify image vs mask data loading\"\n )\n raise RuntimeError(error_msg) from e\n\n return MedImage.create(o)\n\n def encodes(self, o: MedMask):\n return o"
212175
},
213176
{
214177
"cell_type": "code",

nbs/12a_tutorial_patch_training.ipynb

Lines changed: 225 additions & 143 deletions
Large diffs are not rendered by default.

nbs/12b_tutorial_patch_inference.ipynb

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,15 +258,43 @@
258258
"cell_type": "markdown",
259259
"id": "cell-model-header",
260260
"metadata": {},
261-
"source": "### Load trained model\n\nLoad the exported learner which contains both the model architecture and best weights.\n\n**Two options:**\n1. **Local file**: Load `best_learner.pkl` exported during training\n2. **MLflow**: Download from MLflow artifacts (for experiment tracking workflows)"
261+
"source": [
262+
"### Load trained model\n",
263+
"\n",
264+
"Load the exported learner which contains both the model architecture and best weights.\n",
265+
"\n",
266+
"**Two options:**\n",
267+
"1. **Local file**: Load `best_learner.pkl` exported during training\n",
268+
"2. **MLflow**: Download from MLflow artifacts (for experiment tracking workflows)"
269+
]
262270
},
263271
{
264272
"cell_type": "code",
265273
"execution_count": null,
266274
"id": "cell-learner",
267275
"metadata": {},
268276
"outputs": [],
269-
"source": "from fastai.learner import load_learner\nimport torch\n\n# Option 1: Load from local file\nlearn = load_learner('models/best_learner.pkl')\n\n# Option 2: Load from MLflow (uncomment to use)\n# import mlflow\n# run_id = \"your_run_id\" # Get from MLflow UI\n# mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=\"model/best_learner.pkl\", dst_path=\"./\")\n# learn = load_learner('best_learner.pkl')\n\nmodel = learn.model\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel.to(device)\nmodel.eval()\n\nprint(f\"Loaded model: {model.__class__.__name__}\")\nprint(f\"Device: {device}\")"
277+
"source": [
278+
"from fastai.learner import load_learner\n",
279+
"import torch\n",
280+
"\n",
281+
"# Option 1: Load from local file\n",
282+
"learn = load_learner('models/best_learner.pkl')\n",
283+
"\n",
284+
"# Option 2: Load from MLflow (uncomment to use)\n",
285+
"# import mlflow\n",
286+
"# run_id = \"your_run_id\" # Get from MLflow UI\n",
287+
"# mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=\"model/best_learner.pkl\", dst_path=\"./\")\n",
288+
"# learn = load_learner('best_learner.pkl')\n",
289+
"\n",
290+
"model = learn.model\n",
291+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
292+
"model.to(device)\n",
293+
"model.eval()\n",
294+
"\n",
295+
"print(f\"Loaded model: {model.__class__.__name__}\")\n",
296+
"print(f\"Device: {device}\")"
297+
]
270298
},
271299
{
272300
"cell_type": "markdown",
@@ -288,7 +316,7 @@
288316
"outputs": [],
289317
"source": [
290318
"# MUST match training pre_patch_tfms\n",
291-
"pre_inference_tfms = [ZNormalization()]"
319+
"pre_inference_tfms = [ZNormalization(masking_method='foreground')]\n"
292320
]
293321
},
294322
{

nbs/sidebar.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,6 @@ website:
2020
- 11c_tutorial_regression.ipynb
2121
- 11d_tutorial_binary_segmentation.ipynb
2222
- 11e_tutorial_multiclass_segmentation.ipynb
23-
- 11f_tutorial_inference.ipynb
23+
- 11f_tutorial_inference.ipynb
24+
- 12a_tutorial_patch_training.ipynb
25+
- 12b_tutorial_patch_inference.ipynb

settings.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
### Python Library ###
66
lib_name = fastMONAI
77
min_python = 3.10
8-
version = 0.8.0
8+
version = 0.8.1
99
### OPTIONAL ###
1010

1111
requirements = fastai==2.8.6 monai==1.5.2 torchio==0.21.2 xlrd>=1.2.0 scikit-image==0.26.0 imagedata==3.8.14 mlflow==3.9.0 huggingface-hub gdown gradio opencv-python plum-dispatch

0 commit comments

Comments
 (0)