nnUNet Tutorial Notebook#

This notebook will guide you through the process of training a nnUNet model on the BraTS dataset, to segment Adult Gliomas. The notebook will cover the basic steps on how to perform a complete nnUNet experiment, from downloading the data to training the model and making predictions.

Data Downloading#

First we will download the BraTS dataset from the Decathlon Challenge Website. The dataset is available at https://drive.google.com/uc?id=1A2IU8Sgea1h3fYLpYtFb2v7NYdMjvEhU

[ ]:
!pip install gdown
[ ]:
import gdown

output_tar = gdown.download("https://drive.google.com/uc?id=1A2IU8Sgea1h3fYLpYtFb2v7NYdMjvEhU")
[ ]:
import tarfile
tar = tarfile.open(output_tar)
tar.extractall()
tar.close()

Multi-Modal to Single Modality Conversion#

nnUNet requires the data to be in a specific format, where each modality is stored in a separate file. Conversely, the Decathlon BraTS Dataset stores all the 4 Image Modalities in a single multi-channel file. We will convert the multi-modal data to single modal data.

[ ]:
import SimpleITK as sitk
import os
from pathlib import Path
import numpy as np
from tqdm.notebook import tqdm

data_dir = "Task01_BrainTumour"
data_list = [f.name for f in os.scandir(Path(data_dir).joinpath("imagesTr")) if f.is_file()]
file_extension = ".nii.gz"

output_dir = str(Path(data_dir).joinpath("imagesTr_Single"))


Path(output_dir).mkdir(parents=True,exist_ok=True)
modality_dict = {
         "_001.nii.gz": "FLAIR",
         "_002.nii.gz": "T1w",
         "_003.nii.gz": "t1gd",
         "_004.nii.gz": "T2w"
    }

for data in tqdm(data_list):
    if data.startswith("."):
        continue
    image = sitk.ReadImage(str(Path(data_dir).joinpath("imagesTr",data)))
    data_array = sitk.GetArrayFromImage(image)
    for idx,modality in enumerate(modality_dict):
        single_image = sitk.GetImageFromArray(data_array[idx])
        single_image.SetSpacing(image.GetSpacing())
        single_image.SetOrigin(image.GetOrigin())
        single_image.SetDirection(image.GetDirection()[:3]+image.GetDirection()[4:7]+image.GetDirection()[8:11])
        filename = str(Path(output_dir).joinpath(str(data)[:-len(file_extension)]+modality))
        #print(f"Writing {filename}")
        sitk.WriteImage(single_image, filename)

Configuration File#

Next, we will create a PyMAIA configuration file for the experiment. The configuration file will contain the following information:

[ ]:
import json
brats_config = {
    "Experiment Name": "BraTS",
    "Seed": 12345,
    "label_suffix": ".nii.gz",
    "Modalities": modality_dict,
    "label_dict": {
        "background": 0,
        "whole_tumor": [1, 2, 3],
        "tumor_core": [2, 3],
        "enhancing_tumor": 3
    },
    "n_folds": 5,
    "FileExtension": ".nii.gz",
    "RegionClassOrder" : [1,2,3]

}

with open("BraTS_config.json","w") as f:
    json.dump(brats_config,f,indent=4)

Decathlon Dataset File#

Finally, we will create a dataset.json file that will contain the paths to the training and testing data. The dataset.json file will have the following structure:

{
    "train": [
        {
            "FLAIR": "Path to FLAIR Image",
            "T1w": "Path to T1w Image",
            "t1gd": "Path to t1gd Image",
            "T2w": "Path to T2w Image",
            "label": "Path to Label Image"
        }
    ],
    "test": [
        {
            "FLAIR": "Path to FLAIR Image",
            "T1w": "Path to T1w Image",
            "t1gd": "Path to t1gd Image",
            "T2w": "Path to T2w Image",
            "label": "Path to Label Image"
        }
    ]
}
[ ]:
cases = [f.name[:-len("_000.nii.gz")]
         for f in os.scandir(Path(data_dir).joinpath("imagesTr_Single"))
         if f.is_file()
         if f.name.endswith(file_extension)]

