Case Study - Model Robustness

Aim: How robust is my network to shifts in data

We often find ourselves in a situation where we have trained a model for a certain task. The network performs well on the training and validation data. It also performs good on the hold out test set. Still you wonder what happens when you feed it data that is systematically different from all three sets? Does it make sense to re-train with more extensive data augmentation?

In this case study we demonstrate how misas helps answer these questions with a concrete example: - Model: Custom U-Net - Data: Small set of transversal CMR images

Prepare Model for misas

from misas.core import default_cmap
from misas.fastai_model import Fastai2_model

from PIL import Image
import itertools
import matplotlib.pyplot as plt
import numpy as np
import altair as alt
trainedModel = Fastai2_model('chfc-cmi/transversal-cmr-seg', 'b0_transversal_5_5', force_reload=False)
Using cache found in /home/markus/.cache/torch/hub/chfc-cmi_transversal-cmr-seg_master

Prepare Dataset for misas

Data is available as png images and masks which is just fine for misas

img = Image.open("example/b0/images/train_example.png").convert("RGB")
plt.imshow(img)
img.size
(128, 128)

How does the trained model perform on this (training) example?

Time to apply the model to the example image and see how it works (we need to call prepareSize manually here):

img = trainedModel.prepareSize(img)
fig,ax = plt.subplots(figsize=(4,4))
plt.imshow(img)
_ = plt.imshow(trainedModel.predict(img), cmap=default_cmap, alpha=.5, interpolation="nearest")

So how does it perform on validation data?

from glob import glob
files = sorted(glob("example/b0/images/val*.png"))
fig, axs = plt.subplots(1,5, figsize=(20,10))
for i, ax in enumerate(axs.flatten()):
    fname = files[i]
    tmp = trainedModel.prepareSize(Image.open(fname).convert("RGB"))
    ax.imshow(tmp)
    ax.imshow(trainedModel.predict(tmp), cmap=default_cmap, alpha=.5, interpolation="nearest")

This is not great. But given the limited training data it looks decent. So let’s have a closer look on how robust this model is. In particular to differences we might encounter when applying this network to new data.

Robustness to basic transformations

from misas.core import *
img = lambda: Image.open(files[0]).convert("RGB")
trueMask = lambda: Image.open(files[0].replace("image","mask")).convert("I")

Sensitivity to orientation

Changes in orientation are very common. Not because it is common to acquire images in different orientation but because the way data is stored in different file formats like nifti and dicom differs. So it is interesting to see how the model works in all possible orientations (including flips).

plot_series(get_dihedral_series(img(),trainedModel), nrow=2, figsize=(20,12))

results = eval_dihedral_series(img(),trueMask(),trainedModel,components=["bg","LV","MY"])
results
k bg LV MY
0 0 0.995824 0.782609 0.765604
1 1 0.963374 0.000000 0.000000
2 2 0.977227 0.000000 0.029690
3 3 0.985537 0.657269 0.206577
4 4 0.979158 0.000000 0.000000
5 5 0.987976 0.625101 0.439119
6 6 0.979158 0.000000 0.000000
7 7 0.992134 0.728850 0.640967

Not surprisingly, the model is very sensitive to changes in orientation. So when using this model it is very important to feed the images in the proper orientation.

Another really interesting thing is that the heart is never properly segmented when images are flipped horizontally, so the heart is on the left instead of the right side. This gives a strong indication that the location within the image is one of the features the network has learned. Depending on your use case this might indeed be a sensible feature to use for segmentation of the heart as in a huge majority of cases the left ventricle of the heart is on the left side of the chest (so showing up on the right side in transversal slices).

Sensitivity to rotation

There should not be a huge variation in rotation (by small angles) when working with transversal slices. Still it is a good idea to get an impression of how quickly segmentation performance decreases with deviations in rotation.

plot_series(get_rotation_series(img(),trainedModel, step=30, end=360), nrow=2)

results = eval_rotation_series(img(),trueMask(),trainedModel,start=-180,end=180,components=["bg","LV","MY"])
(alt
 .Chart(results.melt(id_vars=['deg'],value_vars=['LV','MY']))
 .mark_line()
 .encode(
     x=alt.X("deg",axis=alt.Axis(title=None)),
     y=alt.Y("value",axis=alt.Axis(title=None)),
     color=alt.Color("variable",legend=None),
     tooltip="value"
 )
 .properties(width=700,height=300)
 .interactive()
)

So there is quite a range (from -40 to 80 degrees) where prediction performance remains stable. This is sufficient not to worry about minor deviations.

Let’s have another look at the network moving to the wrong side of thorax when predicting on rotated images:

gif_series(
    get_rotation_series(img(),trainedModel, start=0, end=360,step=10),
    "example/b0/rotation.gif",
    param_name="deg",
    duration=400
)

segmentation sensitivity to rotation

Sensitivity to cropping

