Setup

Dataset is acquired from kaggle.link SARS-CoV-2 CT scan dataset is a public dataset, containing 1252 CT scans (computed tomography scan) from SARS-CoV-2 infected patients (COVID-19) and 1230 CT scans for SARS-CoV-2 non-infected patients. The dataset has been collected from real patients in Sao Paulo, Brazil. The dataset is available in kaggle.

import opendatasets as od

dataset_url = 'https://www.kaggle.com/plameneduardo/sarscov2-ctscan-dataset'
od.download(dataset_url)
  2%|▏         | 5.00M/230M [00:00<00:06, 34.4MB/s]
Downloading sarscov2-ctscan-dataset.zip to ./sarscov2-ctscan-dataset
100%|██████████| 230M/230M [00:04<00:00, 52.3MB/s] 

path=Path('sarscov2-ctscan-dataset')

The Dataset contains two folders namely COVID & non-COVID having CT Scan Images of patients:

path.ls()
(#2) [Path('sarscov2-ctscan-dataset/COVID'),Path('sarscov2-ctscan-dataset/non-COVID')]

Preprocessing

Exploring Dataset Structure and displaying sample CT Scan directories:

path.ls()
(#2) [Path('sarscov2-ctscan-dataset/COVID'),Path('sarscov2-ctscan-dataset/non-COVID')]

There are 1252 CT scan images from SARS-CoV-2 infected patients.

(path/'COVID').ls()
(#1252) [Path('sarscov2-ctscan-dataset/COVID/Covid (1).png'),Path('sarscov2-ctscan-dataset/COVID/Covid (10).png'),Path('sarscov2-ctscan-dataset/COVID/Covid (100).png'),Path('sarscov2-ctscan-dataset/COVID/Covid (1000).png'),Path('sarscov2-ctscan-dataset/COVID/Covid (1001).png'),Path('sarscov2-ctscan-dataset/COVID/Covid (1002).png'),Path('sarscov2-ctscan-dataset/COVID/Covid (1003).png'),Path('sarscov2-ctscan-dataset/COVID/Covid (1004).png'),Path('sarscov2-ctscan-dataset/COVID/Covid (1005).png'),Path('sarscov2-ctscan-dataset/COVID/Covid (1006).png')...]

There are 1230 CT scan images from SARS-CoV-2 non-infected patients.

(path/'non-COVID').ls()
(#1229) [Path('sarscov2-ctscan-dataset/non-COVID/Non-Covid (1).png'),Path('sarscov2-ctscan-dataset/non-COVID/Non-Covid (10).png'),Path('sarscov2-ctscan-dataset/non-COVID/Non-Covid (100).png'),Path('sarscov2-ctscan-dataset/non-COVID/Non-Covid (1000).png'),Path('sarscov2-ctscan-dataset/non-COVID/Non-Covid (1001).png'),Path('sarscov2-ctscan-dataset/non-COVID/Non-Covid (1002).png'),Path('sarscov2-ctscan-dataset/non-COVID/Non-Covid (1003).png'),Path('sarscov2-ctscan-dataset/non-COVID/Non-Covid (1004).png'),Path('sarscov2-ctscan-dataset/non-COVID/Non-Covid (1005).png'),Path('sarscov2-ctscan-dataset/non-COVID/Non-Covid (1006).png')...]

Visualizing the Images: Lets look at some raw images in the dataset:

import PIL #looking into the  images downloaded 
img1 = PIL.Image.open((path/'COVID').ls()[0])
img1

Creating a Datablock

DataBlock API :We divide the dataset as train and valid set and use the random_state argument in order to replicate the result.The valid_pct argument represents the proportion of the dataset to include in the valid (in our case 20%). Presizing is done and Transformations are applied to images keeping 75% of the images and then normalized according to the imagenet stats for applying Transfer Learning later.

Creating Dataloaders

def get_dls(bs,size):
    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                       get_items=get_image_files,
                       get_y=parent_label,
                       splitter=RandomSplitter(valid_pct=0.2, seed=42),
                       item_tfms=Resize(460), #presizing is done 
                       batch_tfms=[*aug_transforms(size=size,min_scale=0.75),
                       Normalize.from_stats(*imagenet_stats)])
    return dblock.dataloaders(path,bs=bs)

To Create dataloaders we use DataBlock API of fast ai: We use images of size 224*224 and a single batch containing 224 images.

dls=get_dls(224,224)

Visualizing the Dataloaders

The images in the dataloaders look like:

dls.show_batch(nrows=3, figsize=(7,6))

A batch of images in a grid look like:

@patch
@delegates(to=draw_label, but=["font_color", "location", "draw_rect", "fsize_div_factor", "font_path", "font_size"])
def show_batch_grid(self:TfmdDL, b=None, n=20, ncol=4, show=True, unique=False,
                    unique_each=True, font_path=None, font_size=20, **kwargs):
    """Show a batch of images
    Key Params:
      * n:      No. of images to display
      * n_col:  No. of columns in the grid
      * unique: Display the same image with different augmentations
      * unique_each: If True, displays a different img on each call
      * font_path:   Path to the `.ttf` font file. Required to display labels
      * font_size:   Size of the font
    """
    if font_path is not None: self.set_font_path(font_path)
    if not hasattr(self, 'font_path'):
        self.font_path = font_path
    if unique:
        old_get_idxs = self.get_idxs
        if unique_each:
            i = np.random.choice(self.n)
            self.get_idxs = partial(itertools.repeat, i)
        else:
            self.get_idxs = lambda: Inf.zeros
    if b is None: b = self.one_batch()
    if not show: return self._pre_show_batch(b, max_n=n)
    _,__, b = self._pre_show_batch(b, max_n=n)
    if unique: self.get_idxs = old_get_idxs
    return make_img_grid([draw_label(i, font_path=self.font_path, font_size=font_size) for i in b],
                         ncol=ncol, img_size=None)

Transfer Learning

The Resnet50 model

  1. What and why did we used Transfer Learning

    • Transfer learning is meaning use a pre-trained model to build our classifier. A pre-trained model is a model that has been previously trained on a dataset. The model comprehends the updated weights and bias. Using a pre-trained model you are saving time and computational resources. Another avantage is that pre-trained models often perform better that architecture designed from scratch.
    • To better understand this point, suppose we want to build a classifier able to sort different sailboat types. A model pre-trained on ships would have already capture in its first layers some boat features, learning faster and with better accuracy among the different sailboat types.
  2. The Resnet50 architecture:

    • Resnet50 generally is considered a good choice as first architecture to test, it shows good performance without an excessive size allowing to use a higher batch size and thus less computation time. For this reason, before to test more complex architectures Resnet50 is a good compromise.
    • Residual net have been ideated to solve the problem of the vanishing gradient. Highly intricate networks with a large number of hidden layer are working effectively in solving complicated tasks. Their structures allow them to catch pattern in complicated data. When we train the network the early layer tend to be trained slower (the gradient are smaller during backpropagation). The initial layers are important because they learn the basic feature of an object (edge, corner and so on). Failing to proper train these layers lead to a decrease in the overall accuracy of the model.
    • Residual neural network have been ideated to solve this issue. The Resnet model presents the possibility to skip the training of some layer during the initial training. The skipped layer is reusing the learned weights from the previous layer. Original research article
  3. Test the Resnet34 architecture with our dataset:

    • Now we are going to test how the FastaAI implementation of this architechture works with the COVID dataset.
    • Create the convolutional neural network First we will create the convolutional neural network based on this architechture, to do this we can use the following code block which uses FastAI ( cnn_learner previously create_cnn) function. We pass the loaded data, specify the model, pass error_rate & accuracy as a list for the metrics parameter specifying we want to see both error_rate and accuracy, and finally specify a weight decay of 1e-1 (1.0).
    • learn.lr_find() & learn.recorder.plot() function to run LR Finder. LR Finder help to find the best learning rate to use with our network. For more information the original paper. As shown from the output of above.

    • learn.recorder.plot() function plot the loss over learning rate. Run the following code block to view the graph. The best learning rate should be chosen as the learning rate value where the curve is the steepest. You may try different learning rate values in order to pick up the best.

    • [learn.fit_one_cycle() & learn.recorder.plot_losses()] The learn.fit_one_cycle() function can be used to fit the model. Fit one cycle reach a comparable accuracy faster than th fit function in training of complex models. Fit one cycle instead of maintain fix the learning rate during all the iterations is linearly increasing the learning rate and then it is decreasing again (this process is what is called one cycle). Moreover, this learning rate variation is helping in preventing overfitting. We use 5 for the parameter cyc_len to specify the number of cycles to run (on cycle can be considered equivalent to an epoch), and max_lr to specify the maximum learning rate to use which we set as 0.001. Fit one cycle varies the learning rate from 10 fold less the maximum learning rate selected. For more information about fit one cycle: article.

Testing with Deeper Architectures

learn = cnn_learner(dls, resnet101, loss_func=CrossEntropyLossFlat(), metrics=[error_rate,accuracy], wd=1e-1).to_fp16()
learn.cbs
(#4) [TrainEvalCallback,Recorder,ProgressCallback,MixedPrecision]

We apply a very powerful Data Augmentation technique that is Mixup and train the model.

learn.fit_one_cycle(80, 3e-3, cbs=MixUp(0.5))
 
epoch train_loss valid_loss error_rate accuracy time
0 1.098607 0.923316 0.405242 0.594758 00:23
1 0.959609 0.496383 0.227823 0.772177 00:23
2 0.907947 0.482087 0.223790 0.776210 00:23
3 0.874337 0.519458 0.252016 0.747984 00:22
4 0.838613 0.474054 0.221774 0.778226 00:23
5 0.800547 0.362412 0.159274 0.840726 00:23
6 0.766916 0.340699 0.137097 0.862903 00:23
7 0.732544 0.190343 0.060484 0.939516 00:23
8 0.705716 0.234552 0.092742 0.907258 00:22
9 0.672490 0.225245 0.092742 0.907258 00:23
10 0.646345 0.196367 0.082661 0.917339 00:23
11 0.614197 0.207071 0.082661 0.917339 00:22
12 0.580931 0.145530 0.058468 0.941532 00:22
13 0.551455 0.160765 0.062500 0.937500 00:22
14 0.530735 0.156187 0.058468 0.941532 00:22
15 0.508143 0.133556 0.040323 0.959677 00:23
16 0.486114 0.130424 0.048387 0.951613 00:23
17 0.468766 0.112146 0.036290 0.963710 00:23
18 0.451705 0.106726 0.038306 0.961694 00:23
19 0.435460 0.140806 0.046371 0.953629 00:23
20 0.421944 0.129412 0.042339 0.957661 00:23
21 0.409734 0.107145 0.034274 0.965726 00:23
22 0.399341 0.125089 0.042339 0.957661 00:23
23 0.389513 0.099228 0.036290 0.963710 00:23
24 0.380188 0.082548 0.018145 0.981855 00:23
25 0.370885 0.071890 0.016129 0.983871 00:23
26 0.363348 0.126151 0.044355 0.955645 00:23
27 0.356708 0.085095 0.022177 0.977823 00:23
28 0.351457 0.082022 0.030242 0.969758 00:23
29 0.346920 0.082360 0.022177 0.977823 00:23
30 0.343117 0.086793 0.026210 0.973790 00:23
31 0.338743 0.084433 0.028226 0.971774 00:23
32 0.332180 0.050694 0.012097 0.987903 00:23
33 0.329243 0.075656 0.022177 0.977823 00:23
34 0.325888 0.074826 0.018145 0.981855 00:22
35 0.321171 0.051103 0.016129 0.983871 00:22
36 0.317992 0.068456 0.014113 0.985887 00:23
37 0.317117 0.095658 0.038306 0.961694 00:23
38 0.314691 0.075247 0.026210 0.973790 00:23
39 0.312669 0.059977 0.014113 0.985887 00:23
40 0.311207 0.062207 0.016129 0.983871 00:23
41 0.307136 0.079891 0.032258 0.967742 00:23
42 0.303215 0.060350 0.014113 0.985887 00:23
43 0.301979 0.061862 0.014113 0.985887 00:23
44 0.302073 0.056083 0.012097 0.987903 00:23
45 0.298332 0.054264 0.014113 0.985887 00:23
46 0.297571 0.050670 0.006048 0.993952 00:23
47 0.295835 0.053044 0.014113 0.985887 00:23
48 0.295003 0.053177 0.006048 0.993952 00:23
49 0.295658 0.070317 0.012097 0.987903 00:23
50 0.293548 0.051080 0.008064 0.991935 00:23
51 0.293651 0.061804 0.016129 0.983871 00:23
52 0.291592 0.044012 0.010081 0.989919 00:23
53 0.289374 0.047382 0.006048 0.993952 00:23
54 0.288036 0.050668 0.006048 0.993952 00:23
55 0.286396 0.057323 0.016129 0.983871 00:23
56 0.285277 0.049304 0.012097 0.987903 00:23
57 0.283157 0.047652 0.010081 0.989919 00:23
58 0.282250 0.046741 0.008064 0.991935 00:23
59 0.282024 0.043001 0.006048 0.993952 00:23
60 0.281422 0.043425 0.004032 0.995968 00:22
61 0.279792 0.048245 0.004032 0.995968 00:23
62 0.278259 0.050301 0.008064 0.991935 00:23
63 0.276450 0.042498 0.006048 0.993952 00:23
64 0.275557 0.043382 0.008064 0.991935 00:23
65 0.274992 0.046327 0.008064 0.991935 00:23
66 0.274949 0.051264 0.012097 0.987903 00:23
67 0.276006 0.050355 0.010081 0.989919 00:23
68 0.277512 0.047513 0.008064 0.991935 00:22
69 0.275122 0.044733 0.006048 0.993952 00:22
70 0.275745 0.042205 0.006048 0.993952 00:23
71 0.274163 0.041508 0.006048 0.993952 00:23
72 0.273943 0.042359 0.006048 0.993952 00:23
73 0.273899 0.042546 0.006048 0.993952 00:23
74 0.271843 0.044013 0.006048 0.993952 00:23
75 0.271735 0.043344 0.006048 0.993952 00:22
76 0.271392 0.045417 0.008064 0.991935 00:23
77 0.270836 0.044158 0.006048 0.993952 00:23
78 0.272595 0.043604 0.008064 0.991935 00:23
79 0.272234 0.044281 0.008064 0.991935 00:23

TTA(Test Time Augmentation)

preds,targs = learn.tta() # TTA applied for validation dataset
accuracy(preds, targs).item()
0.9959677457809448

We get a TTA of 99.59% on the validation set.

ClassificationInterpretationEx

We examine the model predictions in more depth:

import fastai

def _get_truths(vocab, label_idx, is_multilabel):
    if is_multilabel:
          return ';'.join([vocab[i] for i in torch.where(label_idx==1)][0])
    else: return vocab[label_idx]

class ClassificationInterpretationEx(ClassificationInterpretation):
    """
    Extend fastai2's `ClassificationInterpretation` to analyse model predictions in more depth
    See:
      * self.preds_df
      * self.plot_label_confidence()
      * self.plot_confusion_matrix()
      * self.plot_accuracy()
      * self.get_fnames()
      * self.plot_top_losses_grid()
      * self.print_classification_report()
    """
    def __init__(self, dl, inputs, preds, targs, decoded, losses):
        super().__init__(dl, inputs, preds, targs, decoded, losses)
        self.vocab = self.dl.vocab
        if is_listy(self.vocab): self.vocab = self.vocab[-1]
        if self.targs.__class__ == fastai.torch_core.TensorMultiCategory:
              self.is_multilabel = True
        else: self.is_multilabel = False
        self.compute_label_confidence()
        self.determine_classifier_type()

    def determine_classifier_type(self):
        if self.targs[0].__class__==fastai.torch_core.TensorCategory:
            self.is_multilabel = False
        if self.targs[0].__class__==fastai.torch_core.TensorMultiCategory:
            self.is_multilabel = True
            self.thresh = self.dl.loss_func.thresh

    def compute_label_confidence(self, df_colname:Optional[str]="fnames"):
        """
        Collate prediction confidence, filenames, and ground truth labels
        in DataFrames, and store them as class attributes
        `self.preds_df` and `self.preds_df_each`

        If the `DataLoaders` is constructed from a `pd.DataFrame`, use
        `df_colname` to specify the column name with the filepaths
        """
        if not isinstance(self.dl.items, pd.DataFrame):
            self._preds_collated = [
                #(item, self.dl.vocab[label_idx], *preds.numpy()*100)\
                (item, _get_truths(self.dl.vocab, label_idx, self.is_multilabel), *preds.numpy()*100)\
                for item,label_idx,preds in zip(self.dl.items,
                                                self.targs,
                                                self.preds)
            ]
        ## need to extract fname from DataFrame
        elif isinstance(self.dl.items, pd.DataFrame):
            self._preds_collated = [
                #(item[df_colname], self.dl.vocab[label_idx], *preds.numpy()*100)\
                (item[df_colname], _get_truths(self.dl.vocab, label_idx, self.is_multilabel), *preds.numpy()*100)\
                for (_,item),label_idx,preds in zip(self.dl.items.iterrows(),
                                                self.targs,
                                                self.preds)
            ]

        self.preds_df       = pd.DataFrame(self._preds_collated, columns = ['fname','truth', *self.dl.vocab])
        self.preds_df.insert(2, column='loss', value=self.losses.numpy())

        if self.is_multilabel: return # preds_df_each doesnt make sense for multi-label
        self._preds_df_each = {l:self.preds_df.copy()[self.preds_df.truth == l].reset_index(drop=True) for l in self.dl.vocab}
        self.preds_df_each  = defaultdict(dict)


        sort_desc = lambda x,col: x.sort_values(col, ascending=False).reset_index(drop=True)
        for label,df in self._preds_df_each.items():
            filt = df[label] == df[self.dl.vocab].max(axis=1)
            self.preds_df_each[label]['accurate']   = df.copy()[filt]
            self.preds_df_each[label]['inaccurate'] = df.copy()[~filt]

            self.preds_df_each[label]['accurate']   = sort_desc(self.preds_df_each[label]['accurate'], label)
            self.preds_df_each[label]['inaccurate'] = sort_desc(self.preds_df_each[label]['inaccurate'], label)
            assert len(self.preds_df_each[label]['accurate']) + len(self.preds_df_each[label]['inaccurate']) == len(df)

    def get_fnames(self, label:str,
                   mode:('accurate','inaccurate'),
                   conf_level:Union[int,float,tuple]) -> np.ndarray:
        """
        Utility function to grab filenames of a particular label `label` that were classified
        as per `mode` (accurate|inaccurate).
        These filenames are filtered by `conf_level` which can be above or below a certain
        threshold (above if `mode` == 'accurate' else below), or in confidence ranges
        """
        assert label in self.dl.vocab
        if not hasattr(self, 'preds_df_each'): self.compute_label_confidence()
        df = self.preds_df_each[label][mode].copy()
        if mode == 'accurate':
            if isinstance(conf_level, tuple):       filt = df[label].between(*conf_level)
            if isinstance(conf_level, (int,float)): filt = df[label] > conf_level
        if mode == 'inaccurate':
            if isinstance(conf_level, tuple):       filt = df[label].between(*conf_level)
            if isinstance(conf_level, (int,float)): filt = df[label] < conf_level
        return df[filt].fname.values
fname truth loss COVID non-COVID
0 sarscov2-ctscan-dataset/non-COVID/Non-Covid (386).png non-COVID 0.047037 4.594821 95.405182
1 sarscov2-ctscan-dataset/COVID/Covid (581).png COVID 0.018380 98.178818 1.821182

ClassificationInterpretationEx.get_fnames

Returns accuratly classified files with accuracy above 85%:

interp.get_fnames('accurate', 99.95)

Returns inaccurately classified files with accuracy between 84.1-85.2%:

interp.get_fnames('img1', 'accurate', (84.1, 85.2))

Confusion Matrix

Checking the Confusion Matrix:

Plot Accuracy

plotting curves of training process:

functions to plot the accuracy of the labels:

@patch
def plot_accuracy(self:ClassificationInterpretationEx, width=0.9, figsize=(6,6), return_fig=False,
                  title='Accuracy Per Label', ylabel='Accuracy (%)', style='ggplot',
                  color='#2a467e', vertical_labels=True):
    'Plot a bar plot showing accuracy per label'
    if not hasattr(self, 'preds_df_each'):
        raise NotImplementedError
    plt.style.use(style)
    if not hasattr(self, 'preds_df_each'): self.compute_label_confidence()
    self.accuracy_dict = defaultdict()

    for label,df in self.preds_df_each.items():
        total = len(df['accurate']) + len(df['inaccurate'])
        self.accuracy_dict[label] = 100 * len(df['accurate']) / total

    fig,ax = plt.subplots(figsize=figsize)

    x = self.accuracy_dict.keys()
    y = [v for k,v in self.accuracy_dict.items()]

    rects = ax.bar(x,y,width,color=color)
    for rect in rects:
        ht = rect.get_height()
        ax.annotate(s  = f"{ht:.02f}",
                    xy = (rect.get_x() + rect.get_width()/2, ht),
                    xytext = (0,3), # offset vertically by 3 points
                    textcoords = 'offset points',
                    ha = 'center', va = 'bottom'
                   )
    ax.set_ybound(lower=0, upper=100)
    ax.set_yticks(np.arange(0,110,10))
    ax.set_ylabel(ylabel)
    ax.set_xticklabels(x, rotation='vertical' if vertical_labels else 'horizontal')
    plt.suptitle(title)
    plt.tight_layout()
    if return_fig: return fig
<ipython-input-67-2c6afab5379f>:25: MatplotlibDeprecationWarning: The 's' parameter of annotate() has been renamed 'text' since Matplotlib 3.3; support for the old name will be dropped two minor releases later.
  ax.annotate(s  = f"{ht:.02f}",
<ipython-input-67-2c6afab5379f>:34: UserWarning: FixedFormatter should only be used together with FixedLocator
  ax.set_xticklabels(x, rotation='vertical' if vertical_labels else 'horizontal')

Plot Label Confidence

Plotting label confidence as histograms for each label:

@patch
def plot_label_confidence(self:ClassificationInterpretationEx, bins:int=5, fig_width:int=12, fig_height_base:int=4,
                          title:str='Accurate vs. Inaccurate Predictions Confidence (%) Levels Per Label',
                          return_fig:bool=False, label_bars:bool=True, style='ggplot', dpi=150,
                          accurate_color='#2a467e', inaccurate_color='#dc4a46'):
    """Plot label confidence histograms for each label
    Key Args:
      * `bins`:       No. of bins on each side of the plot
      * `return_fig`: If True, returns the figure that can be easily saved to disk
      * `label_bars`: If True, displays the % of samples that fall into each bar
      * `style`:      A matplotlib style. See `plt.style.available` for more
      * `accurate_color`:   Color of the accurate bars
      * `inaccurate_color`: Color of the inaccurate bars
    """
    if not hasattr(self, 'preds_df_each'):
        raise NotImplementedError
    plt.style.use(style)
    fig, axes = plt.subplots(nrows = len(self.preds_df_each.keys()), ncols=2, dpi=dpi,
                             figsize = (fig_width, fig_height_base * len(self.dl.vocab)))
    for i, (label, df) in enumerate(self.preds_df_each.items()):
        height=0
        # find max height
        for mode in ['inaccurate', 'accurate']:
            len_bins,_ = np.histogram(df[mode][label], bins=bins)
            if len_bins.max() > height: height=len_bins.max()

        for mode,ax in zip(['inaccurate', 'accurate'], axes[i]):
            range_ = (50,100) if mode == 'accurate' else (0,50)
            color  = accurate_color if mode == 'accurate' else inaccurate_color
            num,_,patches = ax.hist(df[mode][label], bins=bins, range=range_, rwidth=.95, color=color)
            num_samples = len(df['inaccurate'][label]) + len(df['accurate'][label])
            pct_share   = len(df[mode][label]) / num_samples
            if label_bars:
                for rect in patches:
                    ht = rect.get_height()
                    ax.annotate(s  = f"{round((int(ht) / num_samples) * 100, 1) if ht > 0 else 0}%",
                        xy = (rect.get_x() + rect.get_width()/2, ht),
                        xytext = (0,3), # offset vertically by 3 points
                        textcoords = 'offset points',
                        ha = 'center', va = 'bottom'
                       )
            ax.set_ybound(upper=height + height*0.3)
            ax.set_xlabel(f'{label}: {mode.capitalize()} ({round(pct_share * 100, 2)}%)')
            ax.set_ylabel(f'Num. {mode.capitalize()} ({len(df[mode][label])} of {num_samples})')
    fig.suptitle(title, y=1.0)
    plt.subplots_adjust(top = 0.9, bottom=0.01, hspace=0.25, wspace=0.2)
    plt.tight_layout()
    if return_fig: return fig
<ipython-input-69-1241d4a5b1b6>:37: MatplotlibDeprecationWarning: The 's' parameter of annotate() has been renamed 'text' since Matplotlib 3.3; support for the old name will be dropped two minor releases later.
  ax.annotate(s  = f"{round((int(ht) / num_samples) * 100, 1) if ht > 0 else 0}%",

Plot Top Losses grid

plotting the top losses in a grid:

from fastai_amalgam.utils import *
@patch
def plot_top_losses_grid(self:ClassificationInterpretationEx, k=16, ncol=4, __largest=True,
                         font_path=None, font_size=12, use_dedicated_layout=True) -> PIL.Image.Image:
    """Plot top losses in a grid

    Uses fastai'a `ClassificationInterpretation.plot_top_losses` to fetch
    predictions, and makes a grid with the ground truth labels, predictions,
    prediction confidence and loss ingrained into the image

    By default, `use_dedicated_layout` is used to plot the loss (bottom),
    truths (top-left), and predictions (top-right) in dedicated areas of the
    image. If this is set to `False`, everything is printed at the bottom of the
    image
    """
    # all of the pred fetching code is copied over from
    # fastai's `ClassificationInterpretation.plot_top_losses`
    # and only plotting code is added here
    losses,idx = self.top_losses(k, largest=__largest)
    if not isinstance(self.inputs, tuple): self.inputs = (self.inputs,)
    if isinstance(self.inputs[0], Tensor): inps = tuple(o[idx] for o in self.inputs)
    else: inps = self.dl.create_batch(self.dl.before_batch([tuple(o[i] for o in self.inputs) for i in idx]))
    b = inps + tuple(o[idx] for o in (self.targs if is_listy(self.targs) else (self.targs,)))
    x,y,its = self.dl._pre_show_batch(b, max_n=k)
    b_out = inps + tuple(o[idx] for o in (self.decoded if is_listy(self.decoded) else (self.decoded,)))
    x1,y1,outs = self.dl._pre_show_batch(b_out, max_n=k)
    #if its is not None:
    #    _plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), self.preds[idx], losses,  **kwargs)
    plot_items = its.itemgot(0), its.itemgot(1), outs.itemgot(slice(len(inps), None)), self.preds[idx], losses
    def draw_label(x:TensorImage, labels):
        return PILImage.create(x).draw_labels(labels, font_path=font_path, font_size=font_size, location="bottom")
    # return plot_items
    results = []
    for x, truth, preds, preds_raw, loss in zip(*plot_items):
        if self.is_multilabel:
            preds = preds[0]
        probs_i = np.array([self.dl.vocab.o2i[o] for o in preds])
        pred2prob = [f"{pred} ({round(prob.item()*100,2)}%)" for pred,prob in zip(preds,preds_raw[probs_i])]
        if use_dedicated_layout:
            # draw loss at the bottom, preds on top-right
            # and truths on the top
            img = PILImage.create(x)
            if isinstance(truth, Category): truth = [truth]
            truth.insert(0, "TRUTH: ")
            pred2prob.insert(0, 'PREDS: ')
            loss_text = f"{'LOSS: '.rjust(8)} {round(loss.item(), 4)}"
            img.draw_labels(truth,     location="top-left", font_size=font_size, font_path=font_path)
            img.draw_labels(pred2prob, location="top-right", font_size=font_size, font_path=font_path)
            img.draw_labels(loss_text, location="bottom", font_size=font_size, font_path=font_path)
            results.append(img)
        else:
            # draw everything at the bottom
            out = []
            out.append(f"{'TRUTH: '.rjust(8)} {truth}")
            bsl = '\n' # since f-strings can't have backslashes
            out.append(f"{'PRED: '.rjust(8)} {bsl.join(pred2prob)}")
            if self.is_multilabel: out.append('\n')
            out.append(f"{'LOSS: '.rjust(8)} {round(loss.item(), 4)}")
            results.append(draw_label(x, out))
    return make_img_grid(results, img_size=None, ncol=ncol)
/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai_amalgam/utils.py:92: UserWarning: Loaded default PIL ImageFont. It's highly recommended you use a custom font as the default font's size cannot be tweaked
  warnings.warn("Loaded default PIL ImageFont. It's highly recommended you use a custom font as the default font's size cannot be tweaked")
/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai_amalgam/utils.py:94: UserWarning: `font_size` cannot be used when not using a custom font passed via `font_path`
  warnings.warn(f"`font_size` cannot be used when not using a custom font passed via `font_path`")

PlotLowest Losses Grid

plotting the lowest losses in a grid fashion:

@patch
@delegates(to=ClassificationInterpretationEx.plot_top_losses_grid, but=['largest'])
def plot_lowest_losses_grid(self:ClassificationInterpretationEx, **kwargs):
    """Plot the lowest losses. Exact opposite of `ClassificationInterpretationEx.plot_top_losses`
    """
    return self.plot_top_losses_grid(__largest=False, **kwargs)

Classification Report

scikit-learn Classification report:

import sklearn.metrics as skm
@patch
def print_classification_report(self:ClassificationInterpretationEx, as_dict=False):
    "Get scikit-learn classification report"
    # `flatten_check` and `skm.classification_report` don't play
    # nice together for multi-label
    # d,t = flatten_check(self.decoded, self.targs)
    d,t = self.decoded, self.targs
    if as_dict:
          return skm.classification_report(t, d, labels=list(self.vocab.o2i.values()), target_names=[str(v) for v in self.vocab], output_dict=True)
    else: return skm.classification_report(t, d, labels=list(self.vocab.o2i.values()), target_names=[str(v) for v in self.vocab], output_dict=False)
              precision    recall  f1-score   support

       COVID       0.99      0.98      0.99       243
   non-COVID       0.98      0.99      0.99       253

    accuracy                           0.99       496
   macro avg       0.99      0.99      0.99       496
weighted avg       0.99      0.99      0.99       496

TTA (Test Time Augmentation)

Getting the TTA Score on the validation set:

0.9939516186714172

Checking the confusion matrix:

Exporting the learner into a pickle file:

learn.export()
(#1) [Path('export.pkl')]

Resnet-50 Test

We train with smaller images of sizes 128*128 rather than orignal size of the image and also smaller batch sizes for faster training.

dls2=get_dls(128,128)

learn2 = cnn_learner(dls2, xresnet50, metrics=[error_rate,accuracy], wd=1e-1).to_fp16()

Running the l.r Finder:

print(f"Minimum/10: {lr_min:.2e}, steepest point: {lr_steep:.2e}")
Minimum/10: 8.32e-03, steepest point: 3.31e-04

Training the model in first run:

learn2.fit_one_cycle(5, 3e-3)
epoch train_loss valid_loss error_rate accuracy time
0 0.825077 2.717013 0.504032 0.495968 00:09
1 0.637306 0.874084 0.328629 0.671371 00:09
2 0.538363 0.493991 0.203629 0.796371 00:09
3 0.473004 0.253523 0.122984 0.877016 00:09
4 0.427544 0.230788 0.098790 0.901210 00:09

plotting the curves of training process:

Unfreezing the model and then running l.r finder again for getting the optimal l.r rate (FineTuneing Approach):

SuggestedLRs(lr_min=6.309573450380412e-08, lr_steep=2.75422871709452e-06)

print(f"Minimum/10: {lr_min:.2e}, steepest point: {lr_steep:.2e}")
Minimum/10: 8.32e-03, steepest point: 3.31e-04
learn2.dls2 = get_dls(12, 224)# training on orignal size 
learn2.fit_one_cycle( 12, slice(1e-5, 1e-4))
epoch train_loss valid_loss error_rate accuracy time
0 0.326667 0.237666 0.106855 0.893145 00:10
1 0.314330 0.248609 0.110887 0.889113 00:10
2 0.306127 0.233944 0.090726 0.909274 00:10
3 0.306658 0.229167 0.094758 0.905242 00:10
4 0.305483 0.281535 0.125000 0.875000 00:10
5 0.293949 0.245766 0.102823 0.897177 00:10
6 0.290125 0.226233 0.102823 0.897177 00:10
7 0.279198 0.230645 0.110887 0.889113 00:10
8 0.271748 0.244468 0.110887 0.889113 00:10
9 0.270165 0.208932 0.086694 0.913306 00:10
10 0.268667 0.208460 0.084677 0.915323 00:10
11 0.264588 0.216885 0.096774 0.903226 00:10

Checking the curves again:

Checking the Confusion Matrix:

interp = ClassificationInterpretation.from_learner(learn2)# plot confusion matrix
interp.plot_confusion_matrix(figsize=(12,12), dpi=50)

Plotting top losses

learn2.save('resnet50run')
Path('models/resnet50run.pth')
learn2=learn2.load('resnet50run')

end test

interp = ClassificationInterpretation.from_learner(learn)# plot confusion matrix
interp.plot_confusion_matrix(figsize=(12,12), dpi=50)

interp.plot_top_losses(5, nrows=10)# plot top losses 
learn1=load_learner("export.pkl")

GradCam Testing

Steps for plotting GradCAM:

  1. Create your Learner's test_dl w.r.t. one image and label-Compute activations (forward pass) and gradients (backward pass)
  2. Compute gradcam-map (7x7 in this case)
  3. Take mean of gradients across feature maps: (1280, 7, 7) --> (1280, 1, 1)
  4. Multiply mean activation: (1280,1,1) (1280,7,7) --> (1280,7,7)
  5. Sum (B) across all 1280 channels: (1280,7,7) --> (7,7)
  6. Plot gradcam-map over the image
  7. These steps are shown below one by one and later combined in a Learner.gradcam call

1. Create Learner's test_dl w.r.t. one image and label

def create_test_img(learn, f, return_img=True):
    img = PILImage.create(f)
    x = first(learn.dls.test_dl([f]))
    x = x[0]
    if return_img: return img,x
    return x

2. Compute activations (forward pass) and gradients (backward pass)

def get_label_idx(learn:Learner, preds:torch.Tensor,
                  label:Union[str,int,None]) -> Tuple[int,str]:
    """Either:
    * Get the label idx of a specific `label`
    * Get the max pred using `learn.loss_func.decode` and `learn.loss_func.activation`
        * Only works for `softmax` activations as the backward pass requires a scalar index
        * Throws a `RuntimeError` if the activation is a `sigmoid` activation
    """
    if label is not None:
        # if `label` is a string, check that it exists in the vocab
        # and return the label's index
        if isinstance(label,str):
            if not label in learn.dls.vocab: raise ValueError(f"'{label}' is not part of the Learner's vocab: {learn.dls.vocab}")
            return learn.dls.vocab.o2i[label], label
        # if `label` is an index, return itself
        elif isinstance(label,int): return label, learn.dls.vocab[label]
        else: raise TypeError(f"Expected `str`, `int` or `None`, got {type(label)} instead")
    else:
        # if no `label` is specified, check that `learn.loss_func` has `decodes`
        # and `activation` implemented, run the predictions through them,
        # then check that the output length is 1. If not, the activation must be
        # sigmoid, which is incompatible
        if not hasattr(learn.loss_func, 'activation') or\
           not hasattr(learn.loss_func, 'decodes'):
            raise NotImplementedError(f"learn.loss_func does not have `.activation` or `.decodes` methods implemented")
        decode_pred = compose(learn.loss_func.activation, learn.loss_func.decodes)
        label_idx   = decode_pred(preds)
        if len(label_idx) > 1:
            raise RuntimeError(f"Output label idx must be of length==1. If your loss func has a sigmoid activation, please specify `label`")
        return label_idx, learn.dls.vocab[label_idx][0]

def compute_gcam_items(learn: Learner,
                       x: TensorImage,
                       label: Union[str,int,None] = None,
                       target_layer: Union[nn.Module, Callable, None] = None
                      ) -> Tuple[torch.Tensor]:
    """Compute gradient and activations of `target_layer` of `learn.model`
    for `x` with respect to `label`.

    If `target_layer` is None, then it is set to `learn.model[:-1]`
    """
    to_cuda(learn.model, x)
    target_layer = get_target_layer(learn, target_layer)
    with HookBwd(target_layer) as hook_g:
        with Hook(target_layer) as hook:
            preds       = learn.model.eval()(x)
            activations = hook.stored
            label_idx, label = get_label_idx(learn,preds,label)
            #print(preds.shape, label, label_idx)
            #print(preds)
        preds[0, label_idx].backward()
        gradients = hook_g.stored

    preds = getattr(learn.loss_func, 'activation', noop)(preds)

    # remove the leading batch_size axis
    
    gradients   = gradients  [0]
    activations = activations[0]
    preds       = preds.detach().cpu().numpy().flatten()
    return gradients, activations, preds, label

shapes of gradients, activations and predictions:

<ipython-input-137-c549b6a525cc>:6: UserWarning: Detected a pooling layer in the model body. Unless this is intentional, ensure that the feature map is not flattened
  warnings.warn(f"Detected a pooling layer in the model body. Unless this is intentional, ensure that the feature map is not flattened")
(torch.Size([2048, 7, 7]), torch.Size([2048, 7, 7]), (2,), 'COVID')

3. Compute gradcam-map

def compute_gcam_map(gradients, activations) -> torch.Tensor:
    """Take the mean of `gradients`, multiply by `activations`,
    sum it up and return a GradCAM feature map
    """
    # Mean over the feature maps. If you don't use `keepdim`, it returns
    # a value of shape (1280) which isn't amenable to `*` with the activations
    gcam_weights = gradients.mean(dim=[1,2], keepdim=True) # (1280,7,7)   --> (1280,1,1)
    gcam_map     = (gcam_weights * activations) # (1280,1,1) * (1280,7,7) --> (1280,7,7)
    gcam_map     = gcam_map.sum(0)              # (1280,7,7) --> (7,7)
    return gcam_map
gcam_map = compute_gcam_map(gradients, activations)
gcam_map.shape
torch.Size([7, 7])

4. Plot gradcam-map over the image

plotting Grad Cam over image

plot_gcam(learn, img3, x, gcam_map, full_size=True, dpi=300)

learn.model[1]
Sequential(
  (0): AdaptiveConcatPool2d(
    (ap): AdaptiveAvgPool2d(output_size=1)
    (mp): AdaptiveMaxPool2d(output_size=1)
  )
  (1): Flatten(full=False)
  (2): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Dropout(p=0.25, inplace=False)
  (4): Linear(in_features=4096, out_features=512, bias=False)
  (5): ReLU(inplace=True)
  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): Dropout(p=0.5, inplace=False)
  (8): Linear(in_features=512, out_features=2, bias=False)
)

learn.gradcam(item=im, target_layer=learn.model[0])
/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai_amalgam/utils.py:92: UserWarning: Loaded default PIL ImageFont. It's highly recommended you use a custom font as the default font's size cannot be tweaked
  warnings.warn("Loaded default PIL ImageFont. It's highly recommended you use a custom font as the default font's size cannot be tweaked")

GUI Building

Creating Buttons:

btn_upload = widgets.FileUpload()
btn_upload
img= PILImage.create(btn_upload.data[-1])
img.shape
(350, 408)
out_pl = widgets.Output()
out_pl.clear_output()
with out_pl: display(img.to_thumb(384,404))
out_pl
dls.vocab
['COVID', 'non-COVID']
pred,pred_idx,probs = learn.predict(img)
lbl_pred = widgets.Label()
lbl_pred.value = f'Prediction: {pred}; Probability: {probs[pred_idx]:.04f}'
lbl_pred
btn_run = widgets.Button(description='Classify',layout=Layout(width='40%', height='80px'), button_style='success')
btn_run

Click event handler adds functionallity to butttons:

def on_click_classify(change):
    img = PILImage.create(btn_upload.data[-1])
    out_pl.clear_output()
    with out_pl: display(img.to_thumb(320,320))
    pred,pred_idx,probs = learn_inf.predict(img)
    lbl_pred.value = f'Prediction: {pred}; Probability: {probs[pred_idx]:.04f}'

btn_run.on_click(on_click_classify)

Adding heatmaps button and functionallity:

HeatMp = widgets.Button(description='MAGIC', layout=btn_run.layout, button_style='danger')
HeatMp
def on_click_map(change):
    with out_pl: display(img.to_thumb(320,320))
    learn.gradcam(img).clear(out_pl)
HeatMp.on_click(on_click_map)

Putting all the pieces together in a Vertical Stack for the final GUI:

VBox([widgets.Label('INPUT YOUR CT SCAN IMAGE FOR DETECTION!'),
      btn_upload, btn_run, out_pl, lbl_pred,widgets.Label('Do You Want to See How our Model Decides which is Covid and Which is not?'),widgets.Label("Click Here To Learn how These predictions are made"), HeatMp])

If You want to see the GUI that I built for this Project check out my other blog post named: Covify

What Worked?

  • Using a pretrained model Resnet reduced training time and improved results.
  • Data Augmentations reduced overfitting.
  • The Mixup approach worked like a charm and also prevented overfitting.
  • Presizing approaches worked.
  • I tried Progressive Resizing approach and it greatly improved results and reduced training time.

What didn't?

  • I tried implementing bottleneck layers design on resnets but training was unstable.
  • I tried a deeper vanilla Resnet 101 model but did not noticed a a significant difference.

Other ideas to improve the results?

  • Trying with Diffrent Architectures like Densenet, Efficient Net etc.
  • Trying out diffrent metrics and improving on them for better results.
  • More Compute: Deeper Models. Use cross-validation with several folds and Ensemble models.

Thank you for reading this far!😊This was a great challenge and I learned a lot throughout this process.There is also a lot of room for improvement and work to do :)