A Comprehensive Guide to MRI Analysis through Deep Learning models in PyTorch
Introduction
First of all, I’d like to introduce myself. My name is Carla Pitarch, and I am a PhD candidate in AI. My research centers on developing an automated brain tumor grade classification system by extracting information from Magnetic Resonance Images (MRIs) using Deep Learning (DL) models, particularly Convolutional Neural Networks (CNNs).
At the start of my PhD journey, diving into MRI data and DL was a whole new world. The initial steps for running models in this realm were not as straightforward as expected. Despite spending some time researching in this domain, I found a lack of comprehensive repositories guiding the initiation into both MRI and DL. Therefore, I decided to share some of the knowledge I have gained over this period, hoping it makes your journey a tad smoother.
Embarking on Computer Vision (CV) tasks through DL often involves using standard public image datasets such as ImageNet , characterized by 3-channel RGB natural images. PyTorch models are primed for these specifications, expecting input images to be in this format. However, when our image data comes from a distinct domain, like the medical field, diverging in both format and features from these natural image datasets, it presents challenges. This post delves into this issue, emphasizing two crucial preparatory steps before model implementation: aligning the data with the model’s requirements and preparing the model to effectively process our data.
Background
Let’s start with a brief overview of the fundamental aspects of CNNs and MRI.
Convolutional Neural Networks
In this section, we delve into the realm of CNNs, assuming readers have a foundational understanding of DL. CNNs stand as the gold standard architectures in CV, specializing in the processing of 2D and 3D input image data. Our focus within this post will be centered on the processing of 2D image data.
Image classification, associating output classes or labels with input images, is a core task in CNNs. The pioneering LeNet5 architecture introduced by LeCun et al.¹ in 1989 laid the groundwork for CNNs. This architecture can be summarized as follows:
2D CNN architectures operate by receiving image pixels as input, expecting an image to be a tensor with shape Height x Width x Channels. Color images typically consist of 3 channels: red, green and blue (RGB), while grayscale images consist of a single channel.
A fundamental operation in CNNs is convolution, executed by applying a set of filters or kernels across all areas of the input data. The figure below shows an example of how convolution works in a 2D context.
The process involves sliding the filter across the image to the right and compute the weighted sum to obtain a convolved feature map. The output will represent whether a specific visual pattern, for instance an edge, is recognized at that location in the input image. Following each convolutional layer, an activation function introduces non-linearity. Popular choices include: ReLU (Rectified Linear Unit), Leaky ReLu, Sigmoid, Tanh, and Softmax. For further details on each activation function, this post provides clear explanations Activation Functions in Neural Networks | by SAGAR SHARMA | Towards Data Science.
Different types of layers contribute to the construction of CNNs, each playing a distinct role in defining the network’s functionality. Alongside convolutional layers, several other prominent layers used in CNNs include:
- Pooling layers, like max-pooling or average-pooling, efffectively reduce feature map dimensions while preserving essential information.
- Dropout layers are used to prevent overfitting by randomly deactivating a fraction of neurons during training, thereby enhancing the network’s generalization ability.
- Batch normalization layers focus on standardizing inputs for each layer, which accelerates network training.
- Fully connected (FC) layers establish connections between all neurons in one layer and all activations from the preceding layer, integrating learned features to facilitate final classifications.
CNNs learn to identify patterns hierarchically. Initial layers focus on low-level features, progressively moving to highly abstract features in deeper layers. Upon reaching the FC layer, the Softmax activation function estimates class probabilities for the image input.
Beyond LeNet’s inception, notable CNN architectures like AlexNet², GoogLeNet³, VGGNet⁴, ResNet⁵, and more recent Transformer⁶ have significantly contributed to the realm of DL.
Natural Images Overview
Exploring natural 2D images provides a foundational understanding of image data. To begin, we will delve into some examples.
For our first example we will select an image from the widely known MNIST dataset.
This image shape is [28,28] , representing a grayscale image with a single channel. Then, the image input for a neural network would be (28*28*1).
Now let’s explore an image from the ImageNet dataset. You can access this dataset directly from ImageNet’s website ImageNet (image-net.org) or explore a subset available on Kaggle ImageNet Object Localization Challenge | Kaggle.
We can decompose this image into its RGB channels:
Since the shape of this image is [500, 402, 3] the image input of a neural network would be represented as (500*402*3).
Magnetic Resonance Imaging
The MRI scan is the most widely used test in neurology and neurosurgery, offering a non-invasive method that provides good soft tissue contrast⁷. Beyond visualizing structural details, MR imaging delves deeper, providing valuable insights into both the structural and functional aspects of the brain.
MRIs, constitute 3D volumes that enable visualization across the three anatomical planes: axial, coronal, and sagittal. These volumes are composed of 3D cubes known as voxels, in contrast to the standard 2D images, which are made up of 2D squares called pixels. While 3D volumes offer comprehensive data, they can also be decomposed into 2D slices.
Diverse MRI modalities or sequences, such as T1, T1 with gadolinium contrast enhancement (T1ce), T2, and FLAIR (Fluid Attenuated Inversion Recovery), are normally used for diagnosis. These sequences enable the differentiation of tumor compartments by offering distinct signal intensities that correspond to specific regions or tissues. Below is an illustration presenting these four sequences from a single patient diagnosed with glioblastoma, known as the most aggressive form among gliomas, the most prevalent primary brain tumors.
Brain Tumor Segmentation Data
Brain Tumor Segmentation (BraTS) Challenge made available one of the most extensive multi-modal brain MRI scans datasets of glioma patients spanning from 2012 to 2023. The primary goal of the BraTS competition is to evaluate the state-of-the-art methodologies for segmenting brain tumors in multi-modal MRI scans, although additional tasks have been added over time.
BraTS dataset provides clinical information about the tumors, including a binary label indicating the tumor grade (low-grade or high-grade). BraTS scans are available as NIfTI files and describe T1, T1ce, T2 and Flair modalities. The scans are provided after some pre-processing steps, including co-registration to the same anatomical template, interpolation to a uniform isotropic resolution (1mm³), and skull-stripping.
In this post we will use the 2020 dataset from Kaggle BraTS2020 Dataset (Training + Validation) to classify glioma MRIs into low or high grade.
The following images display examples of low-grade and high-grade gliomas:
The Kaggle repository comprises 369 directories, each representing a patient and containing corresponding image data. Additionally, it contains two .csv files: name_mapping.csv and survival_info.csv. For our purpose, we will utilize name_mapping.csv, which links BraTS patient names to the TCGA-LGG and TCGA-GBM public datasets available from the Cancer Imaging Archive. This file not only facilitates name mapping but also provides tumor grade labels (LGG-HGG).
Let’s explore the contents of each patient folder, taking Patient 006 as an example. Within the BraTS20_Training_006 folder, we find 5 files, each corresponding to one MRI modality and the segmentation mask:
- BraTS20_Training_006_flair.nii
- BraTS20_Training_006_t1.nii
- BraTS20_Training_006_t1ce.nii
- BraTS20_Training_006_t2.nii
- BraTS20_Training_006_seg.nii
These files are in the .nii format, which represents the NIfTI format — one of the most prevalent in neuroimaging.
MRI Data Preparation
To handle NIfTI images in Python, we can use the NiBabel package. Below is an example of the data loading. By using the get_fdata() method we can interpret the image as a numpy array.
The array shape, [240, 240, 155], indicates a 3D volume comprising 240 2D slices within the x and y dimensions, and 155 slices in the z dimension. These dimensions correspond to distinct anatomical perspectives: the axial view (x-y plane), coronal view (x-z plane), and sagittal view (y-z plane).
To make it simpler, we will only employ 2D slices in axial plane, then the resulting images will have shape [240, 240] .
In order to meet the specifications of PyTorch models, input tensors must have shape [batch_size, num_channels, height, width]. In our MRI data, each of the four modalities (FLAIR, T1ce, T1, T2) emphasizes distinct features within the image, akin to channels in an image. To align our data format with PyTorch requirements, we’ll stack these sequences as channels achieving a [4, 240, 240] tensor.
PyTorch provide two data utilities, torch.utils.data.Dataset and torch.utils.data.DataLoader, designed for iterating over datasets and loading data in batches. The Dataset includes subclasses for various standard datasets like MNIST, CIFAR, or ImageNet. Importing these datasets involves loading the respective class and initializing the DataLoader.
Consider an example with the MNIST dataset:
This enables us to obtain the final tensor with dimensions [batch_size=32, num_channels=1, H=28, W=28] tensor .
Since we have a non-trivial dataset, the creation of a custom Dataset class is necessary before using the DataLoader. While the detailed steps for creating this custom dataset are not covered in this post, readers are referred to the PyTorch tutorial on Datasets & DataLoaders for comprehensive guidance.
PyTorch Model Preparation
PyTorch is a DL framework developed by Facebook AI researchers in 2017. The torchvision package contains popular datasets, image transformations and model architectures. In torchvision.models we can find the implementation of some DL architectures for different tasks, such as classification, segmentation or object detection.
For our application we will load the ResNet18architecture.
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
In a neural network, the input and output layers are inherently linked to the specific problem being addressed. PyTorch DL models typically expect 3-channel RGB images as input, as we can observe in the initial convolutional layer’s configuration Conv2d(in_channels = 3, out_channels = 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False). Additionaly, the final layer Linear(in_features = 512, out_features = 1000, bias = True) defaults to an output size of 1000, representing the number of classes in the ImageNet dataset.
Before training a classification model, it is essential to adjust in_channels and out_features to align with our specific dataset. We can access the first convolutional layer through resnet.conv1and update the in_channels. Similarly, we can access the last fully connected layer throughresnet.fc and modify the out_features.
ResNet(
(conv1): Conv2d(4, 64, kernel_size=(7, 7), stride=(1, 1))
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=2, bias=True)
)
Image Classification
With our model and data prepared for training, we can proceed to its practical application.
The following example illustrates how we can effectively utilize our ResNet18 for classifying an image into low or high-grade. To manage the batch size dimension in our tensor, we will simply add a unit dimension (note that this is normally handled by the dataloader).
tensor(0.5465, device='cuda:0', grad_fn=<NllLossBackward0>)
And that brings us to the end of the post. I hope this has been useful for those venturing into the intersection of MRI and Deep Learning. Thank you for taking the time to read. For deeper insights into my research, feel free to explore my recent paper! 🙂
Pitarch, C.; Ribas, V.; Vellido, A. AI-Based Glioma Grading for a Trustworthy Diagnosis: An Analytical Pipeline for Improved Reliability. Cancers 2023, 15, 3369. https://doi.org/10.3390/cancers15133369.
Unless otherwise noted, all images are by the author.
[1] Y. LeCun et al. “Backpropagation Applied to Handwritten Zip Code
Recognition”. In: Neural Computation 1 (4 Dec. 1989), pp. 541–551.
issn: 0899–7667. doi: 10.1162/NECO.1989.1.4.541.
[2] Alex Krizhevsky, Ilya Sutskever, and Geoffrey E. Hinton. “ImageNet
Classification with Deep Convolutional Neural Networks”. In: Advances
in Neural Information Processing Systems 25 (2012).
[3] Christian Szegedy et al. “Going Deeper with Convolutions”. In: Pro-
ceedings of the IEEE Computer Society Conference on Computer Vi-
sion and Pattern Recognition 07–12-June-2015 (Sept. 2014), pp. 1–9.
issn: 10636919. doi: 10.1109/CVPR.2015.7298594.
[4] Karen Simonyan and Andrew Zisserman. “Very Deep Convolutional
Networks for Large-Scale Image Recognition”. In: 3rd International
Conference on Learning Representations, ICLR 2015 — Conference Track
Proceedings (Sept. 2014).
[5] Kaiming He et al. “Deep Residual Learning for Image Recognition”.
In: Proceedings of the IEEE Computer Society Conference on Computer
Vision and Pattern Recognition 2016-December (Dec. 2015), pp. 770–
778. issn: 10636919. doi: 10.1109/CVPR.2016.90.
[6] Ashish Vaswani et al. “Attention Is All You Need”. In: Advances in
Neural Information Processing Systems 2017-December (June 2017), pp. 5999–6009. ISSN: 10495258. DOI: 10.48550/arxiv.1706.03762.
[7] Lisa M. DeAngelis. “Brain Tumors”. In: New England journal of medicine 344 (2 Aug. 2001), pp. 114–123. issn: 0028–4793. doi: 10.1056/NEJM200101113440207
Dealing with MRI and Deep Learning with Python was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.
Originally appeared here:
Dealing with MRI and Deep Learning with Python
Go Here to Read this Fast! Dealing with MRI and Deep Learning with Python