cases = np.unique(cases)

data_list = {
    "train":
        [
            {
                modality_dict[modality_id] : str(Path(data_dir).joinpath("imagesTr_Single",case + modality_id))
                for modality_id in modality_dict
            }
            for case in cases
        ],
    "test": []
}

for section in data_list:
    for idx, case in enumerate(data_list[section]):
        f = Path(data_list[section][idx][list(modality_dict.values())[0]]).name
        data_list[section][idx]["label"] = str(Path(data_dir).joinpath("labelsTr", f[:-len("_000.nii.gz")]+brats_config["label_suffix"]))


with open("dataset.json", "w") as f:
    json.dump(data_list, f, indent=4)

Create Pipeline#

[ ]:
%%bash

export ROOT_FOLDER=./

nnunet_create_pipeline.py --input-data-folder dataset.json --config-file BraTS_config.json --task-ID 100 --test-split 0

Prepare Data#

[ ]:
%%bash

export ROOT_FOLDER=/opt/code/PyMAIA/Tutorials
nnunet_prepare_data_folder --input-data-folder dataset.json --task-ID 100 --task-name BraTS --config-file BraTS_config.json --test-split 0

Pre-Processing#

[ ]:
%%bash

export ROOT_FOLDER=/opt/code/PyMAIA/Tutorials

nnunet_run_plan_and_preprocessing --config-file /opt/code/PyMAIA/Tutorials/BraTS/BraTS_results/Dataset100_BraTS.json -np 4

During the preprocessing step, the nnUnet framework will automatically handle the following steps:

  • Resampling (Target spacing, followed by Transpose)

  • Normalization ( Optional use of Non-Zero Mask, Custom Normalization Scheme for different modalities)

Model Training#

To customize the nnUNet model training, we can create a custom nnUNetTrainer class that inherits from the nnUNetTrainer class. The custom class can be used to override the default training configuration, such as the learning rate, weight decay, and number of epochs:

from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
import torch


class nnUNetTrainerDemo(nnUNetTrainer):
    def __init__(
            self,
            plans: dict,
            configuration: str,
            fold: int,
            dataset_json: dict,
            unpack_dataset: bool = True,
            device: torch.device = torch.device("cuda"),
    ):
        super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
        self.num_iterations_per_epoch = 10
        self.num_val_iterations_per_epoch = 10
        self.num_epochs = 5
        self.initial_lr = 1e-2
        self.weight_decay = 3e-5
        self.oversample_foreground_percent = 0.33
        self.num_iterations_per_epoch = 250
        self.num_val_iterations_per_epoch = 50
        self.num_epochs = 1000
        self.current_epoch = 0
        self.enable_deep_supervision = False

    def configure_optimizers(self):
        return torch.optim.Adam(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay)


    def _build_loss(self):
        self.loss = torch.nn.CrossEntropyLoss()
        return self.loss

To customize the Batch and the Patch size, we can modify the corresponding entries in the nnUNetPlans file.

[ ]:
%%bash

export ROOT_FOLDER=/opt/code/PyMAIA/Tutorials
export N_THREADS=4
nnunet_run_training --config-file /opt/code/PyMAIA/Tutorials/BraTS/BraTS_results/Dataset100_BraTS.json --run-fold 0 -tr nnUNetTrainerDemo

Export nnUNet Model#

After the 5-fold cross-validation training is complete, we can export the nnUNet model to a zip file. The zip file contains the model weights, the configuration file, and the training logs.

[ ]:
%%bash

export ROOT_FOLDER=/opt/code/PyMAIA/Tutorials
export N_THREADS=4
nnunet_run_training --config-file /opt/code/PyMAIA/Tutorials/BraTS/BraTS_results/Dataset100_BraTS.json --run-fold -1 --output-model-file BraTS_nnuNet.zip -tr nnUNetTrainerDemo

Convert nnUNet to MONAI Bundle [Coming Soon]#

Export Trained Model and Upload to MLFlow [Coming Soon]#

Run Inference [Coming Soon]#