from misas.fastai_model import Fastai2_model
Local Interpretability
Evaluation functions
Sensitivity Analysis
Example data (kaggle)
def label_func(x):
pass
def acc_seg(input, target):
pass
def diceComb(input, targs):
pass
def diceLV(input, targs):
pass
def diceMY(input, targs):
pass
def img():
"""
Opens the sample image as a PIL image
"""
return Image.open("example/kaggle/images/1-frame014-slice005.png").convert("RGB")
#img = lambda: Image.open("example/kaggle/images/1-frame014-slice005.png")
def trueMask():
"""
Opens the true mask as a PIL image
"""
return Image.open("example/kaggle/masks/1-frame014-slice005.png").convert("I")
#trueMask = lambda: Image.open("example/kaggle/masks/1-frame014-slice005.png")
= Fastai2_model("chfc-cmi/cmr-seg-tl", "cmr_seg_base") trainedModel
Using cache found in /home/csa84mikl/.cache/torch/hub/chfc-cmi_cmr-seg-tl_master
Default color map
Define a default color map for the sample image derived from viridis
and a default color map for the true map derived from plasma
but setting the color for class “0” to completely transparent. This makes sense if class “0” is the background.
Generic functions
get_generic_series
get_generic_series (image, model, transform, truth=None, tfm_y=False, start=0, end=180, step=30, log_steps=False)
Generic function for transforming images. Input: image (PIP image, usually your sample image opened by img()), model (the function for your model that manages the prediction for your mask), transform (your specific transformation function), truth = None (replaces with a true mask if available), tfm_y = False (set to True if your true mask has to be transformed as well to fit the transformed sample image e.g in case of a rotation of the sample image), start, end, step as values the transform function, log_steps = False (if enabled logarithmic steps as parameters for the transform function are possible) Output: a list containing lists of [param, img, pred, trueMask] after img and optionally trueMask have been transformed and pred has been determined by using a modell on the transformed img for each different param
plot_series
plot_series (series, nrow=1, figsize=(16, 6), param_name='param', overlay_truth=False, vmax=None, vmin=0, cmap=<matplotlib.colors.ListedColormap object at 0x7f52ec9228e0>, cmap_true_mask=<matplotlib.colors.ListedColormap object at 0x7f520556ddf0>, **kwargs)
plots the transformed images with the prediction and optionally the true mask overlayed intput: series = a list containing lists of [param, img, pred, trueMask] from the function get_generic_series nrow = number of rows drawn with the transformed images figsize = (16,6) param_name=‘param’, overlay_truth = False (if True displays the true mask over the sample along with the prediction) vmax = None (controls how many colors the prediction is going to have, can be set manually by the user, otherwise is deterimed by the max amount of colors in the prediction cmap= default_cmap (sets the default color map) output: a plot generated by mathplotlib
plot_frame
plot_frame (param, img, pred, param_name='param', vmax=None, vmin=0, cmap=<matplotlib.colors.ListedColormap object at 0x7f52ec9228e0>, **kwargs)
plots the transformed images and prediction overlayed for the gif_series function
gif_series
gif_series (series, fname, duration=150, param_name='param', vmax=None, vmin=0, cmap=<matplotlib.colors.ListedColormap object at 0x7f52ec9228e0>)
creates a gif from the output of plot_frame
eval_generic_series
eval_generic_series (image, mask, model, transform_function, start=0, end=360, step=5, param_name='param', mask_transform_function=None, components=['bg', 'c1', 'c2'], eval_function=<function dice_by_component>, mask_prepareSize=True)
Perform the transformation on the sample, creates a prediction and then uses the prediction and true mask to run an evaluation function to measure the overlap between predicted mask and true mask
plot_eval_series
plot_eval_series (results, chart_type='line', value_vars=None, value_name='Dice Score')
Plots the resuls of the eval_generic_function
Rotation
get_rotation_series
get_rotation_series (image, model, start=0, end=361, step=60, **kwargs)
runs the get_generic_series with rotationTransform as transform
rotationTransform
rotationTransform (image, deg)
rotates an image by x degrees (deg)
= get_rotation_series(img(), trainedModel, truth=trueMask()) series
[W NNPACK.cpp:51] Could not initialize NNPACK! Reason: Unsupported hardware.
= True) plot_series(series, overlay_truth
eval_rotation_series
eval_rotation_series (image, mask, model, step=5, start=0, end=360, param_name='deg', **kwargs)
= eval_rotation_series(img(), trueMask(), trainedModel, components=['bg','LV','MY'])
results plot_eval_series(results)
You can easily generate gifs by plotting multiple frames
gif_series(=0,end=360,step=5),
get_rotation_series(img(),trainedModel,start"example/kaggle/rotation.gif",
=500,
duration="degrees"
param_name )
Cropping
get_crop_series
get_crop_series (image, model, start=0, end=256, step=10, finalSize=None, **kwargs)
cropTransform
cropTransform (image, pxls, finalSize=None)
= get_crop_series(img(), trainedModel, truth=trueMask(), step=10) series
=3,figsize=(16,15), overlay_truth = True) #,overlay_truth=True) plot_series(series,nrow
eval_crop_series
eval_crop_series (image, mask, model, step=10, start=0, end=256, finalSize=None, param_name='pixels', **kwargs)
= eval_crop_series(img(), trueMask(), trainedModel, components=['bg','LV','MY'])
results plot_eval_series(results)
Cropping and comparing to the full original mask might not be desired. In this case it is possible to crop the mask as well. All pixels in the cropped area are set to 0 (commonly the background class). As soon as a class is completely missing, the dice score might jump to 1 because not predicting the class is correct in that case.
gif_series(=0,end=256,step=5),
get_crop_series(img(),trainedModel,start"example/kaggle/crop.gif",
=500,
duration="pixels"
param_name )
Brightness
get_brightness_series
get_brightness_series (image, model, start=0.25, end=8, step=1.4142135623730951, log_steps=True, **kwargs)
brightnessTransform
brightnessTransform (image, light)
= get_brightness_series(img(), trainedModel, truth=trueMask(), start=1/8, end=16) series
=3, figsize=(12,6), overlay_truth = True) plot_series(series, nrow
eval_bright_series
eval_bright_series (image, mask, model, start=0.05, end=0.95, step=0.05, param_name='brightness', **kwargs)
= eval_bright_series(img(), trueMask(), trainedModel, start=0, end=1.05, components=['bg','LV','MY'])
results plot_eval_series(results)
gif_series(= np.sqrt(2)),
get_brightness_series(img(), trainedModel, step "example/kaggle/brightness.gif",
=500,
duration="brightness"
param_name )
Contrast
get_contrast_series
get_contrast_series (image, model, start=0.25, end=8, step=1.4142135623730951, log_steps=True, **kwargs)
contrastTransform
contrastTransform (image, scale)
= get_contrast_series(img(), trainedModel, truth=trueMask(), start=1/8, end=16) series
=3, figsize=(12,8), overlay_truth = True) plot_series(series, nrow
eval_contrast_series
eval_contrast_series (image, mask, model, start=0.25, end=8, step=1.4142135623730951, param_name='contrast', **kwargs)
= eval_contrast_series(img(), trueMask(), trainedModel, start=0.25, end=8, step=np.sqrt(2), components=['bg','LV','MY'])
results plot_eval_series(results)
gif_series(=np.sqrt(2)),
get_contrast_series(img(), trainedModel,step"example/kaggle/contrast.gif",
=500,
duration="contrast"
param_name )
Zoom
get_zoom_series
get_zoom_series (image, model, start=0, end=1, step=0.1, finalSize=None, **kwargs)
zoomTransform
zoomTransform (image, zoom, finalSize=None)
= get_zoom_series(img(), trainedModel, truth=trueMask()) series
=2, figsize=(16,8), overlay_truth = True) plot_series(series, nrow
eval_zoom_series
eval_zoom_series (image, mask, model, step=0.1, start=0, end=1, finalSize=None, param_name='scale', **kwargs)
= eval_zoom_series(img(), trueMask(), trainedModel, components=['bg','LV','MY'])
results plot_eval_series(results)
gif_series(=0,end=1,step=0.1),
get_zoom_series(img(),trainedModel,start"example/kaggle/zoom.gif",
=500,
duration="scale"
param_name )
Dihedral
get_dihedral_series
get_dihedral_series (image, model, start=0, end=8, step=1, **kwargs)
dihedralTransform
dihedralTransform (image, sym_im)
= get_dihedral_series(img(), trainedModel, truth=trueMask()) series
= True) plot_series(series, overlay_truth
eval_dihedral_series
eval_dihedral_series (image, mask, model, start=0, end=8, step=1, param_name='k', **kwargs)
= eval_dihedral_series(img(), trueMask(), trainedModel, components=['bg','LV','MY'])
results ="point") plot_eval_series(results, chart_type
gif_series(
get_dihedral_series(img(), trainedModel),"example/kaggle/dihedral.gif",
="k",
param_name=1000
duration )
Resize
resizeTransform
resizeTransform (image, size)
get_resize_series
get_resize_series (image, model, start=10, end=200, step=30, **kwargs)
= get_resize_series(img(), trainedModel, truth=trueMask()) series
=True, sharey=True, overlay_truth = True) plot_series(series, sharex
eval_resize_series
eval_resize_series (image, mask, model, start=22, end=3000, step=100, param_name='px', **kwargs)
= eval_resize_series(img(), trueMask(), trainedModel, components=['bg','LV','MY'])
results plot_eval_series(results)
gif_series(
get_resize_series(img(), trainedModel),"example/kaggle/resize.gif",
="px",
param_name=500
duration )
More on Evaluation
The default score for evaluation is the Dice-Score calculated separately for each component. In addition to Dice, misas
provides functions for component-wise functions for precision and recall but you can easily define your own.
= eval_rotation_series(img(), trueMask(), trainedModel, components=['bg','LV','MY'])
results_dice = plot_eval_series(results_dice, value_vars=['bg','LV','MY'], value_name="Dice Score") plot_dice
= eval_rotation_series(img(), trueMask(), trainedModel, components=['bg','LV','MY'], eval_function=precision_by_component)
results_precision = plot_eval_series(results_precision, value_vars=['bg','LV','MY'], value_name="Precision") plot_precision
= eval_rotation_series(img(), trueMask(), trainedModel, components=['bg','LV','MY'], eval_function=recall_by_component)
results_recall = plot_eval_series(results_recall, value_vars=['bg','LV','MY'], value_name="Recall") plot_recall
The objects returned by the plot function are altair
graphs that can be further customized and combined
= plot_dice.properties(title="Dice")
plot_dice = plot_precision.properties(title="Precision")
plot_precision = plot_recall.properties(title="Recall") plot_recall
& plot_precision & plot_recall plot_dice
In order to define your own evaluation function you need to define a function with the predicted and true masks as first and second parameter and the component to evaluate as third parameter. Masks are of type ImageSegment and you can access the tensor data using the .data
property. This is an example on how to define specificity. This can than be passed as evaluation function.
def specificity_by_component(predictedMask, trueMask, component = 1):
= 1.0
specificity = np.array(predictedMask) != component
pred = np.array(trueMask) != component
msk = pred&msk
intersect = np.sum(pred) + np.sum(msk)
total if total > 0:
= np.sum(intersect).astype(float) / msk.sum()
specificity return specificity.item()
= eval_rotation_series(img(), trueMask(), trainedModel, components=['bg','LV','MY'], eval_function=specificity_by_component)
results_specificity = plot_eval_series(results_specificity, value_vars=['bg','LV','MY'], value_name="Specificity")
plot_specificity plot_specificity
The specificity for the background class degrades so dramatically for rotations of around 180 degrees as the LV and MY classes are no longer detected at all. So there are no “true negatives” for the background class, consequently specificity for that class drops to zero.
Confusion Matrices
Confusion matrices are useful to evaluate in more detail which classes the model gets wrong. To conveniently generate separate confusion matrices or series misas
provides some convenience functions.
= get_rotation_series(img(), trainedModel, truth=trueMask()) series
get_confusion
get_confusion (prediction, truth, max_class=None)
The get_confusion
function returns a two dimensional numpy
array with counts for each class combination. The true class is along the columns and the predicted class along the rows. The number of classes is derived from the data if not provided as max_class
parameter. This parameter is important if the given instance of prediction and truth does not contain all available classes.
= get_confusion(series[0][2], series[0][3])
cm cm
array([[63862, 22, 37],
[ 8, 756, 76],
[ 129, 28, 618]])
This matrix shows that there are 754 pixels classified correctly as “LV” (class=1). However, there are also 17 pixels that are in reality “LV” but predicted as “MY”. Accordingly, there are 68 pixels that are “MY” but predicted as “LV”.
Looking at tables is much less convenient and informative than looking at graphics so let’s plot this matrix
plot_confusion
plot_confusion (confusion_matrix, norm_axis=0, components=None, ax=None, ax_label=True, cmap='Blues')
= plot_confusion(cm, components=["bg","LV","MY"]) _
This is the confusion matrix for one image. Next we want to look at the confusion matrix for a full series of transformed (in this case rotated) images.
plot_confusion_series
plot_confusion_series (series, nrow=1, figsize=(16, 6), param_name='param', cmap='Blues', components=None, norm_axis=0, **kwargs)
=(16.5,6)) plot_series(series,figsize
=['bg','LV','MY']) plot_confusion_series(series, components