Another variation that might occur in real life is a difference in field of view. This can happen due to different settings when acquiring the images or due to pre-processing steps in an analysis pipeline.

plot_series(get_crop_series(img(),trainedModel, start = 5, end = 60, step = 10), nrow=2)

gif_series(
    get_crop_series(img(),trainedModel, start=60, end=5, step=-5),
    "example/b0/crop.gif",
    param_name="pixels",
    duration=400
)

segmentation sensitivity to rotation

This looks quite good. It seems to be okay to crop the image as long as the whole heart remains intact. As soon as we start to crop part of the heart the model is no longer able to find it (this is expected). It also does not start to predict heart somewhere, where it should not when cropping even further.

results = eval_crop_series(img(),trueMask(),trainedModel,start = 5, components=["bg","LV","MY"])
(alt
 .Chart(results.melt(id_vars=['pixels'],value_vars=['LV','MY']))
 .mark_line()
 .encode(
     x="pixels",
     y="value",
     color="variable",
     tooltip="value"
 )
 .properties(width=700,height=300)
 .interactive()
)

The dice scores of 1 for very small sizes is because the model is not supposed to predict anything and it is not predicting anything. It drops to 0 when the heart starts to appear on the image but the model is still unable to locate it and then raises to the final performance it has on the whole image. Reaching a plateau at a size of 160px.

Sensitivity to brightness

plot_series(get_brightness_series(img(),trainedModel, start=np.sqrt(2)/8), nrow=2)

results = eval_bright_series(img(),trueMask(),trainedModel, components=["bg","LV","MY"])
(alt
 .Chart(results.melt(id_vars=['brightness'],value_vars=['LV','MY']))
 .mark_line()
 .encode(
     x="brightness",
     y="value",
     color="variable",
     tooltip="value"
 )
 .properties(width=700,height=300)
 .interactive()
)

Sensitivity to contrast

plot_series(get_contrast_series(img(),trainedModel, start=1/8, end=np.sqrt(2)*8), nrow = 2)

results = eval_contrast_series(img(),trueMask(),trainedModel, components=["bg","LV","MY"])
(alt
 .Chart(results.melt(id_vars=['contrast'],value_vars=['LV','MY']))
 .mark_line(point=True)
 .encode(
     x=alt.X(
     "contrast",
     scale=alt.Scale(type="log")),
     y="value",
     color="variable",
     tooltip="value",
 )
 .properties(width=700,height=300)
 .interactive()
)

Sensitivity to zoom

plot_series(get_zoom_series(img(),trainedModel))

results = eval_zoom_series(img(),trueMask(),trainedModel,components=["bg","LV","MY"])
(alt
 .Chart(results.melt(id_vars=['scale'],value_vars=['LV','MY']))
 .mark_line()
 .encode(
     x="scale",
     y="value",
     color="variable",
     tooltip="value"
 )
 .properties(width=700,height=300)
 .interactive()
)
gif_series(
    get_zoom_series(img(),trainedModel),
    "example/b0/zoom.gif",
    param_name="scale",
    duration=400
)

segmentation sensitivity to zoom

Robustness to MR artifacts

Spike artifact

Spike artifacts can happen with different intensities and at different locations in k-space. It is even possible to have multiple spikes.

from misas.mri import *

First we consider a single spike quite far from the center of k-space.

plot_series(get_spike_series(img(),trainedModel))

Segmentation performance is heavily impacted by this kind of artifact. The training examples did not have a single example with this herringbone pattern so we even get striped predictions for the myocardium.

Let’s have a look how the location of the spike in k-space changes the artifact and model performance. From top to bottom, moving farter from the center.

for i in [.51,.55,.6,.75]:
    plot_series(get_spike_series(img(),trainedModel,spikePosition=[i,i]), param_name="intensity")

So spikes closer to the center of k-space create artifacts with lower frequency and have less severe impact on segmentation performance.

So far we only looked at spikes on the diagonal of k-space, for the sake of completeness we can also look at arbitrary locations in k-space.

fig, axs = plt.subplots(7,7,figsize=(16,16))
values = [.4,.45,.49,.5,.51,.55,.6]
values=list(itertools.product(values, values))
for x, ax in zip(values, axs.flatten()):
    im = trainedModel.prepareSize(img())
    im = spikeTransform(im,.5,list(x))
    ax.imshow(np.array(im))
    ax.imshow(np.array(trainedModel.predict(im)), cmap=default_cmap, alpha=.5, interpolation="nearest")
    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)

\(B_0\)-Field inhomogeneity

\(B_0\)-Field inhomogeneieties are quite common, particularly at higher field strength. Adjusting these inhomogeneities at ultra high field strength (shimming) is an active field of research (Hock et al. 2020). So what is the impact of this so-called Bias field:

plot_series(get_biasfield_series(img(),trainedModel))

The model works (at least on this example) reliably for even very intense field inhomogeneities.