We often find ourselves in a situation where we have a pre-trained model for a certain task (e.g. cardiac segmentation) and we have a data set where we want to perform that task. We know the model was not trained on that specific dataset. So we wonder, how will the model perform? Is it usable at all? Do we need to pre-process our data in a certain way to use it?
In this case study we demonstrate how
misas helps answer these questions with a concrete example:
The used model was trained on UK Biobank cardiac imaging data to segment short-axis images of the heart into left ventricle (LV), right ventricle (RV) and myocardium (MY). For details about the model please read the paper (Bai et al. 2018) and cite it if you use it. For implementation, training and usage see the GitHub repository. We downloaded the pre-trained model for short-axis images from https://www.doc.ic.ac.uk/~wbai/data/ukbb_cardiac/trained_model/ (local copy in
example/kaggle/FCN_sa). In order to use it with
misas we need to wrap it in a class that implements the desired interface (
Image as input, see the main docu for more details).
ukbb_cardiac is written in
tensorflow v1. With
tensorflow v2 make sure to import the compat module.
from misas.tensorflow_model import ukbb_model, crop_pad_pil
model = ukbb_model('example/kaggle/FCN_sa')
INFO:tensorflow:Restoring parameters from example/kaggle/FCN_sa
The model requires images to be a multiple of 16 in each dimension. We pad images accordingly in
prepareSize. Additionally code in
image_to_input takes care of the specifics of transforming a three-channel image into a single-item batch of single-channel images. In
predict the output is converted to
The Data Science Bowl Cardiac Challenge Data consists of MRI cine images from 1140 patients in dicom format. Multiple slices in short axis are available for each patient. Additionally, end-systolic and end-diastolic volumes are given (the original Kaggle challenge asked participants to predict these from the images).
You can download and unpack the dataset from the above website. Some example images are included in the
pydicom to read the images and convert them to pillow
from pydicom import dcmread import numpy as np import matplotlib.pyplot as plt from PIL import Image from misas.core import default_cmap, default_cmap_true_mask #import warnings #warnings.filterwarnings('ignore')
We use the window information within the dicom file to scale pixel intensities accordingly.
def prepareImage(fname): ds = dcmread(fname) img = (ds.pixel_array.astype(np.int16)) img = (img/img.max()*255).astype(np.uint8) img = Image.fromarray(np.array(img)) return img.convert("RGB")
Okay, let's look at an example:
img = prepareImage("example/kaggle/117_sax_76_IM-11654-0005.dcm") plt.imshow(img) print(np.array(img).shape)
(256, 192, 3)
img = model.prepareSize(img) fig,ax = plt.subplots(figsize=(8,8)) plt.imshow(img) pred=model.predict(img) plt.imshow(pred, cmap=default_cmap, alpha=.5, interpolation="nearest")
<matplotlib.image.AxesImage at 0x7f341aefa790>
The model identified the left ventricle and myocardium partially. It failed to identify the right ventricle. So this result is promising in that it shows some kind of success but it is not usable as it is.
ukbb_cardiacnetwork. That network was trained specifically for UK Biobank images and not to be applied generally.
misascomes in. But first look at some more examples to see if we selected a bad example by chance.
from glob import glob
dcm_files = glob("example/kaggle/sample_images/*.dcm") fig, axs = plt.subplots(2,5, figsize=(20,10)) for i, ax in enumerate(axs.flatten()): tmp = model.prepareSize(prepareImage(dcm_files[i])) ax.imshow(tmp) #model.predict(tmp).show(ax=ax, cmap=default_cmap) ax.imshow(model.predict(tmp), cmap=default_cmap, vmax=3, alpha=.5, interpolation="nearest")
Apparently our first impression that it does not work out of the box applies to most images. This also shows a bit of the variety of images and quality and also some inconsistencies in orientation.
In order to get more detailed insights into what's happening we select a specific example and define two helper functions that will open the image and true mask for that example. We use functions to get the image from file again instead of loading the image once and passing it around to avoid working with an accidentally modified version of the image.
In this case we have the true mask from a previous experiment but it would be possible to create that mask manually as well, it just needs to be saved as png with pixel values
0 for background and
n for the
n classes. As there are only three clases in this case looking at the png image in a standard image viewer will look like a purely black image.
from misas.core import *
def img(): """ Opens the sample image as a PIL image """ return Image.open("example/kaggle/images/1-frame014-slice006.png").convert("RGB") def trueMask(): """ Opens the true mask as a PIL image """ return Image.open("example/kaggle/masks_full/1-frame014-slice006.png").convert("I")
There are eight posible orientations an image can be in (four rotations by 90° and a flipped variant for each). In misas a series with these 8 items is available via
dihedral_series = get_dihedral_series(img(),model, truth=trueMask())
plot_series(dihedral_series, nrow=2, figsize=(20,12), param_name="orientation", overlay_truth = True)