from misas.core import default_cmap
from misas.fastai_model import Fastai2_model
from PIL import Image
import itertools
import matplotlib.pyplot as plt
import numpy as np
import altair as alt
Case Study - Model Robustness
Aim: How robust is my network to shifts in data
We often find ourselves in a situation where we have trained a model for a certain task. The network performs well on the training and validation data. It also performs good on the hold out test set. Still you wonder what happens when you feed it data that is systematically different from all three sets? Does it make sense to re-train with more extensive data augmentation?
In this case study we demonstrate how misas
helps answer these questions with a concrete example: - Model: Custom U-Net - Data: Small set of transversal CMR images
Prepare Model for misas
= Fastai2_model('chfc-cmi/transversal-cmr-seg', 'b0_transversal_5_5', force_reload=False) trainedModel
Using cache found in /home/markus/.cache/torch/hub/chfc-cmi_transversal-cmr-seg_master
Prepare Dataset for misas
Data is available as png images and masks which is just fine for misas
= Image.open("example/b0/images/train_example.png").convert("RGB")
img
plt.imshow(img) img.size
(128, 128)
How does the trained model perform on this (training) example?
Time to apply the model to the example image and see how it works (we need to call prepareSize
manually here):
= trainedModel.prepareSize(img)
img = plt.subplots(figsize=(4,4))
fig,ax
plt.imshow(img)= plt.imshow(trainedModel.predict(img), cmap=default_cmap, alpha=.5, interpolation="nearest") _
So how does it perform on validation data?
from glob import glob
= sorted(glob("example/b0/images/val*.png")) files
= plt.subplots(1,5, figsize=(20,10))
fig, axs for i, ax in enumerate(axs.flatten()):
= files[i]
fname = trainedModel.prepareSize(Image.open(fname).convert("RGB"))
tmp
ax.imshow(tmp)=default_cmap, alpha=.5, interpolation="nearest") ax.imshow(trainedModel.predict(tmp), cmap
This is not great. But given the limited training data it looks decent. So let’s have a closer look on how robust this model is. In particular to differences we might encounter when applying this network to new data.
Robustness to basic transformations
from misas.core import *
= lambda: Image.open(files[0]).convert("RGB")
img = lambda: Image.open(files[0].replace("image","mask")).convert("I") trueMask
Sensitivity to orientation
Changes in orientation are very common. Not because it is common to acquire images in different orientation but because the way data is stored in different file formats like nifti and dicom differs. So it is interesting to see how the model works in all possible orientations (including flips).
=2, figsize=(20,12)) plot_series(get_dihedral_series(img(),trainedModel), nrow
= eval_dihedral_series(img(),trueMask(),trainedModel,components=["bg","LV","MY"])
results results
k | bg | LV | MY | |
---|---|---|---|---|
0 | 0 | 0.995824 | 0.782609 | 0.765604 |
1 | 1 | 0.963374 | 0.000000 | 0.000000 |
2 | 2 | 0.977227 | 0.000000 | 0.029690 |
3 | 3 | 0.985537 | 0.657269 | 0.206577 |
4 | 4 | 0.979158 | 0.000000 | 0.000000 |
5 | 5 | 0.987976 | 0.625101 | 0.439119 |
6 | 6 | 0.979158 | 0.000000 | 0.000000 |
7 | 7 | 0.992134 | 0.728850 | 0.640967 |
Not surprisingly, the model is very sensitive to changes in orientation. So when using this model it is very important to feed the images in the proper orientation.
Another really interesting thing is that the heart is never properly segmented when images are flipped horizontally, so the heart is on the left instead of the right side. This gives a strong indication that the location within the image is one of the features the network has learned. Depending on your use case this might indeed be a sensible feature to use for segmentation of the heart as in a huge majority of cases the left ventricle of the heart is on the left side of the chest (so showing up on the right side in transversal slices).
Sensitivity to rotation
There should not be a huge variation in rotation (by small angles) when working with transversal slices. Still it is a good idea to get an impression of how quickly segmentation performance decreases with deviations in rotation.
=30, end=360), nrow=2) plot_series(get_rotation_series(img(),trainedModel, step
= eval_rotation_series(img(),trueMask(),trainedModel,start=-180,end=180,components=["bg","LV","MY"]) results
(alt=['deg'],value_vars=['LV','MY']))
.Chart(results.melt(id_vars
.mark_line()
.encode(=alt.X("deg",axis=alt.Axis(title=None)),
x=alt.Y("value",axis=alt.Axis(title=None)),
y=alt.Color("variable",legend=None),
color="value"
tooltip
)=700,height=300)
.properties(width
.interactive() )
So there is quite a range (from -40 to 80 degrees) where prediction performance remains stable. This is sufficient not to worry about minor deviations.
Let’s have another look at the network moving to the wrong side of thorax when predicting on rotated images:
gif_series(=0, end=360,step=10),
get_rotation_series(img(),trainedModel, start"example/b0/rotation.gif",
="deg",
param_name=400
duration )
Sensitivity to cropping
Another variation that might occur in real life is a difference in field of view. This can happen due to different settings when acquiring the images or due to pre-processing steps in an analysis pipeline.
= 5, end = 60, step = 10), nrow=2) plot_series(get_crop_series(img(),trainedModel, start
gif_series(=60, end=5, step=-5),
get_crop_series(img(),trainedModel, start"example/b0/crop.gif",
="pixels",
param_name=400
duration )
This looks quite good. It seems to be okay to crop the image as long as the whole heart remains intact. As soon as we start to crop part of the heart the model is no longer able to find it (this is expected). It also does not start to predict heart somewhere, where it should not when cropping even further.
= eval_crop_series(img(),trueMask(),trainedModel,start = 5, components=["bg","LV","MY"]) results
(alt=['pixels'],value_vars=['LV','MY']))
.Chart(results.melt(id_vars
.mark_line()
.encode(="pixels",
x="value",
y="variable",
color="value"
tooltip
)=700,height=300)
.properties(width
.interactive() )
The dice scores of 1 for very small sizes is because the model is not supposed to predict anything and it is not predicting anything. It drops to 0 when the heart starts to appear on the image but the model is still unable to locate it and then raises to the final performance it has on the whole image. Reaching a plateau at a size of 160px.
Sensitivity to brightness
=np.sqrt(2)/8), nrow=2) plot_series(get_brightness_series(img(),trainedModel, start
= eval_bright_series(img(),trueMask(),trainedModel, components=["bg","LV","MY"]) results
(alt=['brightness'],value_vars=['LV','MY']))
.Chart(results.melt(id_vars
.mark_line()
.encode(="brightness",
x="value",
y="variable",
color="value"
tooltip
)=700,height=300)
.properties(width
.interactive() )
Sensitivity to contrast
=1/8, end=np.sqrt(2)*8), nrow = 2) plot_series(get_contrast_series(img(),trainedModel, start
= eval_contrast_series(img(),trueMask(),trainedModel, components=["bg","LV","MY"]) results
(alt=['contrast'],value_vars=['LV','MY']))
.Chart(results.melt(id_vars=True)
.mark_line(point
.encode(=alt.X(
x"contrast",
=alt.Scale(type="log")),
scale="value",
y="variable",
color="value",
tooltip
)=700,height=300)
.properties(width
.interactive() )
Sensitivity to zoom
plot_series(get_zoom_series(img(),trainedModel))
= eval_zoom_series(img(),trueMask(),trainedModel,components=["bg","LV","MY"]) results
(alt=['scale'],value_vars=['LV','MY']))
.Chart(results.melt(id_vars
.mark_line()
.encode(="scale",
x="value",
y="variable",
color="value"
tooltip
)=700,height=300)
.properties(width
.interactive() )
gif_series(
get_zoom_series(img(),trainedModel),"example/b0/zoom.gif",
="scale",
param_name=400
duration )
Robustness to MR artifacts
Spike artifact
Spike artifacts can happen with different intensities and at different locations in k-space. It is even possible to have multiple spikes.
from misas.mri import *
First we consider a single spike quite far from the center of k-space.
plot_series(get_spike_series(img(),trainedModel))
Segmentation performance is heavily impacted by this kind of artifact. The training examples did not have a single example with this herringbone pattern so we even get striped predictions for the myocardium.
Let’s have a look how the location of the spike in k-space changes the artifact and model performance. From top to bottom, moving farter from the center.
for i in [.51,.55,.6,.75]:
=[i,i]), param_name="intensity") plot_series(get_spike_series(img(),trainedModel,spikePosition
So spikes closer to the center of k-space create artifacts with lower frequency and have less severe impact on segmentation performance.
So far we only looked at spikes on the diagonal of k-space, for the sake of completeness we can also look at arbitrary locations in k-space.
= plt.subplots(7,7,figsize=(16,16))
fig, axs = [.4,.45,.49,.5,.51,.55,.6]
values =list(itertools.product(values, values))
valuesfor x, ax in zip(values, axs.flatten()):
= trainedModel.prepareSize(img())
im = spikeTransform(im,.5,list(x))
im
ax.imshow(np.array(im))=default_cmap, alpha=.5, interpolation="nearest")
ax.imshow(np.array(trainedModel.predict(im)), cmapFalse)
ax.axes.xaxis.set_visible(False) ax.axes.yaxis.set_visible(
\(B_0\)-Field inhomogeneity
\(B_0\)-Field inhomogeneieties are quite common, particularly at higher field strength. Adjusting these inhomogeneities at ultra high field strength (shimming) is an active field of research (Hock et al. 2020). So what is the impact of this so-called Bias field:
plot_series(get_biasfield_series(img(),trainedModel))
The model works (at least on this example) reliably for even very intense field inhomogeneities.