Aim: How robust is my network to shifts in data

We often find ourselves in a situation where we have trained a model for a certain task. The network performs well on the training and validation data. It also performs good on the hold out test set. Still you wonder what happens when you feed it data that is systematically different from all three sets? Does it make sense to re-train with more extensive data augmentation?

In this case study we demonstrate how misas helps answer these questions with a concrete example:

  • Model: Custom U-Net
  • Data: Small set of transversal CMR images

Prepare Model for misas

from misas.core import default_cmap
from misas.fastai_model import Fastai2_model

from PIL import Image
import itertools
import matplotlib.pyplot as plt
import numpy as np
import altair as alt
trainedModel = Fastai2_model('chfc-cmi/transversal-cmr-seg', 'b0_transversal_5_5', force_reload=False)
Using cache found in /home/markus/.cache/torch/hub/chfc-cmi_transversal-cmr-seg_master

Prepare Dataset for misas

Data is available as png images and masks which is just fine for misas

img = Image.open("example/b0/images/train_example.png").convert("RGB")
plt.imshow(img)
img.size
(128, 128)

How does the trained model perform on this (training) example?

Time to apply the model to the example image and see how it works (we need to call prepareSize manually here):

img = trainedModel.prepareSize(img)
fig,ax = plt.subplots(figsize=(4,4))
plt.imshow(img)
_ = plt.imshow(trainedModel.predict(img), cmap=default_cmap, alpha=.5, interpolation="nearest")

So how does it perform on validation data?

from glob import glob
files = sorted(glob("example/b0/images/val*.png"))
fig, axs = plt.subplots(1,5, figsize=(20,10))
for i, ax in enumerate(axs.flatten()):
    fname = files[i]
    tmp = trainedModel.prepareSize(Image.open(fname).convert("RGB"))
    ax.imshow(tmp)
    ax.imshow(trainedModel.predict(tmp), cmap=default_cmap, alpha=.5, interpolation="nearest")

This is not great. But given the limited training data it looks decent. So let's have a closer look on how robust this model is. In particular to differences we might encounter when applying this network to new data.

Robustness to basic transformations

from misas.core import *
img = lambda: Image.open(files[0]).resize((256,256)).convert("RGB")
trueMask = lambda: Image.open(files[0].replace("image","mask")).convert("I")

Sensitivity to orientation

Changes in orientation are very common. Not because it is common to acquire images in different orientation but because the way data is stored in different file formats like nifti and dicom differs. So it is interesting to see how the model works in all possible orientations (including flips).

plot_series(get_dihedral_series(img(),trainedModel), nrow=2, figsize=(20,12))
results = eval_dihedral_series(img(),trueMask(),trainedModel,components=["bg","LV","MY"])
results
k bg LV MY
0 0 0.995824 0.782609 0.765604
1 1 0.963374 0.000000 0.000000
2 2 0.977961 0.000000 0.070302
3 3 0.986643 0.735463 0.237958
4 4 0.979158 0.000000 0.000000
5 5 0.987976 0.625101 0.439119
6 6 0.979158 0.000000 0.000000
7 7 0.992454 0.740951 0.650998

Not surprisingly, the model is very sensitive to changes in orientation. So when using this model it is very important to feed the images in the proper orientation.

Another really interesting thing is that the heart is never properly segmented when images are flipped horizontally, so the heart is on the left instead of the right side. This gives a strong indication that the location within the image is one of the features the network has learned. Depending on your use case this might indeed be a sensible feature to use for segmentation of the heart as in a huge majority of cases the left ventricle of the heart is on the left side of the chest (so showing up on the right side in transversal slices).

Sensitivity to rotation

There should not be a huge variation in rotation (by small angles) when working with transversal slices. Still it is a good idea to get an impression of how quickly segmentation performance decreases with deviations in rotation.

plot_series(get_rotation_series(img(),trainedModel, step=30), nrow=2)
results = eval_rotation_series(img(),trueMask(),trainedModel,start=-180,end=180,components=["bg","LV","MY"])
(alt
 .Chart(results.melt(id_vars=['deg'],value_vars=['LV','MY']))
 .mark_line()
 .encode(
     x=alt.X("deg",axis=alt.Axis(title=None)),
     y=alt.Y("value",axis=alt.Axis(title=None)),
     color=alt.Color("variable",legend=None),
     tooltip="value"
 )
 .properties(width=700,height=300)
 .interactive()
)

So there is quite a range (from -40 to 80 degrees) where prediction performance remains stable. This is sufficient not to worry about minor deviations.

Let's have another look at the network moving to the wrong side of thorax when predicting on rotated images:

gif_series(
    get_rotation_series(img(),trainedModel, start=0, end=360,step=10),
    "example/b0/rotation.gif",
    param_name="deg",
    duration=400
)

segmentation sensitivity to rotation

Sensitivity to cropping

Another variation that might occur in real life is a difference in field of view. This can happen due to different settings when acquiring the images or due to pre-processing steps in an analysis pipeline.

plot_series(get_crop_series(img(),trainedModel, start = 10, end = 120, step = 20), nrow=2)
gif_series(
    get_crop_series(img(),trainedModel, start=120, end=10,step=-10),
    "example/b0/crop.gif",
    param_name="pixels",
    duration=400
)

segmentation sensitivity to rotation

This looks quite good. It seems to be okay to crop the image as long as the whole heart remains intact. As soon as we start to crop part of the heart the model is no longer able to find it (this is expected). It also does not start to predict heart somewhere, where it should not when cropping even further.

results = eval_crop_series(img(),trueMask(),trainedModel,start = 5, components=["bg","LV","MY"])
(alt
 .Chart(results.melt(id_vars=['pixels'],value_vars=['LV','MY']))
 .mark_line()
 .encode(
     x="pixels",
     y="value",
     color="variable",
     tooltip="value"
 )
 .properties(width=700,height=300)
 .interactive()
)

The dice scores of 1 for very small sizes is because the model is not supposed to predict anything and it is not predicting anything. It drops to 0 when the heart starts to appear on the image but the model is still unable to locate it and then raises to the final performance it has on the whole image. Reaching a plateau at a size of 160px.

Sensitivity to brightness

plot_series(get_brightness_series(img(),trainedModel, start=np.sqrt(2)/8), nrow=2)