Aim: Apply to MyoPS data

This is a more complex application of misas to a multi-channel input model with multiple output classes. It uses data from the Myocardial pathology segmentation combining multi-sequence CMR challenge (MyoPS 2020).

General results are published in "Myocardial Pathology Segmentation Combining Multi-Sequence Cardiac Magnetic Resonance Images." First Challenge, MyoPS 2020, Held in Conjunction with MICCAI 2020, Lima, Peru, October 4, 2020, Proceedings and the specific model is described in

Ankenbrand M.J., Lohr D., Schreiber L.M. (2020) "Exploring Ensemble Applications for Multi-sequence Myocardial Pathology Segmentation." In:Zhuang X., Li L. (eds) Myocardial Pathology Segmentation Combining Multi-Sequence Cardiac Magnetic Resonance Images. MyoPS 2020. Lecture Notes in Computer Science, vol 12554. Springer, Cham. supplemented by https://github.com/chfc-cmi/miccai2020-myops

DOI and install the specific versions of packages listed at the bottom of the page (other versions might work too but are untested).

Prepare Model for misas

from fastai2.vision.all import *
from fastai2.vision.models import resnet34
class AddMaskCodeMapping(Transform):
    "Add mapping of pixel value to class for a `TensorMask`"
    def __init__(self, mapping, codes=None):
        #print("init")
        self.mapping = mapping
        self.codes = codes
        if codes is not None: self.vocab,self.c = codes,len(codes)

    def encodes(self, o:PILMask):
        #print("encodes")
        mo = ToTensor()(o)
        mo = mo.to(dtype=torch.long)
        mo = self.mapping.index_select(0,mo.flatten()).reshape(*mo.shape)
        mo = PILMask.create(mo.to(dtype=torch.uint8))
        return mo
    
    def decodes(self, o:TensorMask):
        # decoding of inputs works out of the box, but get_preds are not properly decoded
        if len(o.shape) > 2:
            o = o.argmax(dim=0)
        if self.codes is not None: o._meta = {'codes': self.codes}
        return o
def MappedMaskBlock(mapping,codes=None):
    "A `TransformBlock` for segmentation masks, with mapping of pixel values to classes, potentially with `codes`"
    return TransformBlock(type_tfms=PILMask.create, item_tfms=AddMaskCodeMapping(mapping=mapping,codes=codes), batch_tfms=IntToFloatTensor)
def getMappedMaskBlock(predefined_mapping_name):
    predefined_mappings = {
        'full': ([0,1,2,3,4,5],['bg','lv','my','rv','ed','sc']),
        'edOnly': ([0,0,0,0,1,0],['bg','ed']),
        'edScCombined': ([0,0,0,0,1,1],['bg','edSc']),
        'scOnly': ([0,0,0,0,0,1],['bg','sc']),
        'edScOnly': ([0,0,0,0,1,2],['bg','ed','sc']),
    }
    mapping,codes = predefined_mappings[predefined_mapping_name]
    return MappedMaskBlock(mapping = torch.LongTensor(mapping), codes=codes)
def get_train_files(path):
    items = get_image_files(path)
    items = L([x for x in items if x.name.startswith("1")])
    return items
def getMyopsDls(mapping_name="full", images="images", path='example/myops'):
    mmb = getMappedMaskBlock(mapping_name)
    myopsData = DataBlock(blocks=(ImageBlock, mmb),#['bg','lv','my','rv','ed','sc'])),
        get_items=get_train_files,
        splitter=FuncSplitter(lambda o: False),
        get_y=lambda o: str(o).replace(images,"masks"),
        item_tfms=CropPad(256),
        batch_tfms=aug_transforms(max_rotate=90,pad_mode="zeros"))
    dls = myopsData.dataloaders(f'{path}/{images}',num_workers=4,batch_size=12)
    dls[1].bs = 12
    return dls
