Aim: Can I use a given model on a given dataset?

We often find ourselves in a situation where we have a pre-trained model for a certain task (e.g. cardiac segmentation) and we have a data set where we want to perform that task. We know the model was not trained on that specific dataset. So we wonder, how will the model perform? Is it usable at all? Do we need to pre-process our data in a certain way to use it?

In this case study we demonstrate how misas helps answer these questions with a concrete example:

Prepare Model for misas

The used model was trained on UK Biobank cardiac imaging data to segment short-axis images of the heart into left ventricle (LV), right ventricle (RV) and myocardium (MY). For details about the model please read the paper (Bai et al. 2018) and cite it if you use it. For implementation, training and usage see the GitHub repository. We downloaded the pre-trained model for short-axis images from https://www.doc.ic.ac.uk/~wbai/data/ukbb_cardiac/trained_model/ (local copy in example/kaggle/FCN_sa). In order to use it with misas we need to wrap it in a class that implements the desired interface (prepareSize and predict taking Image as input, see the main docu for more details).

ukbb_cardiac is written in tensorflow v1. With tensorflow v2 make sure to import the compat module.

from misas.tensorflow_model import ukbb_model, crop_pad_pil
model = ukbb_model('example/kaggle/FCN_sa')
INFO:tensorflow:Restoring parameters from example/kaggle/FCN_sa

The model requires images to be a multiple of 16 in each dimension. We pad images accordingly in prepareSize. Additionally code in image_to_input takes care of the specifics of transforming a three-channel image into a single-item batch of single-channel images. In predict the output is converted to ImageSegment class.

Prepare Dataset for misas

The Data Science Bowl Cardiac Challenge Data consists of MRI cine images from 1140 patients in dicom format. Multiple slices in short axis are available for each patient. Additionally, end-systolic and end-diastolic volumes are given (the original Kaggle challenge asked participants to predict these from the images).

You can download and unpack the dataset from the above website. Some example images are included in the example/kaggle/dicom folder.

We use pydicom to read the images and convert them to pillow Image objects.

from pydicom import dcmread
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from misas.core import default_cmap, default_cmap_true_mask
#import warnings
#warnings.filterwarnings('ignore')

We use the window information within the dicom file to scale pixel intensities accordingly.

def prepareImage(fname):
    ds = dcmread(fname)
    img = (ds.pixel_array.astype(np.int16))
    img = (img/img.max()*255).astype(np.uint8)
    img = Image.fromarray(np.array(img))
    return img.convert("RGB")

Okay, let's look at an example:

img = prepareImage("example/kaggle/117_sax_76_IM-11654-0005.dcm")
plt.imshow(img)
print(np.array(img).shape)
(256, 192, 3)

How does the model perform out of the box?

Time to apply the model to the example image and see how it works (we need to call prepareSize manually here):

img = model.prepareSize(img)
fig,ax = plt.subplots(figsize=(8,8))
plt.imshow(img)
pred=model.predict(img)
plt.imshow(pred, cmap=default_cmap, alpha=.5, interpolation="nearest")
<matplotlib.image.AxesImage at 0x7f341aefa790>

The model identified the left ventricle and myocardium partially. It failed to identify the right ventricle. So this result is promising in that it shows some kind of success but it is not usable as it is.

Still, we might be able to use it on the kaggle dataset if we understand why it fails and properly pre-process the data. This is where misas comes in. But first look at some more examples to see if we selected a bad example by chance.

from glob import glob
dcm_files = glob("example/kaggle/sample_images/*.dcm")
fig, axs = plt.subplots(2,5, figsize=(20,10))
for i, ax in enumerate(axs.flatten()):
    tmp = model.prepareSize(prepareImage(dcm_files[i]))
    ax.imshow(tmp)
    #model.predict(tmp)[0].show(ax=ax, cmap=default_cmap)
    ax.imshow(model.predict(tmp), cmap=default_cmap, vmax=3, alpha=.5, interpolation="nearest")

Apparently our first impression that it does not work out of the box applies to most images. This also shows a bit of the variety of images and quality and also some inconsistencies in orientation.

Analysis of one image

In order to get more detailed insights into what's happening we select a specific example and define two helper functions that will open the image and true mask for that example. We use functions to get the image from file again instead of loading the image once and passing it around to avoid working with an accidentally modified version of the image.

In this case we have the true mask from a previous experiment but it would be possible to create that mask manually as well, it just needs to be saved as png with pixel values 0 for background and 1 to n for the n classes. As there are only three clases in this case looking at the png image in a standard image viewer will look like a purely black image.

from misas.core import *
def img():
    """
    Opens the sample image as a PIL image
    """
    return Image.open("example/kaggle/images/1-frame014-slice006.png").convert("RGB")

def trueMask():
    """
    Opens the true mask as a PIL image
    """
    return Image.open("example/kaggle/masks_full/1-frame014-slice006.png").convert("I")

Sensitivity to orientation

There are eight posible orientations an image can be in (four rotations by 90° and a flipped variant for each). In misas a series with these 8 items is available via get_dihedral_series:

dihedral_series = get_dihedral_series(img(),model, truth=trueMask())
plot_series(dihedral_series, nrow=2, figsize=(20,12), param_name="orientation", overlay_truth = True)