Case Study: MyoPS - myocardial pathology segmentation

Aim: Apply to MyoPS data

This is a more complex application of misas to a multi-channel input model with multiple output classes. It uses data from the Myocardial pathology segmentation combining multi-sequence CMR challenge (MyoPS 2020).

General results are published in “Myocardial Pathology Segmentation Combining Multi-Sequence Cardiac Magnetic Resonance Images.” First Challenge, MyoPS 2020, Held in Conjunction with MICCAI 2020, Lima, Peru, October 4, 2020, Proceedings and the specific model is described in > Ankenbrand M.J., Lohr D., Schreiber L.M. (2020) “Exploring Ensemble Applications for Multi-sequence Myocardial Pathology Segmentation.” In: Zhuang X., Li L. (eds) Myocardial Pathology Segmentation Combining Multi-Sequence Cardiac Magnetic Resonance Images. MyoPS 2020. Lecture Notes in Computer Science, vol 12554. Springer, Cham.

As this model uses a specific development version of fastai v2 and data from the challenge can not be freely shared it is much harder to reproduce the results from this notebook. You need to request the data from the challenge website, download the model from zenodo

DOI and install the specific versions of packages listed at the bottom of the page (other versions might work too but are untested).

Prepare Model for misas

from import *
from import resnet34
import warnings
class AddMaskCodeMapping(Transform):
    "Add mapping of pixel value to class for a `TensorMask`"
    def __init__(self, mapping, codes=None):
        self.mapping = mapping = codes
        if codes is not None: self.vocab,self.c = codes,len(codes)

    def encodes(self, o:PILMask):
        mo = ToTensor()(o)
        mo =
        mo = self.mapping.index_select(0,mo.flatten()).reshape(*mo.shape)
        mo = PILMask.create(
        return mo
    def decodes(self, o:TensorMask):
        # decoding of inputs works out of the box, but get_preds are not properly decoded
        if len(o.shape) > 2:
            o = o.argmax(dim=0)
        if is not None: o._meta = {'codes':}
        return o
def MappedMaskBlock(mapping,codes=None):
    "A `TransformBlock` for segmentation masks, with mapping of pixel values to classes, potentially with `codes`"
    return TransformBlock(type_tfms=PILMask.create, item_tfms=AddMaskCodeMapping(mapping=mapping,codes=codes), batch_tfms=IntToFloatTensor)
def getMappedMaskBlock(predefined_mapping_name):
    predefined_mappings = {
        'full': ([0,1,2,3,4,5],['bg','lv','my','rv','ed','sc']),
        'edOnly': ([0,0,0,0,1,0],['bg','ed']),
        'edScCombined': ([0,0,0,0,1,1],['bg','edSc']),
        'scOnly': ([0,0,0,0,0,1],['bg','sc']),
        'edScOnly': ([0,0,0,0,1,2],['bg','ed','sc']),
    mapping,codes = predefined_mappings[predefined_mapping_name]
    return MappedMaskBlock(mapping = torch.LongTensor(mapping), codes=codes)
def get_train_files(path):
    items = get_image_files(path)
    items = L([x for x in items if"1")])
    return items
def getMyopsDls(mapping_name="full", images="images", path="/storage/biomeds/data/myops/"):
    mmb = getMappedMaskBlock(mapping_name)
    myopsData = DataBlock(blocks=(ImageBlock, mmb),#['bg','lv','my','rv','ed','sc'])),
        splitter=FuncSplitter(lambda o: False),
        get_y=lambda o: str(o).replace(images,"masks"),
    dls = myopsData.dataloaders(f'{path}/{images}',num_workers=4,batch_size=12)
    dls[1].bs = 12
    return dls
def multi_dice(input:Tensor, targs:Tensor, class_id=0, inverse=False)->Tensor:
    n = targs.shape[0]
    input = input.argmax(dim=1).view(n,-1)
    # replace all with class_id with 1 all else with 0 to have binary case
    output = (input == class_id).float()
    # same for targs
    targs = (targs.view(n,-1) == class_id).float()
    if inverse:
        output = 1 - output
        targs = 1 - targs
    intersect = (output * targs).sum(dim=1).float()
    union = (output+targs).sum(dim=1).float()
    res = 2. * intersect / union
    res[torch.isnan(res)] = 1
    return res.mean()

def diceFG(input, targs): return multi_dice(input,targs,class_id=1)
def diceLV(input, targs): return multi_dice(input,targs,class_id=1)
def diceMY(input, targs): return multi_dice(input,targs,class_id=2)
def diceRV(input, targs): return multi_dice(input,targs,class_id=3)
def diceEd(input, targs): return multi_dice(input,targs,class_id=4)
def diceSc(input, targs): return multi_dice(input,targs,class_id=5)
dices = [diceLV,diceMY,diceRV,diceEd,diceSc]
getMyopsDls("full", "images")
learn = unet_learner(
        getMyopsDls("full", "images"),
[W NNPACK.cpp:51] Could not initialize NNPACK! Reason: Unsupported hardware.

Prepare Dataset for misas

Data is available as png images and masks which is just fine for misas

from misas.core import default_cmap
from PIL import ImageOps
img = lambda:"101-orig-4.png").convert("RGB")
trueMask = lambda:"101-orig-4.png").convert("I")
plt.imshow(np.array(trueMask()), cmap=default_cmap)

class Fastai2_model:
    def __init__(self, learner):
        self.trainedModel = learner
    def prepareSize(self, item): #, asPIL=False):     
        image = ImageOps.crop(item, (np.floor(to_cut_w), np.floor(to_cut_h), np.ceil(to_cut_w), np.ceil(to_cut_h)))
        return image
    def predict(self, image):
        image = self.prepareSize(image)#, #asPIL=True)
        image = PILImage.create(np.array(image))
        with self.trainedModel.no_bar():
            mask = self.trainedModel.predict(image)[0]#(pilimg) #mask,probs,rest 
        output = Image.fromarray(np.array(mask).astype(np.uint8))
        return output #mask, probs
model = Fastai2_model(learn.load("../../Downloads/multi_ce_full"))

How does the trained model perform on this (training) example?

#Time to apply the model to the example image and see how it works (we need to call prepareSize manually here):

from misas.core import *
from misas.core import default_cmap


mask = model.predict(img())

fig,ax = plt.subplots(figsize=(4.5,4.5))
plt.imshow(mask, cmap=default_cmap)

fig,ax = plt.subplots(figsize=(4.5,4.5))
plt.imshow(np.array(model.prepareSize((trueMask()))), cmap=default_cmap)

Robustness to basic transformations

#img = lambda: open_image(files[0]).resize(256)
#trueMask = lambda: open_mask(files[0].replace("image","mask"))

Sensitivity to orientation

Changes in orientation are very common. Not because it is common to acquire images in different orientation but because the way data is stored in different file formats like nifti and dicom differs. So it is interesting to see how the model works in all possible orientations (including flips).

dihed = get_dihedral_series(img(),model)
plot_series(dihed, nrow=2, figsize=(20,12))

#plt.imshow(np.array(dihed [0][2].convert ("I")))

Sensitivity to rotation

Let’s get an impression of how quickly segmentation performance decreases with deviations in rotation.

plot_series(get_rotation_series(img(),model, step=30), nrow=2, param_name="deg")

plot_series(get_rotation_series(img(),model, step=60), nrow=1, param_name="deg")

results = eval_rotation_series(img(),trueMask(),model,start=-180,end=180,components=["bg","LV","MY","RV","edema","scar"])
import altair as alt

So the range where prediction performance remains stable is quite large for most classes. However the rarer pathology classes scar and particularly edema react much more sensitively to rotation.

    get_rotation_series(img(),model, start=0, end=360,step=10),

segmentation sensitivity to rotation

Sensitivity to cropping

Another variation that might occur in real life is a difference in field of view. This can happen due to different settings when acquiring the images or due to pre-processing steps in an analysis pipeline.

plot_series(get_crop_series(img(),model, start = 50, end = 230, step = 10, finalSize=400), nrow=2, vmax=5)

plot_series(get_crop_series(img(),model, start = 80, end = 230, step = 20, finalSize=400), nrow=1, vmax=5)

    get_crop_series(img(),model, start=50, end=250,step=10),

segmentation sensitivity to rotation - myops

It seems to be okay to crop the image to some extend. But performance degrades even before we start to crop part of the heart.

results = eval_crop_series(img(),trueMask(),model,start = 50, end=256, finalSize=400, components=["bg","LV","MY","RV","edema","scar"])

Sensitivity to brightness

plot_series(get_brightness_series(img(),model), nrow=1) #end = 0.99, step = 0.18

results = eval_bright_series(img(),trueMask(),model, components=["bg","LV","MY","RV","edema","scar"])

Sensitivity to contrast

plot_series(get_contrast_series(img(),model), nrow = 1, vmax=5) #start=0.1, end=3, step=0.5

results = eval_contrast_series(img(),trueMask(),model, components=["bg","LV","MY","RV","edema","scar"]) #end = 2.5, step = 0.3

Sensitivity to zoom

plot_series(get_zoom_series(img(),model), param_name="zoom", nrow=2, vmax=5) #start=160,end=750,step=60, finalSize=480)

plot_series(get_zoom_series(img(),model), param_name="zoom", nrow=1, vmax=5)#,start=160,end=770,step=120, finalSize=480)

results = eval_zoom_series(img(),trueMask(),model,components=["bg","LV","MY","RV","edema","scar"]) #,start=160,end=900,step=20,finalSize=480
    get_zoom_series(img(),model) , #start=50, end=900,step=50

segmentation sensitivity to zoom

IDEA: Channel imbalance transformation For multi-channel images it might be useful to consider transformations that work differently on different channels.

Robustness to MR artifacts

It would be nice to analyze the effect of MR artifacts. However, we are dealing with multi-channel images here. Each channel is a separate MR image. So it is not obvious how to deal with this.

