Setup

Dataset is acquired from kaggle.link SARS-CoV-2 CT scan dataset is a public dataset, containing 1252 CT scans (computed tomography scan) from SARS-CoV-2 infected patients (COVID-19) and 1230 CT scans for SARS-CoV-2 non-infected patients. The dataset has been collected from real patients in Sao Paulo, Brazil. The dataset is available in kaggle.

import opendatasets as od

dataset_url = 'https://www.kaggle.com/plameneduardo/sarscov2-ctscan-dataset'
od.download(dataset_url)
  2%|▏         | 5.00M/230M [00:00<00:06, 34.4MB/s]
Downloading sarscov2-ctscan-dataset.zip to ./sarscov2-ctscan-dataset
100%|██████████| 230M/230M [00:04<00:00, 52.3MB/s] 

path=Path('sarscov2-ctscan-dataset')

The Dataset contains two folders namely COVID & non-COVID having CT Scan Images of patients:

path.ls()
(#2) [Path('sarscov2-ctscan-dataset/COVID'),Path('sarscov2-ctscan-dataset/non-COVID')]

Preprocessing

Exploring Dataset Structure and displaying sample CT Scan directories:

path.ls()
(#2) [Path('sarscov2-ctscan-dataset/COVID'),Path('sarscov2-ctscan-dataset/non-COVID')]

There are 1252 CT scan images from SARS-CoV-2 infected patients.

(path/'COVID').ls()
(#1252) [Path('sarscov2-ctscan-dataset/COVID/Covid (1).png'),Path('sarscov2-ctscan-dataset/COVID/Covid (10).png'),Path('sarscov2-ctscan-dataset/COVID/Covid (100).png'),Path('sarscov2-ctscan-dataset/COVID/Covid (1000).png'),Path('sarscov2-ctscan-dataset/COVID/Covid (1001).png'),Path('sarscov2-ctscan-dataset/COVID/Covid (1002).png'),Path('sarscov2-ctscan-dataset/COVID/Covid (1003).png'),Path('sarscov2-ctscan-dataset/COVID/Covid (1004).png'),Path('sarscov2-ctscan-dataset/COVID/Covid (1005).png'),Path('sarscov2-ctscan-dataset/COVID/Covid (1006).png')...]

There are 1230 CT scan images from SARS-CoV-2 non-infected patients.

(path/'non-COVID').ls()
(#1229) [Path('sarscov2-ctscan-dataset/non-COVID/Non-Covid (1).png'),Path('sarscov2-ctscan-dataset/non-COVID/Non-Covid (10).png'),Path('sarscov2-ctscan-dataset/non-COVID/Non-Covid (100).png'),Path('sarscov2-ctscan-dataset/non-COVID/Non-Covid (1000).png'),Path('sarscov2-ctscan-dataset/non-COVID/Non-Covid (1001).png'),Path('sarscov2-ctscan-dataset/non-COVID/Non-Covid (1002).png'),Path('sarscov2-ctscan-dataset/non-COVID/Non-Covid (1003).png'),Path('sarscov2-ctscan-dataset/non-COVID/Non-Covid (1004).png'),Path('sarscov2-ctscan-dataset/non-COVID/Non-Covid (1005).png'),Path('sarscov2-ctscan-dataset/non-COVID/Non-Covid (1006).png')...]

Visualizing the Images: Lets look at some raw images in the dataset:

import PIL #looking into the  images downloaded 
img1 = PIL.Image.open((path/'COVID').ls()[0])
img1

Creating a Datablock

DataBlock API :We divide the dataset as train and valid set and use the random_state argument in order to replicate the result.The valid_pct argument represents the proportion of the dataset to include in the valid (in our case 20%). Presizing is done and Transformations are applied to images keeping 75% of the images and then normalized according to the imagenet stats for applying Transfer Learning later.

Creating Dataloaders

def get_dls(bs,size):
    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                       get_items=get_image_files,
                       get_y=parent_label,
                       splitter=RandomSplitter(valid_pct=0.2, seed=42),
                       item_tfms=Resize(460), #presizing is done 
                       batch_tfms=[*aug_transforms(size=size,min_scale=0.75),
                       Normalize.from_stats(*imagenet_stats)])
    return dblock.dataloaders(path,bs=bs)

To Create dataloaders we use DataBlock API of fast ai: We use images of size 224*224 and a single batch containing 224 images.

dls=get_dls(224,224)

Visualizing the Dataloaders

The images in the dataloaders look like:

dls.show_batch(nrows=3, figsize=(7,6))

A batch of images in a grid look like:

@patch
@delegates(to=draw_label, but=["font_color", "location", "draw_rect", "fsize_div_factor", "font_path", "font_size"])
def show_batch_grid(self:TfmdDL, b=None, n=20, ncol=4, show=True, unique=False,
                    unique_each=True, font_path=None, font_size=20, **kwargs):
    """Show a batch of images
    Key Params:
      * n:      No. of images to display
      * n_col:  No. of columns in the grid
      * unique: Display the same image with different augmentations
      * unique_each: If True, displays a different img on each call
      * font_path:   Path to the `.ttf` font file. Required to display labels
      * font_size:   Size of the font
    """
    if font_path is not None: self.set_font_path(font_path)
    if not hasattr(self, 'font_path'):
        self.font_path = font_path
    if unique:
        old_get_idxs = self.get_idxs
        if unique_each:
            i = np.random.choice(self.n)
            self.get_idxs = partial(itertools.repeat, i)
        else:
            self.get_idxs = lambda: Inf.zeros
    if b is None: b = self.one_batch()
    if not show: return self._pre_show_batch(b, max_n=n)
    _,__, b = self._pre_show_batch(b, max_n=n)
    if unique: self.get_idxs = old_get_idxs
    return make_img_grid([draw_label(i, font_path=self.font_path, font_size=font_size) for i in b],
                         ncol=ncol, img_size=None)

Transfer Learning

The Resnet50 model

  1. What and why did we used Transfer Learning

    • Transfer learning is meaning use a pre-trained model to build our classifier. A pre-trained model is a model that has been previously trained on a dataset. The model comprehends the updated weights and bias. Using a pre-trained model you are saving time and computational resources. Another avantage is that pre-trained models often perform better that architecture designed from scratch.
    • To better understand this point, suppose we want to build a classifier able to sort different sailboat types. A model pre-trained on ships would have already capture in its first layers some boat features, learning faster and with better accuracy among the different sailboat types.
  2. The Resnet50 architecture:

    • Resnet50 generally is considered a good choice as first architecture to test, it shows good performance without an excessive size allowing to use a higher batch size and thus less computation time. For this reason, before to test more complex architectures Resnet50 is a good compromise.
    • Residual net have been ideated to solve the problem of the vanishing gradient. Highly intricate networks with a large number of hidden layer are working effectively in solving complicated tasks. Their structures allow them to catch pattern in complicated data. When we train the network the early layer tend to be trained slower (the gradient are smaller during backpropagation). The initial layers are important because they learn the basic feature of an object (edge, corner and so on). Failing to proper train these layers lead to a decrease in the overall accuracy of the model.
    • Residual neural network have been ideated to solve this issue. The Resnet model presents the possibility to skip the training of some layer during the initial training. The skipped layer is reusing the learned weights from the previous layer. Original research article
  3. Test the Resnet34 architecture with our dataset:

    • Now we are going to test how the FastaAI implementation of this architechture works with the COVID dataset.
    • Create the convolutional neural network First we will create the convolutional neural network based on this architechture, to do this we can use the following code block which uses FastAI ( cnn_learner previously create_cnn) function. We pass the loaded data, specify the model, pass error_rate & accuracy as a list for the metrics parameter specifying we want to see both error_rate and accuracy, and finally specify a weight decay of 1e-1 (1.0).
    • learn.lr_find() & learn.recorder.plot() function to run LR Finder. LR Finder help to find the best learning rate to use with our network. For more information the original paper. As shown from the output of above.

    • learn.recorder.plot() function plot the loss over learning rate. Run the following code block to view the graph. The best learning rate should be chosen as the learning rate value where the curve is the steepest. You may try different learning rate values in order to pick up the best.

    • [learn.fit_one_cycle() & learn.recorder.plot_losses()] The learn.fit_one_cycle() function can be used to fit the model. Fit one cycle reach a comparable accuracy faster than th fit function in training of complex models. Fit one cycle instead of maintain fix the learning rate during all the iterations is linearly increasing the learning rate and then it is decreasing again (this process is what is called one cycle). Moreover, this learning rate variation is helping in preventing overfitting. We use 5 for the parameter cyc_len to specify the number of cycles to run (on cycle can be considered equivalent to an epoch), and max_lr to specify the maximum learning rate to use which we set as 0.001. Fit one cycle varies the learning rate from 10 fold less the maximum learning rate selected. For more information about fit one cycle: article.

Testing with Deeper Architectures

learn = cnn_learner(dls, resnet101, loss_func=CrossEntropyLossFlat(), metrics=[error_rate,accuracy], wd=1e-1).to_fp16()
learn.cbs
(#4) [TrainEvalCallback,Recorder,ProgressCallback,MixedPrecision]

We apply a very powerful Data Augmentation technique that is Mixup and train the model.

learn.fit_one_cycle(80, 3e-3, cbs=MixUp(0.5))
 
epoch train_loss valid_loss error_rate accuracy time
0 1.098607 0.923316 0.405242 0.594758 00:23
1 0.959609 0.496383 0.227823 0.772177 00:23
2 0.907947 0.482087 0.223790 0.776210 00:23
3 0.874337 0.519458 0.252016 0.747984 00:22
4 0.838613 0.474054 0.221774 0.778226 00:23
5 0.800547 0.362412 0.159274 0.840726 00:23
6 0.766916 0.340699 0.137097 0.862903 00:23
7 0.732544 0.190343 0.060484 0.939516 00:23
8 0.705716 0.234552 0.092742 0.907258 00:22
9 0.672490 0.225245 0.092742 0.907258 00:23
10 0.646345 0.196367 0.082661 0.917339 00:23
11 0.614197 0.207071 0.082661 0.917339 00:22
12 0.580931 0.145530 0.058468 0.941532 00:22
13 0.551455 0.160765 0.062500 0.937500 00:22
14 0.530735 0.156187 0.058468 0.941532 00:22
15 0.508143 0.133556 0.040323 0.959677 00:23
16 0.486114 0.130424 0.048387 0.951613 00:23
17 0.468766 0.112146 0.036290 0.963710 00:23
18 0.451705 0.106726 0.038306 0.961694 00:23
19 0.435460 0.140806 0.046371 0.953629 00:23
20 0.421944 0.129412 0.042339 0.957661 00:23
21 0.409734 0.107145 0.034274 0.965726 00:23
22 0.399341 0.125089 0.042339 0.957661 00:23
23 0.389513 0.099228 0.036290 0.963710 00:23
24 0.380188 0.082548 0.018145 0.981855 00:23
25 0.370885 0.071890 0.016129 0.983871 00:23
26 0.363348 0.126151 0.044355 0.955645 00:23
27 0.356708 0.085095 0.022177 0.977823 00:23
28 0.351457 0.082022 0.030242 0.969758 00:23
29 0.346920 0.082360 0.022177 0.977823 00:23
30 0.343117 0.086793 0.026210 0.973790 00:23
31 0.338743 0.084433 0.028226 0.971774 00:23
32 0.332180 0.050694 0.012097 0.987903 00:23
33 0.329243 0.075656 0.022177 0.977823 00:23
34 0.325888 0.074826 0.018145 0.981855 00:22
35 0.321171 0.051103 0.016129 0.983871 00:22
36 0.317992 0.068456 0.014113 0.985887 00:23
37 0.317117 0.095658 0.038306 0.961694 00:23
38 0.314691 0.075247 0.026210 0.973790 00:23
39 0.312669 0.059977 0.014113 0.985887 00:23
40 0.311207 0.062207 0.016129 0.983871 00:23
41 0.307136 0.079891 0.032258 0.967742 00:23
42 0.303215 0.060350 0.014113 0.985887 00:23
43 0.301979 0.061862 0.014113 0.985887 00:23
44 0.302073 0.056083 0.012097 0.987903 00:23
45 0.298332 0.054264 0.014113 0.985887 00:23
46 0.297571 0.050670 0.006048 0.993952 00:23
47 0.295835 0.053044 0.014113 0.985887 00:23
48 0.295003 0.053177 0.006048 0.993952 00:23
49 0.295658 0.070317 0.012097 0.987903 00:23
50 0.293548 0.051080 0.008064 0.991935 00:23
51 0.293651 0.061804 0.016129 0.983871 00:23
52 0.291592 0.044012 0.010081 0.989919 00:23
53 0.289374 0.047382 0.006048 0.993952 00:23
54 0.288036 0.050668 0.006048 0.993952 00:23
55 0.286396 0.057323 0.016129 0.983871 00:23
56 0.285277 0.049304 0.012097 0.987903 00:23
57 0.283157 0.047652 0.010081 0.989919 00:23
58 0.282250 0.046741 0.008064 0.991935 00:23
59 0.282024 0.043001 0.006048 0.993952 00:23
60 0.281422 0.043425 0.004032 0.995968 00:22
61 0.279792 0.048245 0.004032 0.995968 00:23
62 0.278259 0.050301 0.008064 0.991935 00:23
63 0.276450 0.042498 0.006048 0.993952 00:23
64 0.275557 0.043382 0.008064 0.991935 00:23
65 0.274992 0.046327 0.008064 0.991935 00:23
66 0.274949 0.051264 0.012097 0.987903 00:23
67 0.276006 0.050355 0.010081 0.989919 00:23
68 0.277512 0.047513 0.008064 0.991935 00:22
69 0.275122 0.044733 0.006048 0.993952 00:22
70 0.275745 0.042205 0.006048 0.993952 00:23
71 0.274163 0.041508 0.006048 0.993952 00:23
72 0.273943 0.042359 0.006048 0.993952 00:23
73 0.273899 0.042546 0.006048 0.993952 00:23
74 0.271843 0.044013 0.006048 0.993952 00:23
75 0.271735 0.043344 0.006048 0.993952 00:22
76 0.271392 0.045417 0.008064 0.991935 00:23
77 0.270836 0.044158 0.006048 0.993952 00:23
78 0.272595 0.043604 0.008064 0.991935 00:23
79 0.272234 0.044281 0.008064 0.991935 00:23

TTA(Test Time Augmentation)

preds,targs = learn.tta() # TTA applied for validation dataset
accuracy(preds, targs).item()
0.9959677457809448

We get a TTA of 99.59% on the validation set.

ClassificationInterpretationEx

We examine the model predictions in more depth:

import fastai

def _get_truths(vocab, label_idx, is_multilabel):
    if is_multilabel:
          return ';'.join([vocab[i] for i in torch.where(label_idx==1)][0])
    else: return vocab[label_idx]

class ClassificationInterpretationEx(ClassificationInterpretation):
    """
    Extend fastai2's `ClassificationInterpretation` to analyse model predictions in more depth
    See:
      * self.preds_df
      * self.plot_label_confidence()
      * self.plot_confusion_matrix()
      * self.plot_accuracy()
      * self.get_fnames()
      * self.plot_top_losses_grid()
      * self.print_classification_report()
    """
    def __init__(self, dl, inputs, preds, targs, decoded, losses):
        super().__init__(dl, inputs, preds, targs, decoded, losses)
        self.vocab = self.dl.vocab
        if is_listy(self.vocab): self.vocab = self.vocab[-1]
        if self.targs.__class__ == fastai.torch_core.TensorMultiCategory:
              self.is_multilabel = True
        else: self.is_multilabel = False
        self.compute_label_confidence()
        self.determine_classifier_type()

    def determine_classifier_type(self):
        if self.targs[0].__class__==fastai.torch_core.TensorCategory:
            self.is_multilabel = False
        if self.targs[0].__class__==fastai.torch_core.TensorMultiCategory:
            self.is_multilabel = True
            self.thresh = self.dl.loss_func.thresh

    def compute_label_confidence(self, df_colname:Optional[str]="fnames"):
        """
        Collate prediction confidence, filenames, and ground truth labels
        in DataFrames, and store them as class attributes
        `self.preds_df` and `self.preds_df_each`

        If the `DataLoaders` is constructed from a `pd.DataFrame`, use
        `df_colname` to specify the column name with the filepaths
        """
        if not isinstance(self.dl.items, pd.DataFrame):
            self._preds_collated = [
                #(item, self.dl.vocab[label_idx], *preds.numpy()*100)\
                (item, _get_truths(self.dl.vocab, label_idx, self.is_multilabel), *preds.numpy()*100)\
                for item,label_idx,preds in zip(self.dl.items,
                                                self.targs,
                                                self.preds)
            ]
        ## need to extract fname from DataFrame
        elif isinstance(self.dl.items, pd.DataFrame):
            self._preds_collated = [
                #(item[df_colname], self.dl.vocab[label_idx], *preds.numpy()*100)\
                (item[df_colname], _get_truths(self.dl.vocab, label_idx, self.is_multilabel), *preds.numpy()*100)\
                for (_,item),label_idx,preds in zip(self.dl.items.iterrows(),
                                                self.targs,
                                                self.preds)
            ]

        self.preds_df       = pd.DataFrame(self._preds_collated, columns = ['fname','truth', *self.dl.vocab])
        self.preds_df.insert(2, column='loss', value=self.losses.numpy())

        if self.is_multilabel: return # preds_df_each doesnt make sense for multi-label
        self._preds_df_each = {l:self.preds_df.copy()[self.preds_df.truth == l].reset_index(drop=True) for l in self.dl.vocab}
        self.preds_df_each  = defaultdict(dict)


        sort_desc = lambda x,col: x.sort_values(col, ascending=False).reset_index(drop=True)
        for label,df in self._preds_df_each.items():
            filt = df[label] == df[self.dl.vocab].max(axis=1)
            self.preds_df_each[label]['accurate']   = df.copy()[filt]
            self.preds_df_each[label]['inaccurate'] = df.copy()[~filt]

            self.preds_df_each[label]['accurate']   = sort_desc(self.preds_df_each[label]['accurate'], label)
            self.preds_df_each[label]['inaccurate'] = sort_desc(self.preds_df_each[label]['inaccurate'], label)
            assert len(self.preds_df_each[label]['accurate']) + len(self.preds_df_each[label]['inaccurate']) == len(df)

    def get_fnames(self, label:str,
                   mode:('accurate','inaccurate'),
                   conf_level:Union[int,float,tuple]) -> np.ndarray:
        """
        Utility function to grab filenames of a particular label `label` that were classified
        as per `mode` (accurate|inaccurate).
        These filenames are filtered by `conf_level` which can be above or below a certain
        threshold (above if `mode` == 'accurate' else below), or in confidence ranges
        """
        assert label in self.dl.vocab
        if not hasattr(self, 'preds_df_each'): self.compute_label_confidence()
        df = self.preds_df_each[label][mode].copy()
        if mode == 'accurate':
            if isinstance(conf_level, tuple):       filt = df[label].between(*conf_level)
            if isinstance(conf_level, (int,float)): filt = df[label] > conf_level
        if mode == 'inaccurate':
            if isinstance(conf_level, tuple):       filt = df[label].between(*conf_level)
            if isinstance(conf_level, (int,float)): filt = df[label] < conf_level
        return df[filt].fname.values
fname truth loss COVID non-COVID
0 sarscov2-ctscan-dataset/non-COVID/Non-Covid (386).png non-COVID 0.047037 4.594821 95.405182
1 sarscov2-ctscan-dataset/COVID/Covid (581).png COVID 0.018380 98.178818 1.821182

ClassificationInterpretationEx.get_fnames

Returns accuratly classified files with accuracy above 85%:

interp.get_fnames('accurate', 99.95)

Returns inaccurately classified files with accuracy between 84.1-85.2%:

interp.get_fnames('img1', 'accurate', (84.1, 85.2))

Confusion Matrix

Checking the Confusion Matrix:

Plot Accuracy

plotting curves of training process:

functions to plot the accuracy of the labels:

@patch
def plot_accuracy(self:ClassificationInterpretationEx, width=0.9, figsize=(6,6), return_fig=False,
                  title='Accuracy Per Label', ylabel='Accuracy (%)', style='ggplot',
                  color='#2a467e', vertical_labels=True):
    'Plot a bar plot showing accuracy per label'
    if not hasattr(self, 'preds_df_each'):
        raise NotImplementedError
    plt.style.use(style)
    if not hasattr(self, 'preds_df_each'): self.compute_label_confidence()
    self.accuracy_dict = defaultdict()

    for label,df in self.preds_df_each.items():
        total = len(df['accurate']) + len(df['inaccurate'])
        self.accuracy_dict[label] = 100 * len(df['accurate']) / total

    fig,ax = plt.subplots(figsize=figsize)

    x = self.accuracy_dict.keys()
    y = [v for k,v in self.accuracy_dict.items()]

    rects = ax.bar(x,y,width,color=color)
    for rect in rects:
        ht = rect.get_height()
        ax.annotate(s  = f"{ht:.02f}",
                    xy = (rect.get_x() + rect.get_width()/2, ht),
                    xytext = (0,3), # offset vertically by 3 points
                    textcoords = 'offset points',
                    ha = 'center', va = 'bottom'
                   )
    ax.set_ybound(lower=0, upper=100)
    ax.set_yticks(np.arange(0,110,10))
    ax.set_ylabel(ylabel)
    ax.set_xticklabels(x, rotation='vertical' if vertical_labels else 'horizontal')
    plt.suptitle(title)
    plt.tight_layout()
    if return_fig: return fig
<ipython-input-67-2c6afab5379f>:25: MatplotlibDeprecationWarning: The 's' parameter of annotate() has been renamed 'text' since Matplotlib 3.3; support for the old name will be dropped two minor releases later.
  ax.annotate(s  = f"{ht:.02f}",
<ipython-input-67-2c6afab5379f>:34: UserWarning: FixedFormatter should only be used together with FixedLocator
  ax.set_xticklabels(x, rotation='vertical' if vertical_labels else 'horizontal')

Plot Label Confidence

Plotting label confidence as histograms for each label:

@patch
def plot_label_confidence(self:ClassificationInterpretationEx, bins:int=5, fig_width:int=12, fig_height_base:int=4,
                          title:str='Accurate vs. Inaccurate Predictions Confidence (%) Levels Per Label',
                          return_fig:bool=False, label_bars:bool=True, style='ggplot', dpi=150,
                          accurate_color='#2a467e', inaccurate_color='#dc4a46'):
    """Plot label confidence histograms for each label
    Key Args:
      * `bins`:       No. of bins on each side of the plot
      * `return_fig`: If True, returns the figure that can be easily saved to disk
      * `label_bars`: If True, displays the % of samples that fall into each bar
      * `style`:      A matplotlib style. See `plt.style.available` for more
      * `accurate_color`:   Color of the accurate bars
      * `inaccurate_color`: Color of the inaccurate bars
    """
    if not hasattr(self, 'preds_df_each'):
        raise NotImplementedError
    plt.style.use(style)
    fig, axes = plt.subplots(nrows = len(self.preds_df_each.keys()), ncols=2, dpi=dpi,
                             figsize = (fig_width, fig_height_base * len(self.dl.vocab)))
    for i, (label, df) in enumerate(self.preds_df_each.items()):
        height=0
        # find max height
        for mode in ['inaccurate', 'accurate']:
            len_bins,_ = np.histogram(df[mode][label], bins=bins)
            if len_bins.max() > height: height=len_bins.max()

        for mode,ax in zip(['inaccurate', 'accurate'], axes[i]):
            range_ = (50,100) if mode == 'accurate' else (0,50)
            color  = accurate_color if mode == 'accurate' else inaccurate_color
            num,_,patches = ax.hist(df[mode][label], bins=bins, range=range_, rwidth=.95, color=color)
            num_samples = len(df['inaccurate'][label]) + len(df['accurate'][label])
            pct_share   = len(df[mode][label]) / num_samples
            if label_bars:
                for rect in patches:
                    ht = rect.get_height()
                    ax.annotate(s  = f"{round((int(ht) / num_samples) * 100, 1) if ht > 0 else 0}%",
                        xy = (rect.get_x() + rect.get_width()/2, ht),
                        xytext = (0,3), # offset vertically by 3 points
                        textcoords = 'offset points',
                        ha = 'center', va = 'bottom'
                       )
            ax.set_ybound(upper=height + height*0.3)
            ax.set_xlabel(f'{label}: {mode.capitalize()} ({round(pct_share * 100, 2)}%)')
            ax.set_ylabel(f'Num. {mode.capitalize()} ({len(df[mode][label])} of {num_samples})')
    fig.suptitle(title, y=1.0)
    plt.subplots_adjust(top = 0.9, bottom=0.01, hspace=0.25, wspace=0.2)
    plt.tight_layout()
    if return_fig: return fig
<ipython-input-69-1241d4a5b1b6>:37: MatplotlibDeprecationWarning: The 's' parameter of annotate() has been renamed 'text' since Matplotlib 3.3; support for the old name will be dropped two minor releases later.
  ax.annotate(s  = f"{round((int(ht) / num_samples) * 100, 1) if ht > 0 else 0}%",

Plot Top Losses grid

plotting the top losses in a grid:

from fastai_amalgam.utils import *
@patch
def plot_top_losses_grid(self:ClassificationInterpretationEx, k=16, ncol=4, __largest=True,
                         font_path=None, font_size=12, use_dedicated_layout=True) -> PIL.Image.Image:
    """Plot top losses in a grid

    Uses fastai'a `ClassificationInterpretation.plot_top_losses` to fetch
    predictions, and makes a grid with the ground truth labels, predictions,
    prediction confidence and loss ingrained into the image

    By default, `use_dedicated_layout` is used to plot the loss (bottom),
    truths (top-left), and predictions (top-right) in dedicated areas of the
    image. If this is set to `False`, everything is printed at the bottom of the
    image
    """
    # all of the pred fetching code is copied over from
    # fastai's `ClassificationInterpretation.plot_top_losses`
    # and only plotting code is added here
    losses,idx = self.top_losses(k, largest=__largest)
    if not isinstance(self.inputs, tuple): self.inputs = (self.inputs,)
    if isinstance(self.inputs[0], Tensor): inps = tuple(o[idx] for o in self.inputs)
    else: inps = self.dl.create_batch(self.dl.before_batch([tuple(o[i] for o in self.inputs) for i in idx]))
    b = inps + tuple(o[idx] for o in (self.targs if is_listy(self.targs) else (self.targs,)))
    x,y,its = self.dl._pre_show_batch(b, max_n=k)
    b_out = inps + tuple(o[idx] for o in (self.decoded if is_listy(self.decoded) else (self.decoded,)))
    x1,y1,outs = self.dl._pre_show_batch(b_out, max_n=k)
    #if its is not None:
    #    _plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), self.preds[idx], losses,  **kwargs)
    plot_items = its.itemgot(0), its.itemgot(1), outs.itemgot(slice(len(inps), None)), self.preds[idx], losses
    def draw_label(x:TensorImage, labels):
        return PILImage.create(x).draw_labels(labels, font_path=font_path, font_size=font_size, location="bottom")
    # return plot_items
    results = []
    for x, truth, preds, preds_raw, loss in zip(*plot_items):
        if self.is_multilabel:
            preds = preds[0]
        probs_i = np.array([self.dl.vocab.o2i[o] for o in preds])
        pred2prob = [f"{pred} ({round(prob.item()*100,2)}%)" for pred,prob in zip(preds,preds_raw[probs_i])]
        if use_dedicated_layout:
            # draw loss at the bottom, preds on top-right
            # and truths on the top
            img = PILImage.create(x)
            if isinstance(truth, Category): truth = [truth]
            truth.insert(0, "TRUTH: ")
            pred2prob.insert(0, 'PREDS: ')
            loss_text = f"{'LOSS: '.rjust(8)} {round(loss.item(), 4)}"
            img.draw_labels(truth,     location="top-left", font_size=font_size, font_path=font_path)
            img.draw_labels(pred2prob, location="top-right", font_size=font_size, font_path=font_path)
            img.draw_labels(loss_text, location="bottom", font_size=font_size, font_path=font_path)
            results.append(img)
        else:
            # draw everything at the bottom
            out = []
            out.append(f"{'TRUTH: '.rjust(8)} {truth}")
            bsl = '\n' # since f-strings can't have backslashes
            out.append(f"{'PRED: '.rjust(8)} {bsl.join(pred2prob)}")
            if self.is_multilabel: out.append('\n')
            out.append(f"{'LOSS: '.rjust(8)} {round(loss.item(), 4)}")
            results.append(draw_label(x, out))
    return make_img_grid(results, img_size=None, ncol=ncol)
/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai_amalgam/utils.py:92: UserWarning: Loaded default PIL ImageFont. It's highly recommended you use a custom font as the default font's size cannot be tweaked
  warnings.warn("Loaded default PIL ImageFont. It's highly recommended you use a custom font as the default font's size cannot be tweaked")
/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai_amalgam/utils.py:94: UserWarning: `font_size` cannot be used when not using a custom font passed via `font_path`
  warnings.warn(f"`font_size` cannot be used when not using a custom font passed via `font_path`")

PlotLowest Losses Grid

plotting the lowest losses in a grid fashion:

@patch
@delegates(to=ClassificationInterpretationEx.plot_top_losses_grid, but=['largest'])
def plot_lowest_losses_grid(self:ClassificationInterpretationEx, **kwargs):
    """Plot the lowest losses. Exact opposite of `ClassificationInterpretationEx.plot_top_losses`
    """
    return self.plot_top_losses_grid(__largest=False, **kwargs)