def multi_dice(input:Tensor, targs:Tensor, class_id=0, inverse=False)->Tensor:
    n = targs.shape[0]
    input = input.argmax(dim=1).view(n,-1)
    # replace all with class_id with 1 all else with 0 to have binary case
    output = (input == class_id).float()
    # same for targs
    targs = (targs.view(n,-1) == class_id).float()
    if inverse:
        output = 1 - output
        targs = 1 - targs
    intersect = (output * targs).sum(dim=1).float()
    union = (output+targs).sum(dim=1).float()
    res = 2. * intersect / union
    res[torch.isnan(res)] = 1
    return res.mean()

def diceFG(input, targs): return multi_dice(input,targs,class_id=1)
def diceLV(input, targs): return multi_dice(input,targs,class_id=1)
def diceMY(input, targs): return multi_dice(input,targs,class_id=2)
def diceRV(input, targs): return multi_dice(input,targs,class_id=3)
def diceEd(input, targs): return multi_dice(input,targs,class_id=4)
def diceSc(input, targs): return multi_dice(input,targs,class_id=5)
dices = [diceLV,diceMY,diceRV,diceEd,diceSc]
learn = unet_learner(
        getMyopsDls("full", "images"),
        resnet34
    )
from fastai.vision import Image as F1Image
from fastai.vision import ImageSegment as F1ImageSegment

Prepare Dataset for misas

Data is available as png images and masks which is just fine for misas

from fastai.vision import open_image, open_mask
img = lambda: open_image("example/myops/images/101-orig-4.png")
trueMask = lambda: open_mask("example/myops/masks/101-orig-4.png")
img().show()
trueMask().show()
class Fastai2_model:
    def __init__(self, learner):
        self.trainedModel = learner
        
    def prepareSize(self, item, asPIL=False):
        if isinstance(item, F1ImageSegment):
            pilmask = PILMask(Image.fromarray((item.data.squeeze(0)).numpy().astype(np.uint8)))
            pilmask = CropPad(256)(pilmask)
            return F1ImageSegment(torch.ByteTensor(np.array(pilmask)).unsqueeze(0))
        pilimg = PILImage(Image.fromarray((item.data.permute(1,2,0) * 255).numpy().astype(np.uint8)))
        pilimg = CropPad(256)(pilimg)
        if asPIL:
            return pilimg
        return F1Image(torch.Tensor(np.array(pilimg)/255).permute(2,0,1))
        
    def predict(self, image):
        pilimg = self.prepareSize(image, asPIL=True)
        with self.trainedModel.no_bar():
            mask,probs,rest = self.trainedModel.predict(pilimg)
        return F1ImageSegment(torch.ByteTensor(np.array(mask)).unsqueeze(0)),probs

# model = Fastai1_model('example/b0','b0_transversal_5_5') # if it were local
model = Fastai2_model(learn.load("../example/myops/multi_ce_full"))

How does the trained model perform on this (training) example?

Time to apply the model to the example image and see how it works (we need to call prepareSize manually here):

from misas.core import *
from misas.core import default_cmap
model.prepareSize(img())
model.prepareSize(trueMask())
mask = model.predict(img())[0]
mask
fig,ax = plt.subplots(figsize=(4.5,4.5))
model.prepareSize(img()).show(ax=ax)
mask.show(ax=ax,cmap=default_cmap)
fig,ax = plt.subplots(figsize=(4.5,4.5))
model.prepareSize(img()).show(ax=ax)
model.prepareSize(trueMask()).show(ax=ax,cmap=default_cmap)

Robustness to basic transformations

#trueMask = lambda: open_mask(files[0].replace("image","mask"))

Sensitivity to orientation

Changes in orientation are very common. Not because it is common to acquire images in different orientation but because the way data is stored in different file formats like nifti and dicom differs. So it is interesting to see how the model works in all possible orientations (including flips).

dihed = get_dihedral_series(img(),model)
plot_series(dihed, nrow=2, figsize=(20,12))