Introduction

In this notebook I implement a neural network based solution for building footprint detection on the SpaceNet7 dataset. I ignore the temporal aspect of the orginal challenge and focus on performing segmentation to detect buildings on single images. I use fastai, a deep learning library based on PyTorch. It provides functionality to train neural networks with modern best practices while reducing the amount of boilerplate code required.

Dataset Downloading

from google.colab import drive
drive.mount('/content/gdrive')
Mounted at /content/gdrive
cd /content/gdrive/Shareddrives/Undrive
/content/gdrive/Shareddrives/Undrive
# !unzip spacenet-7-multitemporal-urban-development.zip -d s7

START

The dataset is stored on AWS. Instructions how to install are here.

Installing Libraries and Preparing requirements.txt for reproducbillity

 
!pip freeze > requirements.txt
ls
models/           s7/                                             wandb/
requirements.txt  spacenet-7-multitemporal-urban-development.zip

Setup

cd /content/gdrive/Shareddrives/Undrive/s7/SN7_buildings_train/train/
/content/gdrive/Shareddrives/Undrive/s7/SN7_buildings_train/train
 
!ls
L15-0331E-1257N_1327_3160_13  L15-1203E-1203N_4815_3378_13
L15-0357E-1223N_1429_3296_13  L15-1204E-1202N_4816_3380_13
L15-0358E-1220N_1433_3310_13  L15-1204E-1204N_4819_3372_13
L15-0361E-1300N_1446_2989_13  L15-1209E-1113N_4838_3737_13
L15-0368E-1245N_1474_3210_13  L15-1210E-1025N_4840_4088_13
L15-0387E-1276N_1549_3087_13  L15-1276E-1107N_5105_3761_13
L15-0434E-1218N_1736_3318_13  L15-1289E-1169N_5156_3514_13
L15-0457E-1135N_1831_3648_13  L15-1296E-1198N_5184_3399_13
L15-0487E-1246N_1950_3207_13  L15-1298E-1322N_5193_2903_13
L15-0506E-1204N_2027_3374_13  L15-1335E-1166N_5342_3524_13
L15-0544E-1228N_2176_3279_13  L15-1389E-1284N_5557_3054_13
L15-0566E-1185N_2265_3451_13  L15-1438E-1134N_5753_3655_13
L15-0571E-1075N_2287_3888_13  L15-1439E-1134N_5759_3655_13
L15-0577E-1243N_2309_3217_13  L15-1479E-1101N_5916_3785_13
L15-0586E-1127N_2345_3680_13  L15-1481E-1119N_5927_3715_13
L15-0595E-1278N_2383_3079_13  L15-1538E-1163N_6154_3539_13
L15-0614E-0946N_2459_4406_13  L15-1615E-1205N_6460_3370_13
L15-0632E-0892N_2528_4620_13  L15-1615E-1206N_6460_3366_13
L15-0683E-1006N_2732_4164_13  L15-1617E-1207N_6468_3360_13
L15-0760E-0887N_3041_4643_13  L15-1669E-1153N_6678_3579_13
L15-0924E-1108N_3699_3757_13  L15-1669E-1160N_6678_3548_13
L15-0977E-1187N_3911_3441_13  L15-1669E-1160N_6679_3549_13
L15-1014E-1375N_4056_2688_13  L15-1672E-1207N_6691_3363_13
L15-1015E-1062N_4061_3941_13  L15-1690E-1211N_6763_3346_13
L15-1025E-1366N_4102_2726_13  L15-1691E-1211N_6764_3347_13
L15-1049E-1370N_4196_2710_13  L15-1703E-1219N_6813_3313_13
L15-1138E-1216N_4553_3325_13  L15-1709E-1112N_6838_3742_13
L15-1172E-1306N_4688_2967_13  L15-1716E-1211N_6864_3345_13
L15-1185E-0935N_4742_4450_13  models
L15-1200E-0847N_4802_4803_13  wandb
 
L15-0331E-1257N_1327_3160_13  L15-1200E-0847N_4802_4803_13
L15-0357E-1223N_1429_3296_13  L15-1203E-1203N_4815_3378_13
L15-0358E-1220N_1433_3310_13  L15-1204E-1202N_4816_3380_13
L15-0361E-1300N_1446_2989_13  L15-1204E-1204N_4819_3372_13
L15-0368E-1245N_1474_3210_13  L15-1209E-1113N_4838_3737_13
L15-0387E-1276N_1549_3087_13  L15-1210E-1025N_4840_4088_13
L15-0434E-1218N_1736_3318_13  L15-1276E-1107N_5105_3761_13
L15-0457E-1135N_1831_3648_13  L15-1289E-1169N_5156_3514_13
L15-0487E-1246N_1950_3207_13  L15-1296E-1198N_5184_3399_13
L15-0506E-1204N_2027_3374_13  L15-1298E-1322N_5193_2903_13
L15-0544E-1228N_2176_3279_13  L15-1335E-1166N_5342_3524_13
L15-0566E-1185N_2265_3451_13  L15-1389E-1284N_5557_3054_13
L15-0571E-1075N_2287_3888_13  L15-1438E-1134N_5753_3655_13
L15-0577E-1243N_2309_3217_13  L15-1439E-1134N_5759_3655_13
L15-0586E-1127N_2345_3680_13  L15-1479E-1101N_5916_3785_13
L15-0595E-1278N_2383_3079_13  L15-1481E-1119N_5927_3715_13
L15-0614E-0946N_2459_4406_13  L15-1538E-1163N_6154_3539_13
L15-0632E-0892N_2528_4620_13  L15-1615E-1205N_6460_3370_13
L15-0683E-1006N_2732_4164_13  L15-1615E-1206N_6460_3366_13
L15-0760E-0887N_3041_4643_13  L15-1617E-1207N_6468_3360_13
L15-0924E-1108N_3699_3757_13  L15-1669E-1153N_6678_3579_13
L15-0977E-1187N_3911_3441_13  L15-1669E-1160N_6678_3548_13
L15-1014E-1375N_4056_2688_13  L15-1669E-1160N_6679_3549_13
L15-1015E-1062N_4061_3941_13  L15-1672E-1207N_6691_3363_13
L15-1025E-1366N_4102_2726_13  L15-1690E-1211N_6763_3346_13
L15-1049E-1370N_4196_2710_13  L15-1691E-1211N_6764_3347_13
L15-1138E-1216N_4553_3325_13  L15-1703E-1219N_6813_3313_13
L15-1172E-1306N_4688_2967_13  L15-1709E-1112N_6838_3742_13
L15-1185E-0935N_4742_4450_13  L15-1716E-1211N_6864_3345_13
path.ls()
(#59) [Path('L15-0331E-1257N_1327_3160_13'),Path('L15-0357E-1223N_1429_3296_13'),Path('L15-0358E-1220N_1433_3310_13'),Path('L15-0361E-1300N_1446_2989_13'),Path('L15-0368E-1245N_1474_3210_13'),Path('L15-0387E-1276N_1549_3087_13'),Path('L15-0434E-1218N_1736_3318_13'),Path('L15-0457E-1135N_1831_3648_13'),Path('L15-0487E-1246N_1950_3207_13'),Path('L15-0506E-1204N_2027_3374_13')...]

Defining training parameters:

cd /content/gdrive/Shareddrives/Undrive/s7/SN7_buildings_train/train
/content/gdrive/Shareddrives/Undrive/s7/SN7_buildings_train/train
ls
L15-0331E-1257N_1327_3160_13/  L15-1200E-0847N_4802_4803_13/
L15-0357E-1223N_1429_3296_13/  L15-1203E-1203N_4815_3378_13/
L15-0358E-1220N_1433_3310_13/  L15-1204E-1202N_4816_3380_13/
L15-0361E-1300N_1446_2989_13/  L15-1204E-1204N_4819_3372_13/
L15-0368E-1245N_1474_3210_13/  L15-1209E-1113N_4838_3737_13/
L15-0387E-1276N_1549_3087_13/  L15-1210E-1025N_4840_4088_13/
L15-0434E-1218N_1736_3318_13/  L15-1276E-1107N_5105_3761_13/
L15-0457E-1135N_1831_3648_13/  L15-1289E-1169N_5156_3514_13/
L15-0487E-1246N_1950_3207_13/  L15-1296E-1198N_5184_3399_13/
L15-0506E-1204N_2027_3374_13/  L15-1298E-1322N_5193_2903_13/
L15-0544E-1228N_2176_3279_13/  L15-1335E-1166N_5342_3524_13/
L15-0566E-1185N_2265_3451_13/  L15-1389E-1284N_5557_3054_13/
L15-0571E-1075N_2287_3888_13/  L15-1438E-1134N_5753_3655_13/
L15-0577E-1243N_2309_3217_13/  L15-1439E-1134N_5759_3655_13/
L15-0586E-1127N_2345_3680_13/  L15-1479E-1101N_5916_3785_13/
L15-0595E-1278N_2383_3079_13/  L15-1481E-1119N_5927_3715_13/
L15-0614E-0946N_2459_4406_13/  L15-1538E-1163N_6154_3539_13/
L15-0632E-0892N_2528_4620_13/  L15-1615E-1205N_6460_3370_13/
L15-0683E-1006N_2732_4164_13/  L15-1615E-1206N_6460_3366_13/
L15-0760E-0887N_3041_4643_13/  L15-1617E-1207N_6468_3360_13/
L15-0924E-1108N_3699_3757_13/  L15-1669E-1153N_6678_3579_13/
L15-0977E-1187N_3911_3441_13/  L15-1669E-1160N_6678_3548_13/
L15-1014E-1375N_4056_2688_13/  L15-1669E-1160N_6679_3549_13/
L15-1015E-1062N_4061_3941_13/  L15-1672E-1207N_6691_3363_13/
L15-1025E-1366N_4102_2726_13/  L15-1690E-1211N_6763_3346_13/
L15-1049E-1370N_4196_2710_13/  L15-1691E-1211N_6764_3347_13/
L15-1138E-1216N_4553_3325_13/  L15-1703E-1219N_6813_3313_13/
L15-1172E-1306N_4688_2967_13/  L15-1709E-1112N_6838_3742_13/
L15-1185E-0935N_4742_4450_13/  L15-1716E-1211N_6864_3345_13/
BATCH_SIZE = 12 #   (3 for xresnet50, 12 for xresnet34 with Tesla P100/T4)
TILES_PER_SCENE = 16
ARCHITECTURE = xresnet34
EPOCHS = 40
CLASS_WEIGHTS = [0.25,0.75]
LR_MAX = 3e-4
ENCODER_FACTOR = 10
CODES = ['Land','Building']
# Weights and Biases config
config_dictionary = dict(
    bs=BATCH_SIZE,
    tiles_per_scene=TILES_PER_SCENE,
    architecture = str(ARCHITECTURE),
    epochs = EPOCHS,
    class_weights = CLASS_WEIGHTS,
    lr_max = LR_MAX,
    encoder_factor = ENCODER_FACTOR
)
BATCH_SIZE = 12 # 3 for xresnet50, 12 for xresnet34 with Tesla P100 (16GB)
TILES_PER_SCENE = 16
ARCHITECTURE = xresnet34
EPOCHS = 80
CLASS_WEIGHTS = [0.25,0.75]
LR_MAX = 3e-4
ENCODER_FACTOR = 10
CODES = ['Land','Building']
BATCH_SIZE = 3 # 3 for xresnet50, 12 for xresnet34 with Tesla P100 (16GB)
TILES_PER_SCENE = 16
ARCHITECTURE = xresnet50
EPOCHS = 40
CLASS_WEIGHTS = [0.25,0.75]
LR_MAX = 3e-4
ENCODER_FACTOR = 10
CODES = ['Land','Building']
!ls
models		  s7						  wandb
requirements.txt  spacenet-7-multitemporal-urban-development.zip

Data-Preprocessing

Exploring dataset structure, display sample scene directories:

!nvidia-smi
Tue Jul  6 07:44:58 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   37C    P8     9W /  70W |      3MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

scenes = path.ls().sorted()
print(f'Numer of scenes: {len(scenes)}')
pprint(list(scenes)[:5])
Numer of scenes: 59
[Path('.ipynb_checkpoints'),
 Path('L15-0331E-1257N_1327_3160_13'),
 Path('L15-0357E-1223N_1429_3296_13'),
 Path('L15-0358E-1220N_1433_3310_13'),
 Path('L15-0361E-1300N_1446_2989_13')]

Which folders are in each scene (the last three have been added later during processing)

sample_scene = (path/'L15-0683E-1006N_2732_4164_13')
pprint(list(sample_scene.ls()))
[Path('L15-0683E-1006N_2732_4164_13/UDM_masks'),
 Path('L15-0683E-1006N_2732_4164_13/images'),
 Path('L15-0683E-1006N_2732_4164_13/images_masked'),
 Path('L15-0683E-1006N_2732_4164_13/labels'),
 Path('L15-0683E-1006N_2732_4164_13/labels_match'),
 Path('L15-0683E-1006N_2732_4164_13/labels_match_pix'),
 Path('L15-0683E-1006N_2732_4164_13/binary_mask'),
 Path('L15-0683E-1006N_2732_4164_13/img_tiles'),
 Path('L15-0683E-1006N_2732_4164_13/mask_tiles')]

How many images are in a specific scene:

images_masked = (sample_scene/'images_masked').ls().sorted()
labels = (sample_scene/'labels_match').ls().sorted()
print(f'Numer of images in scene: {len(images_masked)}')
pprint(list(images_masked[:5]))
Numer of images in scene: 22
[Path('L15-0683E-1006N_2732_4164_13/images_masked/global_monthly_2018_01_mosaic_L15-0683E-1006N_2732_4164_13.tif'),
 Path('L15-0683E-1006N_2732_4164_13/images_masked/global_monthly_2018_02_mosaic_L15-0683E-1006N_2732_4164_13.tif'),
 Path('L15-0683E-1006N_2732_4164_13/images_masked/global_monthly_2018_03_mosaic_L15-0683E-1006N_2732_4164_13.tif'),
 Path('L15-0683E-1006N_2732_4164_13/images_masked/global_monthly_2018_04_mosaic_L15-0683E-1006N_2732_4164_13.tif'),
 Path('L15-0683E-1006N_2732_4164_13/images_masked/global_monthly_2018_06_mosaic_L15-0683E-1006N_2732_4164_13.tif')]

There are 58 scenes of 4km x 4km in the dataset, each containing about 24 images over the span of two years.

Let's pick one example image and its polygons:

image, shapes = images_masked[0], labels[0]

We use the images that have UDM masks where clouds were in the original picture:

show_image(PILImage.create(image), figsize=(12,12));

Creating binary masks

This is a function to generate binary mask images from geojson vector files. Source

import rasterio
from rasterio.plot import reshape_as_image
import rasterio.mask
from rasterio.features import rasterize

import pandas as pd
import geopandas as gpd
from shapely.geometry import mapping, Point, Polygon
from shapely.ops import cascaded_union

# SOURCE:  https://lpsmlgeo.github.io/2019-09-22-binary_mask/

def generate_mask(raster_path, shape_path, output_path=None, file_name=None):

    """Function that generates a binary mask from a vector file (shp or geojson)
    raster_path = path to the .tif;
    shape_path = path to the shapefile or GeoJson.
    output_path = Path to save the binary mask.
    file_name = Name of the file.
    """
    
    #load raster
    
    with rasterio.open(raster_path, "r") as src:
        raster_img = src.read()
        raster_meta = src.meta
    
    #load o shapefile ou GeoJson
    train_df = gpd.read_file(shape_path)
    
    #Verify crs
    if train_df.crs != src.crs:
        print(" Raster crs : {}, Vector crs : {}.\n Convert vector and raster to the same CRS.".format(src.crs,train_df.crs))
        
        
    #Function that generates the mask
    def poly_from_utm(polygon, transform):
        poly_pts = []

        poly = cascaded_union(polygon)
        for i in np.array(poly.exterior.coords):

            poly_pts.append(~transform * tuple(i))

        new_poly = Polygon(poly_pts)
        return new_poly
    
    
    poly_shp = []
    im_size = (src.meta['height'], src.meta['width'])
    for num, row in train_df.iterrows():
        if row['geometry'].geom_type == 'Polygon':
            poly = poly_from_utm(row['geometry'], src.meta['transform'])
            poly_shp.append(poly)
        else:
            for p in row['geometry']:
                poly = poly_from_utm(p, src.meta['transform'])
                poly_shp.append(poly)

    #set_trace()
    
    if len(poly_shp) > 0:
      mask = rasterize(shapes=poly_shp,
                      out_shape=im_size)
    else:
      mask = np.zeros(im_size)
    
    # Save or show mask
    mask = mask.astype("uint8")    
    bin_mask_meta = src.meta.copy()
    bin_mask_meta.update({'count': 1})
    if (output_path != None and file_name != None):
      os.chdir(output_path)
      with rasterio.open(file_name, 'w', **bin_mask_meta) as dst:
          dst.write(mask * 255, 1)
    else: 
      return mask

Show a mask:

mask = generate_mask(image, shapes)
plt.figure(figsize=(12,12))
plt.tight_layout()
plt.xticks([])
plt.yticks([])
plt.imshow(mask,cmap='cividis');

Note: We can see that there - correctly - are no buildings in the mask where the UDM mask is.

Now we create and save a mask file for every image in the 'images_masked' folder of every scene.

path.ls()
(#59) [Path('L15-0331E-1257N_1327_3160_13'),Path('L15-0357E-1223N_1429_3296_13'),Path('L15-0358E-1220N_1433_3310_13'),Path('L15-0361E-1300N_1446_2989_13'),Path('L15-0368E-1245N_1474_3210_13'),Path('L15-0387E-1276N_1549_3087_13'),Path('L15-0434E-1218N_1736_3318_13'),Path('L15-0457E-1135N_1831_3648_13'),Path('L15-0487E-1246N_1950_3207_13'),Path('L15-0506E-1204N_2027_3374_13')...]

def save_masks():
  for scene in tqdm(path.ls().sorted()):
    for img in (scene/'images_masked').ls():
      shapes = scene/'labels_match'/(img.name[:-4]+'_Buildings.geojson')
      if not os.path.exists(scene/'binary_mask'/img.name):
        if not os.path.exists(scene/'binary_mask'):
          os.makedirs(scene/'binary_mask')
        generate_mask(img, shapes, scene/'binary_mask', img.name)
save_masks()

As mask creation failed on one image for no obvious reason. I simply deleted it from the training set.

Creating subset of dataset

Let's look at how the images in a scene change over time:

s7

We can see that the ~24 images of every scene are quite similar. The vegetation changes with the seasons and some scenes show building activity, but overall the similarities are greater than the differences.

Therefore I decided to ignore most images. I originally planned to keep every fifth image of every scene, so for example January, June, November, April, and September. This way we could make use of the variability of the different seasons. But it turned out that just selecting one image per scene yielded similar results with a fraction of the training time.

def get_masked_images(path:Path, n=1)->list:
  "Returns the first `n` pictures from every scene"
  files = []
  for folder in path.ls():
    files.extend(get_image_files(path=folder, folders='images_masked')[:n])
  return files
path.ls()
(#59) [Path('L15-0331E-1257N_1327_3160_13'),Path('L15-0357E-1223N_1429_3296_13'),Path('L15-0358E-1220N_1433_3310_13'),Path('L15-0361E-1300N_1446_2989_13'),Path('L15-0368E-1245N_1474_3210_13'),Path('L15-0387E-1276N_1549_3087_13'),Path('L15-0434E-1218N_1736_3318_13'),Path('L15-0457E-1135N_1831_3648_13'),Path('L15-0487E-1246N_1950_3207_13'),Path('L15-0506E-1204N_2027_3374_13')...]
!ls
L15-0331E-1257N_1327_3160_13  L15-1200E-0847N_4802_4803_13
L15-0357E-1223N_1429_3296_13  L15-1203E-1203N_4815_3378_13
L15-0358E-1220N_1433_3310_13  L15-1204E-1202N_4816_3380_13
L15-0361E-1300N_1446_2989_13  L15-1204E-1204N_4819_3372_13
L15-0368E-1245N_1474_3210_13  L15-1209E-1113N_4838_3737_13
L15-0387E-1276N_1549_3087_13  L15-1210E-1025N_4840_4088_13
L15-0434E-1218N_1736_3318_13  L15-1276E-1107N_5105_3761_13
L15-0457E-1135N_1831_3648_13  L15-1289E-1169N_5156_3514_13
L15-0487E-1246N_1950_3207_13  L15-1296E-1198N_5184_3399_13
L15-0506E-1204N_2027_3374_13  L15-1298E-1322N_5193_2903_13
L15-0544E-1228N_2176_3279_13  L15-1335E-1166N_5342_3524_13
L15-0566E-1185N_2265_3451_13  L15-1389E-1284N_5557_3054_13
L15-0571E-1075N_2287_3888_13  L15-1438E-1134N_5753_3655_13
L15-0577E-1243N_2309_3217_13  L15-1439E-1134N_5759_3655_13
L15-0586E-1127N_2345_3680_13  L15-1479E-1101N_5916_3785_13
L15-0595E-1278N_2383_3079_13  L15-1481E-1119N_5927_3715_13
L15-0614E-0946N_2459_4406_13  L15-1538E-1163N_6154_3539_13
L15-0632E-0892N_2528_4620_13  L15-1615E-1205N_6460_3370_13
L15-0683E-1006N_2732_4164_13  L15-1615E-1206N_6460_3366_13
L15-0760E-0887N_3041_4643_13  L15-1617E-1207N_6468_3360_13
L15-0924E-1108N_3699_3757_13  L15-1669E-1153N_6678_3579_13
L15-0977E-1187N_3911_3441_13  L15-1669E-1160N_6678_3548_13
L15-1014E-1375N_4056_2688_13  L15-1669E-1160N_6679_3549_13
L15-1015E-1062N_4061_3941_13  L15-1672E-1207N_6691_3363_13
L15-1025E-1366N_4102_2726_13  L15-1690E-1211N_6763_3346_13
L15-1049E-1370N_4196_2710_13  L15-1691E-1211N_6764_3347_13
L15-1138E-1216N_4553_3325_13  L15-1703E-1219N_6813_3313_13
L15-1172E-1306N_4688_2967_13  L15-1709E-1112N_6838_3742_13
L15-1185E-0935N_4742_4450_13  L15-1716E-1211N_6864_3345_13
masked_images = get_masked_images(path, 1)
len(masked_images)
58

Dataset now consists of 58 correct full images after Datacleaning Step

Cutting images in tiles

Since the images are large (1024x1024), we cut them into 16 smaller tiles (255x255) and save them to disk. Most structures are small in relation to the whole scene, so this should not hurt training too much. Smaller tiles allow for larger batch sizes and/or deeper models to fit in GPU RAM.

Most images have 1024x1024 pixels. Some images however have only 1023 pixels in one dimension, therefore I chose 255 instead of 256 as the tile size. This throws away some pixels in most images, but maintains an equal tile size for all images.

To do: Ideally, we would create overlapping tiles to avoid some buildings being cut in half and never seen in their full shape by the model.

def cut_tiles(tile_size:int):
  "Cuts the large images and masks into equal tiles and saves them to disk"
  masked_images = get_masked_images(path, 5)
  for fn in tqdm(masked_images):
    scene = fn.parent.parent

    # Create directories
    if not os.path.exists(scene/'img_tiles'):
      os.makedirs(scene/'img_tiles')
    if not os.path.exists(scene/'mask_tiles'):
      os.makedirs(scene/'mask_tiles')

    # Create mask for current image
    img = np.array(PILImage.create(fn))
    msk_fn = str(fn).replace('images_masked', 'binary_mask')
    msk = np.array(PILMask.create(msk_fn))
    x, y, _ = img.shape

    # Cut tiles and save them
    for i in range(x//tile_size):
      for j in range(y//tile_size):
        img_tile = img[i*tile_size:(i+1)*tile_size,j*tile_size:(j+1)*tile_size]
        msk_tile = msk[i*tile_size:(i+1)*tile_size,j*tile_size:(j+1)*tile_size]
        Image.fromarray(img_tile).save(f'{scene}/img_tiles/{fn.name[:-4]}_{i}_{j}.png')
        Image.fromarray(msk_tile).save(f'{scene}/mask_tiles/{fn.name[:-4]}_{i}_{j}.png')
 
len(masked_images) #These are the images that are after the data cleaning step ready to be cut into tiles.
58
#del masked_images[-2]
masked_images
[Path('L15-0331E-1257N_1327_3160_13/images_masked/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13.tif'),
 Path('L15-0357E-1223N_1429_3296_13/images_masked/global_monthly_2018_01_mosaic_L15-0357E-1223N_1429_3296_13.tif'),
 Path('L15-0358E-1220N_1433_3310_13/images_masked/global_monthly_2018_03_mosaic_L15-0358E-1220N_1433_3310_13.tif'),
 Path('L15-0361E-1300N_1446_2989_13/images_masked/global_monthly_2018_01_mosaic_L15-0361E-1300N_1446_2989_13.tif'),
 Path('L15-0368E-1245N_1474_3210_13/images_masked/global_monthly_2018_01_mosaic_L15-0368E-1245N_1474_3210_13.tif'),
 Path('L15-0387E-1276N_1549_3087_13/images_masked/global_monthly_2018_03_mosaic_L15-0387E-1276N_1549_3087_13.tif'),
 Path('L15-0434E-1218N_1736_3318_13/images_masked/global_monthly_2018_01_mosaic_L15-0434E-1218N_1736_3318_13.tif'),
 Path('L15-0457E-1135N_1831_3648_13/images_masked/global_monthly_2018_01_mosaic_L15-0457E-1135N_1831_3648_13.tif'),
 Path('L15-0487E-1246N_1950_3207_13/images_masked/global_monthly_2018_02_mosaic_L15-0487E-1246N_1950_3207_13.tif'),
 Path('L15-0506E-1204N_2027_3374_13/images_masked/global_monthly_2018_01_mosaic_L15-0506E-1204N_2027_3374_13.tif'),
 Path('L15-0544E-1228N_2176_3279_13/images_masked/global_monthly_2018_03_mosaic_L15-0544E-1228N_2176_3279_13.tif'),
 Path('L15-0566E-1185N_2265_3451_13/images_masked/global_monthly_2018_03_mosaic_L15-0566E-1185N_2265_3451_13.tif'),
 Path('L15-0571E-1075N_2287_3888_13/images_masked/global_monthly_2018_03_mosaic_L15-0571E-1075N_2287_3888_13.tif'),
 Path('L15-0577E-1243N_2309_3217_13/images_masked/global_monthly_2018_04_mosaic_L15-0577E-1243N_2309_3217_13.tif'),
 Path('L15-0586E-1127N_2345_3680_13/images_masked/global_monthly_2018_03_mosaic_L15-0586E-1127N_2345_3680_13.tif'),
 Path('L15-0595E-1278N_2383_3079_13/images_masked/global_monthly_2018_02_mosaic_L15-0595E-1278N_2383_3079_13.tif'),
 Path('L15-0614E-0946N_2459_4406_13/images_masked/global_monthly_2018_05_mosaic_L15-0614E-0946N_2459_4406_13.tif'),
 Path('L15-0632E-0892N_2528_4620_13/images_masked/global_monthly_2018_01_mosaic_L15-0632E-0892N_2528_4620_13.tif'),
 Path('L15-0683E-1006N_2732_4164_13/images_masked/global_monthly_2018_02_mosaic_L15-0683E-1006N_2732_4164_13.tif'),
 Path('L15-0760E-0887N_3041_4643_13/images_masked/global_monthly_2018_02_mosaic_L15-0760E-0887N_3041_4643_13.tif'),
 Path('L15-0924E-1108N_3699_3757_13/images_masked/global_monthly_2018_01_mosaic_L15-0924E-1108N_3699_3757_13.tif'),
 Path('L15-0977E-1187N_3911_3441_13/images_masked/global_monthly_2018_04_mosaic_L15-0977E-1187N_3911_3441_13.tif'),
 Path('L15-1014E-1375N_4056_2688_13/images_masked/global_monthly_2018_03_mosaic_L15-1014E-1375N_4056_2688_13.tif'),
 Path('L15-1015E-1062N_4061_3941_13/images_masked/global_monthly_2018_02_mosaic_L15-1015E-1062N_4061_3941_13.tif'),
 Path('L15-1025E-1366N_4102_2726_13/images_masked/global_monthly_2018_01_mosaic_L15-1025E-1366N_4102_2726_13.tif'),
 Path('L15-1049E-1370N_4196_2710_13/images_masked/global_monthly_2018_01_mosaic_L15-1049E-1370N_4196_2710_13.tif'),
 Path('L15-1138E-1216N_4553_3325_13/images_masked/global_monthly_2018_02_mosaic_L15-1138E-1216N_4553_3325_13.tif'),
 Path('L15-1172E-1306N_4688_2967_13/images_masked/global_monthly_2018_01_mosaic_L15-1172E-1306N_4688_2967_13.tif'),
 Path('L15-1185E-0935N_4742_4450_13/images_masked/global_monthly_2018_04_mosaic_L15-1185E-0935N_4742_4450_13.tif'),
 Path('L15-1200E-0847N_4802_4803_13/images_masked/global_monthly_2018_02_mosaic_L15-1200E-0847N_4802_4803_13.tif'),
 Path('L15-1203E-1203N_4815_3378_13/images_masked/global_monthly_2018_02_mosaic_L15-1203E-1203N_4815_3378_13.tif'),
 Path('L15-1204E-1202N_4816_3380_13/images_masked/global_monthly_2017_08_mosaic_L15-1204E-1202N_4816_3380_13.tif'),
 Path('L15-1204E-1204N_4819_3372_13/images_masked/global_monthly_2018_01_mosaic_L15-1204E-1204N_4819_3372_13.tif'),
 Path('L15-1209E-1113N_4838_3737_13/images_masked/global_monthly_2018_01_mosaic_L15-1209E-1113N_4838_3737_13.tif'),
 Path('L15-1210E-1025N_4840_4088_13/images_masked/global_monthly_2018_01_mosaic_L15-1210E-1025N_4840_4088_13.tif'),
 Path('L15-1276E-1107N_5105_3761_13/images_masked/global_monthly_2018_01_mosaic_L15-1276E-1107N_5105_3761_13.tif'),
 Path('L15-1289E-1169N_5156_3514_13/images_masked/global_monthly_2018_03_mosaic_L15-1289E-1169N_5156_3514_13.tif'),
 Path('L15-1296E-1198N_5184_3399_13/images_masked/global_monthly_2018_02_mosaic_L15-1296E-1198N_5184_3399_13.tif'),
 Path('L15-1298E-1322N_5193_2903_13/images_masked/global_monthly_2018_01_mosaic_L15-1298E-1322N_5193_2903_13.tif'),
 Path('L15-1335E-1166N_5342_3524_13/images_masked/global_monthly_2018_01_mosaic_L15-1335E-1166N_5342_3524_13.tif'),
 Path('L15-1389E-1284N_5557_3054_13/images_masked/global_monthly_2018_02_mosaic_L15-1389E-1284N_5557_3054_13.tif'),
 Path('L15-1438E-1134N_5753_3655_13/images_masked/global_monthly_2018_03_mosaic_L15-1438E-1134N_5753_3655_13.tif'),
 Path('L15-1439E-1134N_5759_3655_13/images_masked/global_monthly_2018_02_mosaic_L15-1439E-1134N_5759_3655_13.tif'),
 Path('L15-1479E-1101N_5916_3785_13/images_masked/global_monthly_2018_03_mosaic_L15-1479E-1101N_5916_3785_13.tif'),
 Path('L15-1481E-1119N_5927_3715_13/images_masked/global_monthly_2018_01_mosaic_L15-1481E-1119N_5927_3715_13.tif'),
 Path('L15-1538E-1163N_6154_3539_13/images_masked/global_monthly_2018_01_mosaic_L15-1538E-1163N_6154_3539_13.tif'),
 Path('L15-1615E-1205N_6460_3370_13/images_masked/global_monthly_2017_07_mosaic_L15-1615E-1205N_6460_3370_13.tif'),
 Path('L15-1615E-1206N_6460_3366_13/images_masked/global_monthly_2017_10_mosaic_L15-1615E-1206N_6460_3366_13.tif'),
 Path('L15-1617E-1207N_6468_3360_13/images_masked/global_monthly_2018_02_mosaic_L15-1617E-1207N_6468_3360_13.tif'),
 Path('L15-1669E-1153N_6678_3579_13/images_masked/global_monthly_2018_01_mosaic_L15-1669E-1153N_6678_3579_13.tif'),
 Path('L15-1669E-1160N_6678_3548_13/images_masked/global_monthly_2017_10_mosaic_L15-1669E-1160N_6678_3548_13.tif'),
 Path('L15-1669E-1160N_6679_3549_13/images_masked/global_monthly_2017_10_mosaic_L15-1669E-1160N_6679_3549_13.tif'),
 Path('L15-1672E-1207N_6691_3363_13/images_masked/global_monthly_2018_02_mosaic_L15-1672E-1207N_6691_3363_13.tif'),
 Path('L15-1690E-1211N_6763_3346_13/images_masked/global_monthly_2017_08_mosaic_L15-1690E-1211N_6763_3346_13.tif'),
 Path('L15-1691E-1211N_6764_3347_13/images_masked/global_monthly_2018_02_mosaic_L15-1691E-1211N_6764_3347_13.tif'),
 Path('L15-1703E-1219N_6813_3313_13/images_masked/global_monthly_2018_01_mosaic_L15-1703E-1219N_6813_3313_13.tif'),
 Path('L15-1709E-1112N_6838_3742_13/images_masked/global_monthly_2018_01_mosaic_L15-1709E-1112N_6838_3742_13.tif'),
 Path('L15-1716E-1211N_6864_3345_13/images_masked/global_monthly_2018_02_mosaic_L15-1716E-1211N_6864_3345_13.tif')]
path.ls()
(#59) [Path('L15-0331E-1257N_1327_3160_13'),Path('L15-0357E-1223N_1429_3296_13'),Path('L15-0358E-1220N_1433_3310_13'),Path('L15-0361E-1300N_1446_2989_13'),Path('L15-0368E-1245N_1474_3210_13'),Path('L15-0387E-1276N_1549_3087_13'),Path('L15-0434E-1218N_1736_3318_13'),Path('L15-0457E-1135N_1831_3648_13'),Path('L15-0487E-1246N_1950_3207_13'),Path('L15-0506E-1204N_2027_3374_13')...]
path.ls()
(#59) [Path('L15-0331E-1257N_1327_3160_13'),Path('L15-0357E-1223N_1429_3296_13'),Path('L15-0358E-1220N_1433_3310_13'),Path('L15-0361E-1300N_1446_2989_13'),Path('L15-0368E-1245N_1474_3210_13'),Path('L15-0387E-1276N_1549_3087_13'),Path('L15-0434E-1218N_1736_3318_13'),Path('L15-0457E-1135N_1831_3648_13'),Path('L15-0487E-1246N_1950_3207_13'),Path('L15-0506E-1204N_2027_3374_13')...]

Data Loading Functions

Little helper functions

def get_image_tiles(path:Path, n_tiles=TILES_PER_SCENE) -> L:
  "Returns a list of the first `n` image tile filenames in `path`"
  files = L()
  for folder in path.ls():
    files.extend(get_image_files(path=folder, folders='img_tiles')[:n_tiles])
  return files
def get_y_fn(fn:Path) -> str:
  "Returns filename of the associated mask tile for a given image tile"
  return str(fn).replace('img_tiles', 'mask_tiles')
def get_y(fn:Path) -> PILMask:
  "Returns a PILMask object of 0s and 1s for a given tile"
  fn = get_y_fn(fn)
  msk = np.array(PILMask.create(fn))
  msk[msk==255] = 1
  return PILMask.create(msk)

Visualizing Data

Let's look at some raw image tiles and their masks.

def show_tiles(n):
  all_tiles = get_image_tiles(path)
  subset = random.sample(all_tiles, n)
  fig, ax = plt.subplots(n//2, 4, figsize=(14,14))
  for i in range(n):
    y = i//2
    x = 2*i%4
    PILImage.create(subset[i]).show(ctx=ax[y, x])
    get_y(subset[i]).show(ctx=ax[y, x+1], cmap='cividis')
  fig.tight_layout()
  plt.show()
show_tiles(8)

Challenges of the dataset

As we can see in the visualizations, the dataset provides some challenges:

  • The buildings are often extremely small, just a few pixels, and very close to each other
  • On the other hand, there are large structures that cover a much greater area than small buildings
  • Some buildings are hard to recognize, even for the human eye
  • The density of buildings varies greatly. There are tiles with no buildings at all, other tiles show urban scenes with hundreds of buildings
  • The images are very diverse, with great differences in topography, vegetation and urbanization
  • Some tiles are covered partially or completely with UDM masks

Distribution of building density

To explore how imbalanced the data is exactly, we'll analyze the percentages of building pixels in each tile. We create a simple dataloader to easily load and analyze the masks.

tiles = DataBlock(
      blocks = (ImageBlock(),MaskBlock(codes=CODES)),
      get_items = get_image_tiles,
      get_y = get_y
    )              
dls = tiles.dataloaders(path, bs=BATCH_SIZE)
dls.vocab = CODES                 
targs = torch.zeros((0,255,255))
for _, masks in dls[0]:
  targs = torch.cat((targs, masks.cpu()), dim=0)
targs.shape
torch.Size([732, 255, 255])

We have 732 image tiles in total.

Calculating the percentage of building pixels vs background pixels:

total_pixels = targs.shape[1]**2
percentages = torch.count_nonzero(targs, dim=(1,2))/total_pixels
plt.hist(percentages, bins=20)
plt.ylabel('Number of tiles')
plt.xlabel('Ratio of pixels that are of class `building`')
plt.gca().spines['top'].set_color('none')
plt.gca().spines['right'].set_color('none')
plt.show()

We can see that many tiles contain no or very few buildings.

torch.count_nonzero((percentages==0.).float()).item()
64

64 images do not contain a single pixel of the building class, that's almost 10% of the images. These can be areas of empty land, water, or tiles that were covered in clouds.

What is the tile with the largest percentage of buildings?

targs[percentages.argsort(descending=True)[0]].show();

What is the overall ratio building/background?

print(percentages.mean().item(), percentages.median().item())
0.06513693183660507 0.036139946430921555

On average, 6.5% of a tile's pixels are of the building class. The median is only 3.6%. This means this is a rather imbalanced dataset.

Validation Strategy

To allow the evaluation of the performance of our model, we set aside 15% of the dataset as validation set.

We must be thoughtful about how we create this validation set. Using random images would be too easy, as we have several images per scene that differ only slightly. Our validation set would not be thoroughly separated from the training set.

Instead, I chose to randomly select 9 scenes that are used as validation data. The model will never see any images from these scenes during training.

VALID_SCENES = ['L15-0571E-1075N_2287_3888_13',
 'L15-1615E-1205N_6460_3370_13',
 'L15-1210E-1025N_4840_4088_13',
 'L15-1185E-0935N_4742_4450_13',
 'L15-1481E-1119N_5927_3715_13',
 'L15-0632E-0892N_2528_4620_13',
 'L15-1438E-1134N_5753_3655_13',
 'L15-0924E-1108N_3699_3757_13',
 'L15-0457E-1135N_1831_3648_13']
def valid_split(item):
  scene = item.parent.parent.name
  return scene in VALID_SCENES

Undersampling

To help mitigating the imbalanced classes, we remove all tiles that contain no buildings at all from the training set. This reduces the amount of samples by ~10%, thereby accelerating the training while helping the model perform better.

def has_buildings(fn:Path) -> bool:
  """Returns whether the mask of a given image tile
  contains at least one pixel of a building"""
  fn = get_y_fn(fn)
  msk = tensor(PILMask.create(fn))
  count = torch.count_nonzero(msk)
  return count>0.

def get_undersampled_tiles(path:Path) -> L:
  """Returns a list of image tile filenames in `path`.
  For tiles in the training set, empty tiles are ignored.
  All tiles in the validation set are included."""

  files = get_image_tiles(path)
  train_idxs, valid_idxs = FuncSplitter(valid_split)(files)
  train_files = L(filter(has_buildings, files[train_idxs]))
  valid_files = files[valid_idxs]

  return train_files + valid_files

Creating Dataloaders

The following transformations seem reasonable for satellite images. We flip the tiles vertically and horizontally, rotate them, change brightness, contrast and saturation by a small amount. We normalize them according to ImageNet stats, so that we can use a pretrained model later.

tfms = [Dihedral(0.5),              # Horizontal and vertical flip
        Rotate(max_deg=180, p=0.9), # Rotation in any direction possible
        Brightness(0.2, p=0.75),
        Contrast(0.2),
        Saturation(0.2),
        Normalize.from_stats(*imagenet_stats)]

To create the datasets, we use the convenient DataBlock API of fastai. We only load 16 tiles per scene, so only 1 image per region.

tiles = DataBlock(
      blocks = (ImageBlock(),MaskBlock(codes=CODES)), # Independent variable is Image, dependent variable is Mask
      get_items = get_undersampled_tiles,             # Collect undersampled tiles
      get_y = get_y,                                  # Get dependent variable: mask
      splitter = FuncSplitter(valid_split),           # Split into training and validation set
      batch_tfms = tfms                               # Transforms on GPU: augmentation, normalization
    )                              
dls = tiles.dataloaders(path, bs=BATCH_SIZE)
dls.vocab = CODES
len(dls.train_ds), len(dls.valid_ds)
(715, 144)

We have 715 correct tiles in the training set and 144 tiles in the validation set.

Making sure the batches look okay:

inputs, targets = dls.one_batch()
inputs.shape, targets.shape
(torch.Size([12, 3, 255, 255]), torch.Size([12, 255, 255]))

These dimensions are as expected:

  • 12 images per batch
  • 3 channels for the input images
  • no color channels for the target mask
  • image size: 255x255.

Check that the mask looks as expected, 0s and 1s:

targets[0].unique()
TensorMask([0, 1], device='cuda:0')

Defining the Model

The task at hand is an image segmentation problem. In the original competition, it is required to assign individual labels to each building to keep track of it over time (instance segmentation). Here, I chose to do semantic segmentation instead, so just classifying for every pixel if it belongs to a building or not.

The fastai library allows the remarkably simple creation of a U-Net, a standard architecture for image segmentation problems. The module DynamicUNet - provided with an encoder architecture - automatically constructs a decoder and cross connections. This makes it easy to build a U-Net out of different (and pretrained) architectures. I chose this approach to have more time to experiment instead of writing code from scratch. I considered following aspects:

  • Encoder: I picked a xResNet34 model that has been pretrained on ImageNet. A 34-layer encoder seems like a good compromise between accuracy and memory/compute requirements.
  • Loss function: The choice of the loss function is important for segmentation problems. I'll use a weighted pixel-wise cross-entropy loss. The weights are important for the imbalanced dataset.
  • Optimizer: I use the default optimizer, Adam.
  • Metrics:
    • As the classes are very imbalanced, a simple accuracy metric would not be helpful. In a picture with 3% buildings, the model could predict "no building" on every pixel and still get 97% accuracy.
    • Instead, I focus on the Dice metric, it is often used for segmentation tasks. It is equivalent to the F1 score and measures the ratio of $\frac{2TP}{2TP + FP + FN}$
    • Additionally, I added foreground_acc of fastai, it measures the percentage of foreground pixels correctly classified, the Recall. Foreground in this case is the building class.
weights = Tensor(CLASS_WEIGHTS).cuda()
loss_func = CrossEntropyLossFlat(axis=1, weight=weights)

With some experimentation, the class weights 0.25 for the background and 0.75 for the building class seem to work fine.

learn = unet_learner(dls,                                 # DataLoaders
                     ARCHITECTURE,                        # xResNet34
                     loss_func = loss_func,               # Weighted cross entropy loss
                     opt_func = Adam,                     # Adam optimizer
                     metrics = [Dice(), foreground_acc],  # Custom metrics
                     self_attention = False,
                     cbs = [SaveModelCallback(
                              monitor='dice',
                              comp=np.greater,
                              fname='best-model-34'
                            )]
                     )

Summary of the model:

learn.summary()

DynamicUnet (Input shape: 12)
============================================================================
Layer (type)         Output Shape         Param #    Trainable 
============================================================================
                     12 x 32 x 128 x 128 
Conv2d                                    864        False     
BatchNorm2d                               64         True      
ReLU                                                           
Conv2d                                    9216       False     
BatchNorm2d                               64         True      
ReLU                                                           
____________________________________________________________________________
                     12 x 64 x 128 x 128 
Conv2d                                    18432      False     
BatchNorm2d                               128        True      
ReLU                                                           
MaxPool2d                                                      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Sequential                                                     
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Sequential                                                     
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     12 x 128 x 32 x 32  
Conv2d                                    73728      False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
____________________________________________________________________________
                     []                  
AvgPool2d                                                      
____________________________________________________________________________
                     12 x 128 x 32 x 32  
Conv2d                                    8192       False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
Sequential                                                     
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
Sequential                                                     
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     12 x 256 x 16 x 16  
Conv2d                                    294912     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
____________________________________________________________________________
                     []                  
AvgPool2d                                                      
____________________________________________________________________________
                     12 x 256 x 16 x 16  
Conv2d                                    32768      False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Sequential                                                     
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Sequential                                                     
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Sequential                                                     
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Sequential                                                     
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     12 x 512 x 8 x 8    
Conv2d                                    1179648    False     
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
____________________________________________________________________________
                     []                  
AvgPool2d                                                      
____________________________________________________________________________
                     12 x 512 x 8 x 8    
Conv2d                                    131072     False     
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
Sequential                                                     
ReLU                                                           
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
Sequential                                                     
ReLU                                                           
BatchNorm2d                               1024       True      
ReLU                                                           
____________________________________________________________________________
                     12 x 1024 x 8 x 8   
Conv2d                                    4719616    True      
ReLU                                                           
____________________________________________________________________________
                     12 x 512 x 8 x 8    
Conv2d                                    4719104    True      
ReLU                                                           
____________________________________________________________________________
                     12 x 1024 x 8 x 8   
Conv2d                                    525312     True      
ReLU                                                           
PixelShuffle                                                   
BatchNorm2d                               512        True      
Conv2d                                    2359808    True      
ReLU                                                           
Conv2d                                    2359808    True      
ReLU                                                           
ReLU                                                           
____________________________________________________________________________
                     12 x 1024 x 16 x 16 
Conv2d                                    525312     True      
ReLU                                                           
PixelShuffle                                                   
BatchNorm2d                               256        True      
Conv2d                                    1327488    True      
ReLU                                                           
Conv2d                                    1327488    True      
ReLU                                                           
ReLU                                                           
____________________________________________________________________________
                     12 x 768 x 32 x 32  
Conv2d                                    295680     True      
ReLU                                                           
PixelShuffle                                                   
BatchNorm2d                               128        True      
Conv2d                                    590080     True      
ReLU                                                           
Conv2d                                    590080     True      
ReLU                                                           
ReLU                                                           
____________________________________________________________________________
                     12 x 512 x 64 x 64  
Conv2d                                    131584     True      
ReLU                                                           
PixelShuffle                                                   
BatchNorm2d                               128        True      
____________________________________________________________________________
                     12 x 96 x 128 x 128 
Conv2d                                    165984     True      
ReLU                                                           
Conv2d                                    83040      True      
ReLU                                                           
ReLU                                                           
____________________________________________________________________________
                     12 x 384 x 128 x 12 
Conv2d                                    37248      True      
ReLU                                                           
PixelShuffle                                                   
ResizeToOrig                                                   
MergeLayer                                                     
Conv2d                                    88308      True      
ReLU                                                           
Conv2d                                    88308      True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     12 x 2 x 255 x 255  
Conv2d                                    200        True      
ToTensorBase                                                   
____________________________________________________________________________

Total params: 41,240,400
Total trainable params: 19,953,648
Total non-trainable params: 21,286,752

Optimizer used: <function Adam at 0x7fea8c0c5200>
Loss function: FlattenedLoss of CrossEntropyLoss()

Model frozen up to parameter group #2

Callbacks:
  - TrainEvalCallback
  - Recorder
  - ProgressCallback
  - SaveModelCallback

This is the full UNet model:

learn.model

DynamicUnet(
  (layers): ModuleList(
    (0): Sequential(
      (0): ConvLayer(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (1): ConvLayer(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (2): ConvLayer(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (1): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (2): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
      )
      (5): Sequential(
        (0): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential(
            (0): AvgPool2d(kernel_size=2, stride=2, padding=0)
            (1): ConvLayer(
              (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (act): ReLU(inplace=True)
        )
        (1): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (2): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (3): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
      )
      (6): Sequential(
        (0): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential(
            (0): AvgPool2d(kernel_size=2, stride=2, padding=0)
            (1): ConvLayer(
              (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (act): ReLU(inplace=True)
        )
        (1): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (2): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (3): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (4): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (5): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
      )
      (7): Sequential(
        (0): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential(
            (0): AvgPool2d(kernel_size=2, stride=2, padding=0)
            (1): ConvLayer(
              (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (act): ReLU(inplace=True)
        )
        (1): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (2): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
      )
    )
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Sequential(
      (0): ConvLayer(
        (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (1): ConvLayer(
        (0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
    )
    (4): UnetBlock(
      (shuf): PixelShuffle_ICNR(
        (0): ConvLayer(
          (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU()
        )
        (1): PixelShuffle(upscale_factor=2)
      )
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): ConvLayer(
        (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (conv2): ConvLayer(
        (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (relu): ReLU()
    )
    (5): UnetBlock(
      (shuf): PixelShuffle_ICNR(
        (0): ConvLayer(
          (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU()
        )
        (1): PixelShuffle(upscale_factor=2)
      )
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): ConvLayer(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (conv2): ConvLayer(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (relu): ReLU()
    )
    (6): UnetBlock(
      (shuf): PixelShuffle_ICNR(
        (0): ConvLayer(
          (0): Conv2d(384, 768, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU()
        )
        (1): PixelShuffle(upscale_factor=2)
      )
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): ConvLayer(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (conv2): ConvLayer(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (relu): ReLU()
    )
    (7): UnetBlock(
      (shuf): PixelShuffle_ICNR(
        (0): ConvLayer(
          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU()
        )
        (1): PixelShuffle(upscale_factor=2)
      )
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): ConvLayer(
        (0): Conv2d(192, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (conv2): ConvLayer(
        (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (relu): ReLU()
    )
    (8): PixelShuffle_ICNR(
      (0): ConvLayer(
        (0): Conv2d(96, 384, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU()
      )
      (1): PixelShuffle(upscale_factor=2)
    )
    (9): ResizeToOrig()
    (10): MergeLayer()
    (11): ResBlock(
      (convpath): Sequential(
        (0): ConvLayer(
          (0): Conv2d(99, 99, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU()
        )
        (1): ConvLayer(
          (0): Conv2d(99, 99, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (idpath): Sequential()
      (act): ReLU(inplace=True)
    )
    (12): ConvLayer(
      (0): Conv2d(99, 2, kernel_size=(1, 1), stride=(1, 1))
    )
    (13): ToTensorBase(tensor_cls=<class 'fastai.torch_core.TensorBase'>)
  )
)

Training

We can use fastai's learning rate finder to pick a reasonable learning rate:

learn.lr_find(suggestions=False))
SuggestedLRs(valley=tensor(6.3096e-05))

Somewhere around 1e-4 seems reasonable, where the loss decreases steadily.

lr_max = LR_MAX # 3e-4

We unfreeze the model to train the encoder and decoder simultaneously. The encoder should be trained at a lower learning rate, since we don't want to change the pretrained features too much. This is achieved by setting the learning rate to slice(lr_max/10, lr_max)

We use the fit_one_cycle method, where the learning rate starts low for a warm up period, reaches a maximum of lr_max and then anneals to 0 at the end of training.

!nvidia-smi
Tue Jul  6 07:59:28 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   60C    P0    29W /  70W |   8166MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+
ls
L15-0331E-1257N_1327_3160_13/  L15-1203E-1203N_4815_3378_13/
L15-0357E-1223N_1429_3296_13/  L15-1204E-1202N_4816_3380_13/
L15-0358E-1220N_1433_3310_13/  L15-1204E-1204N_4819_3372_13/
L15-0361E-1300N_1446_2989_13/  L15-1209E-1113N_4838_3737_13/
L15-0368E-1245N_1474_3210_13/  L15-1210E-1025N_4840_4088_13/
L15-0387E-1276N_1549_3087_13/  L15-1276E-1107N_5105_3761_13/
L15-0434E-1218N_1736_3318_13/  L15-1289E-1169N_5156_3514_13/
L15-0457E-1135N_1831_3648_13/  L15-1296E-1198N_5184_3399_13/
L15-0487E-1246N_1950_3207_13/  L15-1298E-1322N_5193_2903_13/
L15-0506E-1204N_2027_3374_13/  L15-1335E-1166N_5342_3524_13/
L15-0544E-1228N_2176_3279_13/  L15-1389E-1284N_5557_3054_13/
L15-0566E-1185N_2265_3451_13/  L15-1438E-1134N_5753_3655_13/
L15-0571E-1075N_2287_3888_13/  L15-1439E-1134N_5759_3655_13/
L15-0577E-1243N_2309_3217_13/  L15-1479E-1101N_5916_3785_13/
L15-0586E-1127N_2345_3680_13/  L15-1481E-1119N_5927_3715_13/
L15-0595E-1278N_2383_3079_13/  L15-1538E-1163N_6154_3539_13/
L15-0614E-0946N_2459_4406_13/  L15-1615E-1205N_6460_3370_13/
L15-0632E-0892N_2528_4620_13/  L15-1615E-1206N_6460_3366_13/
L15-0683E-1006N_2732_4164_13/  L15-1617E-1207N_6468_3360_13/
L15-0760E-0887N_3041_4643_13/  L15-1669E-1153N_6678_3579_13/
L15-0924E-1108N_3699_3757_13/  L15-1669E-1160N_6678_3548_13/
L15-0977E-1187N_3911_3441_13/  L15-1669E-1160N_6679_3549_13/
L15-1014E-1375N_4056_2688_13/  L15-1672E-1207N_6691_3363_13/
L15-1015E-1062N_4061_3941_13/  L15-1690E-1211N_6763_3346_13/
L15-1025E-1366N_4102_2726_13/  L15-1691E-1211N_6764_3347_13/
L15-1049E-1370N_4196_2710_13/  L15-1703E-1219N_6813_3313_13/
L15-1138E-1216N_4553_3325_13/  L15-1709E-1112N_6838_3742_13/
L15-1172E-1306N_4688_2967_13/  L15-1716E-1211N_6864_3345_13/
L15-1185E-0935N_4742_4450_13/  models/
L15-1200E-0847N_4802_4803_13/
!ls
gdrive	sample_data

-----Experiment Run: 1

#collapse-output
learn.unfreeze()
learn.fit_one_cycle(
    EPOCHS,
    lr_max=slice(lr_max/ENCODER_FACTOR, lr_max),
    cbs=[WandbCallback()]
)

epoch train_loss valid_loss dice foreground_acc time
0 0.388443 0.413384 0.289366 0.174573 00:50
1 0.349046 0.359303 0.461197 0.445781 00:44
2 0.332909 0.349615 0.490519 0.481620 00:45
3 0.328296 0.404246 0.307558 0.176636 00:44
4 0.311649 0.331803 0.518899 0.552372 00:44
5 0.315864 0.333278 0.496052 0.437288 00:45
6 0.305936 0.369581 0.500008 0.671848 00:45
7 0.306035 0.335628 0.522586 0.520908 00:45
8 0.295921 0.330779 0.530826 0.513260 00:46
9 0.290359 0.311421 0.545738 0.579524 00:46
10 0.294897 0.325334 0.530101 0.488870 00:46
11 0.286177 0.327654 0.539763 0.652674 00:46
12 0.278001 0.354929 0.485753 0.677047 00:46
13 0.275921 0.346134 0.534323 0.538300 00:46
14 0.271893 0.336453 0.529238 0.586955 00:46
15 0.270011 0.351094 0.522521 0.472264 00:46
16 0.266521 0.455999 0.513976 0.443002 00:46
17 0.260103 0.323886 0.543388 0.594548 00:46
18 0.259548 0.317344 0.557026 0.604788 00:46
19 0.252726 0.342917 0.537662 0.612756 00:46
20 0.248773 0.313652 0.553115 0.602638 00:46
21 0.244121 0.335954 0.549108 0.563031 00:46
22 0.246857 0.315527 0.554965 0.656071 00:46
23 0.246583 0.309595 0.560354 0.621056 00:46
24 0.246868 0.322774 0.548355 0.565873 00:46
25 0.243858 0.314193 0.549589 0.677579 00:46
26 0.238657 0.318479 0.553847 0.552282 00:46
27 0.236678 0.319910 0.553423 0.588908 00:46
28 0.234124 0.306141 0.552694 0.669329 00:46
29 0.230853 0.328426 0.550115 0.606573 00:46
30 0.232298 0.314922 0.556050 0.686611 00:46
31 0.230854 0.324863 0.551161 0.623305 00:46
32 0.231132 0.321707 0.548863 0.629805 00:46
33 0.231481 0.318916 0.552225 0.637253 00:46
34 0.225996 0.320821 0.552282 0.626153 00:46
35 0.220080 0.326714 0.551576 0.615147 00:46
36 0.225850 0.327396 0.551351 0.619062 00:47
37 0.227767 0.326052 0.552426 0.623551 00:46
38 0.221003 0.317112 0.554034 0.647620 00:46
39 0.220426 0.318014 0.554319 0.637630 00:46
Better model found at epoch 0 with dice value: 0.289366090884329.
Better model found at epoch 1 with dice value: 0.4611968357497782.
Better model found at epoch 2 with dice value: 0.4905194190974164.
Better model found at epoch 4 with dice value: 0.518899214365881.
Better model found at epoch 7 with dice value: 0.5225864940854509.
Better model found at epoch 8 with dice value: 0.5308255345561974.
Better model found at epoch 9 with dice value: 0.5457375700715181.
Better model found at epoch 18 with dice value: 0.557025656659146.
Better model found at epoch 23 with dice value: 0.5603543324449821.
learn.recorder.plot_loss()
learn.save('best-model-34')
Path('models/best-model-34.pth')

The best model at epoch 23 has a Dice score of 0.56. The theoretic maximum - a perfect segmentation - is 1.0.

learn.recorder.plot_loss()
 run.finish()

Waiting for W&B process to finish, PID 6232
Program ended successfully.
Find user logs for this run at: /content/gdrive/Shareddrives/Undrive/s7/SN7_buildings_train/train/wandb/run-20210706_090514-3cwyiul0/logs/debug.log
Find internal logs for this run at: /content/gdrive/Shareddrives/Undrive/s7/SN7_buildings_train/train/wandb/run-20210706_090514-3cwyiul0/logs/debug-internal.log
Synced 4 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)

Visualizing the Results

Note: The following results actually come from the model of a different training run, also with a Dice score of 0.57
Let's get all predictions on the validation set and have a look at them.
probs,targets,preds,losses = learn.get_preds(dl=dls.valid,
                                             with_loss=True,
                                             with_decoded=True,
                                             act=None)

Sort descending by loss:

loss_sorted = torch.argsort(losses, descending=True)
n = len(loss_sorted)

Helper function to show predictions:

def show_single_pred(index:int):
  fig, ax = plt.subplots(1, 4, figsize=(20,5))
  dls.valid_ds[index][0].show(ctx=ax[0]);
  ax[0].set_title("Input")
  show_at(dls.valid_ds, index, cmap='Blues', ctx=ax[1]);
  ax[1].set_title("Target")
  preds[index].show(cmap='Blues', ctx=ax[2]);
  ax[2].set_title("Prediction Mask")
  probs[index][1].show(cmap='viridis', ctx=ax[3]);
  ax[3].set_title("Building class probability")

Plot the samples with the highest losses

for idx in loss_sorted[:3]:
  print(f'Tile #{idx}, loss: {losses[idx]}')
  show_single_pred(idx)
Tile #110, loss: 0.7653330564498901
Tile #101, loss: 0.541749119758606
Tile #111, loss: 0.4768364429473877

All images with the highest losses show dense urban areas. What's noticable is the trouble that the model has with large buildings, which are often completely overlooked. It's also clear that very small buildings are merged into "blobs". I suspect tracking individual buildings could be difficult here.

Plot samples with medium losses

for idx in loss_sorted[n//2-1:n//2+2]:
  print(f'Tile #{idx}, loss: {losses[idx]}')
  show_single_pred(idx)
Tile #1, loss: 0.06586822122335434
Tile #7, loss: 0.06451009213924408
Tile #55, loss: 0.0639784187078476

The model tends to merge small buildings into larger blobs, and there are several false positives. But there are also some quite good predictions, picking up buildings that are hard even for the human eye to pick up.

Plot some examples with low losses

for idx in loss_sorted[-21:-18]:
  print(f'Tile #{idx}, loss: {losses[idx]}')
  show_single_pred(idx)
Tile #51, loss: 0.0051777600310742855
Tile #92, loss: 0.005046222358942032
Tile #9, loss: 0.004601902794092894

The model shows mixed performance in images with few buildings in them. Overall, the accuracy looks better here than in dense areas. But the model tends to produce false positives; and some tiles show weird artifacts in the corners. It seems as if the model interprets the corners itself as buildings, especially on tiles covered with water.

Show complete scenes

Predict all tiles of a scene:

path.ls()
(#60) [Path('L15-0760E-0887N_3041_4643_13'),Path('L15-0683E-1006N_2732_4164_13'),Path('L15-0632E-0892N_2528_4620_13'),Path('L15-0614E-0946N_2459_4406_13'),Path('L15-0595E-1278N_2383_3079_13'),Path('L15-0586E-1127N_2345_3680_13'),Path('L15-0977E-1187N_3911_3441_13'),Path('L15-0924E-1108N_3699_3757_13'),Path('L15-1014E-1375N_4056_2688_13'),Path('L15-1025E-1366N_4102_2726_13')...]

def save_predictions(scene, path=path) -> None:
  "Predicts all 16 tiles of one scene and saves them to disk"
  output_folder = path/scene/'predicted_tiles'
  if not os.path.exists(output_folder):
    os.makedirs(output_folder)
  tiles = get_image_files(path/scene/'img_tiles').sorted()
  for i in range(16):
    tile_preds = learn.predict(tiles[i])
    to_image(tile_preds[2][1].repeat(3,1,1)).save(output_folder/f'{i:02d}.png')
VALID_SCENES
['L15-0571E-1075N_2287_3888_13',
 'L15-1615E-1205N_6460_3370_13',
 'L15-1210E-1025N_4840_4088_13',
 'L15-1185E-0935N_4742_4450_13',
 'L15-1481E-1119N_5927_3715_13',
 'L15-0632E-0892N_2528_4620_13',
 'L15-1438E-1134N_5753_3655_13',
 'L15-0924E-1108N_3699_3757_13',
 'L15-0457E-1135N_1831_3648_13']
scene = VALID_SCENES[0] # 'L15-0571E-1075N_2287_3888_13'
scene = VALID_SCENES[0:] # '
scene
'L15-1210E-1025N_4840_4088_13'
save_predictions(scene)

Helper function to show several tiles as a large image:

def unblockshaped(arr, h, w):
    """
    Return an array of shape (h, w) where
    h * w = arr.size

    If arr is of shape (n, nrows, ncols), n sublocks of shape (nrows, ncols),
    then the returned array preserves the "physical" layout of the sublocks.

    Source: https://stackoverflow.com/a/16873755
    """
    try: # with color channel
      n, nrows, ncols, c = arr.shape
      return (arr.reshape(h//nrows, -1, nrows, ncols, c)
                .swapaxes(1,2)
                .reshape(h, w, c))
    except ValueError: # without color channel
      n, nrows, ncols = arr.shape
      return (arr.reshape(h//nrows, -1, nrows, ncols)
                .swapaxes(1,2)
                .reshape(h, w))

Load saved predictions:

def get_saved_preds(scene, path=path):
  "Load saved prediction mask tiles for a scene and return image + assembled mask"
  image_file = (path/scene/'images_masked').ls()[0]
  image = load_image(image_file)

  mask_tiles = get_image_files(path/scene/'predicted_tiles').sorted()
  mask_arrs = np.array(list(maps(partial(load_image, mode="L"), np.asarray, mask_tiles)))
  mask_array = unblockshaped(np.array(mask_arrs), 1020, 1020)

  return (image, mask_array)

Show image + stitched predictions:

def show_complete_preds(image, mask_array, scene):
  "Source: https://github.com/CosmiQ/CosmiQ_SN7_Baseline/blob/master/notebooks/sn7_baseline.ipynb"
  figsize = (25, 16)
  fig, (ax0, ax1) = plt.subplots(1, 2, figsize=figsize)
  _ = ax0.imshow(image)
  ax0.set_xticks([])
  ax0.set_yticks([])
  ax0.set_title('Image')
  _ = ax1.imshow(mask_array, cmap='viridis')
  ax1.set_xticks([])
  ax1.set_yticks([])
  ax1.set_title('Prediction Mask')
  plt.suptitle(scene)
  plt.tight_layout()
  plt.savefig(os.path.join(path, scene + '_im0+mask0+dice575.png'))
  plt.show()
show_complete_preds(*get_saved_preds(scene), scene)
VALID_SCENES
['L15-0571E-1075N_2287_3888_13',
 'L15-1615E-1205N_6460_3370_13',
 'L15-1210E-1025N_4840_4088_13',
 'L15-1185E-0935N_4742_4450_13',
 'L15-1481E-1119N_5927_3715_13',
 'L15-0632E-0892N_2528_4620_13',
 'L15-1438E-1134N_5753_3655_13',
 'L15-0924E-1108N_3699_3757_13',
 'L15-0457E-1135N_1831_3648_13']
scene = VALID_SCENES[0:] 
from time import sleep
for scene in VALID_SCENES:
  save_predictions(scene)
  show_complete_preds(*get_saved_preds(scene), scene)
  time.sleep(1)

#Initiating Run2

-----Experiment Run: 2

del learn
learn = unet_learner(dls,                                 # DataLoaders
                     ARCHITECTURE,                        # xResNet34
                     loss_func = loss_func,               # Weighted cross entropy loss
                     opt_func = Adam,                     # Adam optimizer
                     metrics = [Dice(), foreground_acc],  # Custom metrics
                     self_attention = False,
                     cbs = [SaveModelCallback(
                              monitor='dice',
                              comp=np.greater,
                              fname='best-model'
                            )]
                     )
learn.summary()

DynamicUnet (Input shape: 12)
============================================================================
Layer (type)         Output Shape         Param #    Trainable 
============================================================================
                     12 x 32 x 128 x 128 
Conv2d                                    864        False     
BatchNorm2d                               64         True      
ReLU                                                           
Conv2d                                    9216       False     
BatchNorm2d                               64         True      
ReLU                                                           
____________________________________________________________________________
                     12 x 64 x 128 x 128 
Conv2d                                    18432      False     
BatchNorm2d                               128        True      
ReLU                                                           
MaxPool2d                                                      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Sequential                                                     
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Sequential                                                     
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     12 x 128 x 32 x 32  
Conv2d                                    73728      False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
____________________________________________________________________________
                     []                  
AvgPool2d                                                      
____________________________________________________________________________
                     12 x 128 x 32 x 32  
Conv2d                                    8192       False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
Sequential                                                     
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
Sequential                                                     
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     12 x 256 x 16 x 16  
Conv2d                                    294912     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
____________________________________________________________________________
                     []                  
AvgPool2d                                                      
____________________________________________________________________________
                     12 x 256 x 16 x 16  
Conv2d                                    32768      False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Sequential                                                     
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Sequential                                                     
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Sequential                                                     
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Sequential                                                     
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     12 x 512 x 8 x 8    
Conv2d                                    1179648    False     
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
____________________________________________________________________________
                     []                  
AvgPool2d                                                      
____________________________________________________________________________
                     12 x 512 x 8 x 8    
Conv2d                                    131072     False     
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
Sequential                                                     
ReLU                                                           
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
Sequential                                                     
ReLU                                                           
BatchNorm2d                               1024       True      
ReLU                                                           
____________________________________________________________________________
                     12 x 1024 x 8 x 8   
Conv2d                                    4719616    True      
ReLU                                                           
____________________________________________________________________________
                     12 x 512 x 8 x 8    
Conv2d                                    4719104    True      
ReLU                                                           
____________________________________________________________________________
                     12 x 1024 x 8 x 8   
Conv2d                                    525312     True      
ReLU                                                           
PixelShuffle                                                   
BatchNorm2d                               512        True      
Conv2d                                    2359808    True      
ReLU                                                           
Conv2d                                    2359808    True      
ReLU                                                           
ReLU                                                           
____________________________________________________________________________
                     12 x 1024 x 16 x 16 
Conv2d                                    525312     True      
ReLU                                                           
PixelShuffle                                                   
BatchNorm2d                               256        True      
Conv2d                                    1327488    True      
ReLU                                                           
Conv2d                                    1327488    True      
ReLU                                                           
ReLU                                                           
____________________________________________________________________________
                     12 x 768 x 32 x 32  
Conv2d                                    295680     True      
ReLU                                                           
PixelShuffle                                                   
BatchNorm2d                               128        True      
Conv2d                                    590080     True      
ReLU                                                           
Conv2d                                    590080     True      
ReLU                                                           
ReLU                                                           
____________________________________________________________________________
                     12 x 512 x 64 x 64  
Conv2d                                    131584     True      
ReLU                                                           
PixelShuffle                                                   
BatchNorm2d                               128        True      
____________________________________________________________________________
                     12 x 96 x 128 x 128 
Conv2d                                    165984     True      
ReLU                                                           
Conv2d                                    83040      True      
ReLU                                                           
ReLU                                                           
____________________________________________________________________________
                     12 x 384 x 128 x 12 
Conv2d                                    37248      True      
ReLU                                                           
PixelShuffle                                                   
ResizeToOrig                                                   
MergeLayer                                                     
Conv2d                                    88308      True      
ReLU                                                           
Conv2d                                    88308      True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     12 x 2 x 255 x 255  
Conv2d                                    200        True      
____________________________________________________________________________

Total params: 41,240,400
Total trainable params: 19,953,648
Total non-trainable params: 21,286,752

Optimizer used: <function Adam at 0x7eff70070cb0>
Loss function: FlattenedLoss of CrossEntropyLoss()

Model frozen up to parameter group #2

Callbacks:
  - TrainEvalCallback
  - Recorder
  - ProgressCallback
  - SaveModelCallback
learn.model

DynamicUnet(
  (layers): ModuleList(
    (0): Sequential(
      (0): ConvLayer(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (1): ConvLayer(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (2): ConvLayer(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (1): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (2): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
      )
      (5): Sequential(
        (0): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential(
            (0): AvgPool2d(kernel_size=2, stride=2, padding=0)
            (1): ConvLayer(
              (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (act): ReLU(inplace=True)
        )
        (1): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (2): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (3): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
      )
      (6): Sequential(
        (0): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential(
            (0): AvgPool2d(kernel_size=2, stride=2, padding=0)
            (1): ConvLayer(
              (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (act): ReLU(inplace=True)
        )
        (1): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (2): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (3): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (4): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (5): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
      )
      (7): Sequential(
        (0): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential(
            (0): AvgPool2d(kernel_size=2, stride=2, padding=0)
            (1): ConvLayer(
              (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (act): ReLU(inplace=True)
        )
        (1): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (2): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
      )
    )
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Sequential(
      (0): ConvLayer(
        (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (1): ConvLayer(
        (0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
    )
    (4): UnetBlock(
      (shuf): PixelShuffle_ICNR(
        (0): ConvLayer(
          (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU()
        )
        (1): PixelShuffle(upscale_factor=2)
      )
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): ConvLayer(
        (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (conv2): ConvLayer(
        (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (relu): ReLU()
    )
    (5): UnetBlock(
      (shuf): PixelShuffle_ICNR(
        (0): ConvLayer(
          (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU()
        )
        (1): PixelShuffle(upscale_factor=2)
      )
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): ConvLayer(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (conv2): ConvLayer(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (relu): ReLU()
    )
    (6): UnetBlock(
      (shuf): PixelShuffle_ICNR(
        (0): ConvLayer(
          (0): Conv2d(384, 768, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU()
        )
        (1): PixelShuffle(upscale_factor=2)
      )
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): ConvLayer(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (conv2): ConvLayer(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (relu): ReLU()
    )
    (7): UnetBlock(
      (shuf): PixelShuffle_ICNR(
        (0): ConvLayer(
          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU()
        )
        (1): PixelShuffle(upscale_factor=2)
      )
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): ConvLayer(
        (0): Conv2d(192, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (conv2): ConvLayer(
        (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (relu): ReLU()
    )
    (8): PixelShuffle_ICNR(
      (0): ConvLayer(
        (0): Conv2d(96, 384, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU()
      )
      (1): PixelShuffle(upscale_factor=2)
    )
    (9): ResizeToOrig()
    (10): MergeLayer()
    (11): ResBlock(
      (convpath): Sequential(
        (0): ConvLayer(
          (0): Conv2d(99, 99, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU()
        )
        (1): ConvLayer(
          (0): Conv2d(99, 99, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (idpath): Sequential()
      (act): ReLU(inplace=True)
    )
    (12): ConvLayer(
      (0): Conv2d(99, 2, kernel_size=(1, 1), stride=(1, 1))
    )
  )
)
 
learn.lr_find(suggestions=False)
# Run2
learn.unfreeze()
learn.fit_one_cycle(
    EPOCHS,
    lr_max=slice(lr_max/ENCODER_FACTOR, lr_max),
    cbs=[WandbCallback()]
)

epoch train_loss valid_loss dice foreground_acc time
0 1.937740 0.952615 0.182434 0.229873 00:43
1 0.905160 0.588179 0.333524 0.360573 00:46
2 0.587758 0.485159 0.409871 0.416817 00:47
3 0.461647 0.448966 0.419921 0.572241 00:47
4 0.431045 0.412626 0.434957 0.437551 00:47
5 0.384136 0.399362 0.441417 0.484893 00:48
6 0.372647 0.469187 0.417687 0.342526 00:49
7 0.358055 0.401312 0.433877 0.474332 00:48
8 0.335158 0.426308 0.409999 0.420248 00:49
9 0.340261 0.388585 0.453323 0.680999 00:49
10 0.350476 0.424114 0.427702 0.423918 00:49
11 0.340361 0.830213 0.363972 0.274424 00:49
12 0.347802 0.396467 0.447701 0.424461 00:49
13 0.347854 0.361305 0.471549 0.464989 00:49
14 0.322265 0.496307 0.421688 0.320692 00:49
15 0.312543 0.419182 0.407756 0.352709 00:49
16 0.304165 0.388873 0.451288 0.371343 00:48
17 0.299325 0.477858 0.439302 0.351467 00:48
18 0.308990 0.409736 0.413224 0.461884 00:48
19 0.288411 0.498319 0.370609 0.280490 00:48
20 0.281468 0.360414 0.485021 0.485740 00:48
21 0.274416 0.418477 0.475279 0.414546 00:48
22 0.278710 0.325033 0.520833 0.521762 00:48
23 0.270224 0.359991 0.506736 0.490314 00:49
24 0.278280 0.342068 0.522744 0.638762 00:48
25 0.268276 0.384951 0.448934 0.658303 00:48
26 0.266887 0.318082 0.537392 0.628973 00:48
27 0.262979 0.364846 0.507638 0.516916 00:48
28 0.257349 0.333423 0.531541 0.588342 00:48
29 0.254129 0.337612 0.534185 0.635413 00:47
30 0.256843 0.371552 0.523938 0.519276 00:47
31 0.252976 0.376065 0.502235 0.588909 00:48
32 0.250810 0.393963 0.521790 0.486784 00:48
33 0.246234 0.375275 0.519142 0.489137 00:48
34 0.245064 0.331335 0.528236 0.686900 00:48
35 0.246199 0.377710 0.537571 0.545870 00:47
36 0.244448 0.368422 0.542517 0.553300 00:49
37 0.241390 0.349010 0.519374 0.491989 00:48
38 0.243630 0.359600 0.509863 0.521250 00:48
39 0.241627 0.336892 0.539313 0.583476 00:48
40 0.238271 0.367133 0.536665 0.515662 00:47
41 0.232399 0.347052 0.545673 0.578082 00:48
42 0.232242 0.319940 0.536428 0.673036 00:48
43 0.235520 0.320792 0.553474 0.609017 00:47
44 0.226998 0.326368 0.549632 0.638580 00:48
45 0.229752 0.361821 0.542705 0.521525 00:48
46 0.231764 0.356062 0.544539 0.571413 00:48
47 0.226325 0.372531 0.541578 0.529744 00:48
48 0.226063 0.374176 0.542866 0.539490 00:48
49 0.223156 0.416989 0.514208 0.441924 00:48
50 0.225011 0.389048 0.541989 0.558846 00:48
51 0.221490 0.363521 0.539199 0.517616 00:48
52 0.221223 0.344703 0.554917 0.570877 00:48
53 0.219953 0.349854 0.550852 0.559267 00:47
54 0.220182 0.371416 0.547601 0.532858 00:48
55 0.218004 0.387494 0.507081 0.445336 00:47
56 0.216417 0.355569 0.554721 0.555101 00:48
57 0.214520 0.388316 0.538123 0.508856 00:48
58 0.219770 0.350133 0.553124 0.555288 00:48
59 0.213634 0.380796 0.545702 0.530716 00:48
60 0.214770 0.403674 0.544516 0.526405 00:48
61 0.212837 0.382615 0.547741 0.530949 00:49
62 0.210907 0.385081 0.551419 0.559432 00:50
63 0.211821 0.383210 0.538854 0.495931 00:49
64 0.211744 0.366420 0.551287 0.553639 00:49
65 0.208235 0.393033 0.544909 0.509530 00:49
66 0.209825 0.404355 0.536005 0.474010 00:49
67 0.210346 0.362994 0.552082 0.538402 00:49
68 0.208779 0.372804 0.547688 0.514751 00:48
69 0.209830 0.351810 0.555377 0.548285 00:48
70 0.207887 0.386565 0.541370 0.496355 00:49
71 0.206442 0.377267 0.548791 0.524507 00:48
72 0.207366 0.374546 0.548827 0.523453 00:48
73 0.205153 0.365717 0.549390 0.525388 00:49
74 0.207726 0.367150 0.552252 0.542283 00:48
75 0.208129 0.372454 0.549932 0.532464 00:48
76 0.208461 0.360281 0.553864 0.548102 00:48
77 0.207639 0.372237 0.550863 0.531641 00:48
78 0.208036 0.364154 0.553543 0.545880 00:48
79 0.206991 0.361695 0.552300 0.540092 00:47
Better model found at epoch 0 with dice value: 0.1824341691331724.
Better model found at epoch 1 with dice value: 0.333523651466895.
Better model found at epoch 2 with dice value: 0.4098708280811938.
Better model found at epoch 3 with dice value: 0.4199211105853764.
Better model found at epoch 4 with dice value: 0.4349567064541659.
Better model found at epoch 5 with dice value: 0.4414173943941379.
Better model found at epoch 9 with dice value: 0.4533229756356818.
Better model found at epoch 13 with dice value: 0.4715493178986515.
Better model found at epoch 20 with dice value: 0.48502134642674577.
Better model found at epoch 22 with dice value: 0.5208334194588561.
Better model found at epoch 24 with dice value: 0.5227439605738604.
Better model found at epoch 26 with dice value: 0.5373924110580985.
Better model found at epoch 35 with dice value: 0.5375711386569465.
Better model found at epoch 36 with dice value: 0.5425170370282605.
Better model found at epoch 41 with dice value: 0.5456728971466592.
Better model found at epoch 43 with dice value: 0.5534739872514315.
Better model found at epoch 52 with dice value: 0.5549168145010914.
Better model found at epoch 69 with dice value: 0.5553773335762703.
learn.recorder.plot_loss()
learn.save('xres34-best-long')
 run.finish()

-----Experiment Run: 3

learn.save('xres34-best-long')
del learn
BATCH_SIZE = 3 # 3 for xresnet50, 12 for xresnet34 with Tesla P100 (16GB)
TILES_PER_SCENE = 16
ARCHITECTURE = xresnet50
EPOCHS = 40
CLASS_WEIGHTS = [0.25,0.75]
LR_MAX = 3e-4
ENCODER_FACTOR = 10
CODES = ['Land','Building']
# Weights and Biases config
config_dictionary = dict(
    bs=BATCH_SIZE,
    tiles_per_scene=TILES_PER_SCENE,
    architecture = str(ARCHITECTURE),
    epochs = EPOCHS,
    class_weights = CLASS_WEIGHTS,
    lr_max = LR_MAX,
    encoder_factor = ENCODER_FACTOR
)
learn = unet_learner(dls,                                 # DataLoaders
                     ARCHITECTURE,                        # xResNet50
                     loss_func = loss_func,               # Weighted cross entropy loss
                     opt_func = Adam,                     # Adam optimizer
                     metrics = [Dice(), foreground_acc],  # Custom metrics
                     self_attention = False,
                     cbs = [SaveModelCallback(
                              monitor='dice',
                              comp=np.greater,
                              fname='best-model-0'
                            )]
                    )
learn.model

DynamicUnet(
  (layers): ModuleList(
    (0): Sequential(
      (0): ConvLayer(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (1): ConvLayer(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (2): ConvLayer(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (act): ReLU(inplace=True)
        )
        (1): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (2): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
      )
      (5): Sequential(
        (0): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential(
            (0): AvgPool2d(kernel_size=2, stride=2, padding=0)
            (1): ConvLayer(
              (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (act): ReLU(inplace=True)
        )
        (1): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (2): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (3): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
      )
      (6): Sequential(
        (0): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential(
            (0): AvgPool2d(kernel_size=2, stride=2, padding=0)
            (1): ConvLayer(
              (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (act): ReLU(inplace=True)
        )
        (1): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (2): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (3): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (4): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (5): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
      )
      (7): Sequential(
        (0): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential(
            (0): AvgPool2d(kernel_size=2, stride=2, padding=0)
            (1): ConvLayer(
              (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (act): ReLU(inplace=True)
        )
        (1): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (2): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
      )
    )
    (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Sequential(
      (0): ConvLayer(
        (0): Conv2d(2048, 4096, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (1): ConvLayer(
        (0): Conv2d(4096, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
    )
    (4): UnetBlock(
      (shuf): PixelShuffle_ICNR(
        (0): ConvLayer(
          (0): Conv2d(2048, 4096, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU()
        )
        (1): PixelShuffle(upscale_factor=2)
      )
      (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): ConvLayer(
        (0): Conv2d(2048, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (conv2): ConvLayer(
        (0): Conv2d(2048, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (relu): ReLU()
    )
    (5): UnetBlock(
      (shuf): PixelShuffle_ICNR(
        (0): ConvLayer(
          (0): Conv2d(2048, 4096, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU()
        )
        (1): PixelShuffle(upscale_factor=2)
      )
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): ConvLayer(
        (0): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (conv2): ConvLayer(
        (0): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (relu): ReLU()
    )
    (6): UnetBlock(
      (shuf): PixelShuffle_ICNR(
        (0): ConvLayer(
          (0): Conv2d(1536, 3072, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU()
        )
        (1): PixelShuffle(upscale_factor=2)
      )
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): ConvLayer(
        (0): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (conv2): ConvLayer(
        (0): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (relu): ReLU()
    )
    (7): UnetBlock(
      (shuf): PixelShuffle_ICNR(
        (0): ConvLayer(
          (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU()
        )
        (1): PixelShuffle(upscale_factor=2)
      )
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): ConvLayer(
        (0): Conv2d(576, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (conv2): ConvLayer(
        (0): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (relu): ReLU()
    )
    (8): PixelShuffle_ICNR(
      (0): ConvLayer(
        (0): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU()
      )
      (1): PixelShuffle(upscale_factor=2)
    )
    (9): ResizeToOrig()
    (10): MergeLayer()
    (11): ResBlock(
      (convpath): Sequential(
        (0): ConvLayer(
          (0): Conv2d(291, 291, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU()
        )
        (1): ConvLayer(
          (0): Conv2d(291, 291, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (idpath): Sequential()
      (act): ReLU(inplace=True)
    )
    (12): ConvLayer(
      (0): Conv2d(291, 2, kernel_size=(1, 1), stride=(1, 1))
    )
    (13): ToTensorBase(tensor_cls=<class 'fastai.torch_core.TensorBase'>)
  )
)
learn.unfreeze()
learn.fit_one_cycle(
    EPOCHS,
    lr_max=slice(lr_max/ENCODER_FACTOR, lr_max),
    cbs=[WandbCallback()]
)
learn = unet_learner(dls,                                 # DataLoaders
                     ARCHITECTURE,                        # xResNet50
                     loss_func = loss_func,               # Weighted cross entropy loss
                     opt_func = Adam,                     # Adam optimizer
                     metrics = [Dice(), foreground_acc],  # Custom metrics
                     self_attention = False,
                     cbs = [SaveModelCallback(
                              monitor='dice',
                              comp=np.greater,
                              fname='best-model'
                            )]
                     )
Downloading: "https://s3.amazonaws.com/fast-ai-modelzoo/xrn50_940.pth" to /root/.cache/torch/hub/checkpoints/xrn50_940.pth

learn.model

DynamicUnet(
  (layers): ModuleList(
    (0): Sequential(
      (0): ConvLayer(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (1): ConvLayer(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (2): ConvLayer(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (act): ReLU(inplace=True)
        )
        (1): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (2): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
      )
      (5): Sequential(
        (0): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential(
            (0): AvgPool2d(kernel_size=2, stride=2, padding=0)
            (1): ConvLayer(
              (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (act): ReLU(inplace=True)
        )
        (1): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (2): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (3): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
      )
      (6): Sequential(
        (0): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential(
            (0): AvgPool2d(kernel_size=2, stride=2, padding=0)
            (1): ConvLayer(
              (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (act): ReLU(inplace=True)
        )
        (1): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (2): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (3): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (4): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (5): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
      )
      (7): Sequential(
        (0): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential(
            (0): AvgPool2d(kernel_size=2, stride=2, padding=0)
            (1): ConvLayer(
              (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (act): ReLU(inplace=True)
        )
        (1): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
        (2): ResBlock(
          (convpath): Sequential(
            (0): ConvLayer(
              (0): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): ConvLayer(
              (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (2): ConvLayer(
              (0): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (idpath): Sequential()
          (act): ReLU(inplace=True)
        )
      )
    )
    (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Sequential(
      (0): ConvLayer(
        (0): Conv2d(2048, 4096, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (1): ConvLayer(
        (0): Conv2d(4096, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
    )
    (4): UnetBlock(
      (shuf): PixelShuffle_ICNR(
        (0): ConvLayer(
          (0): Conv2d(2048, 4096, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU()
        )
        (1): PixelShuffle(upscale_factor=2)
      )
      (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): ConvLayer(
        (0): Conv2d(2048, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (conv2): ConvLayer(
        (0): Conv2d(2048, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (relu): ReLU()
    )
    (5): UnetBlock(
      (shuf): PixelShuffle_ICNR(
        (0): ConvLayer(
          (0): Conv2d(2048, 4096, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU()
        )
        (1): PixelShuffle(upscale_factor=2)
      )
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): ConvLayer(
        (0): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (conv2): ConvLayer(
        (0): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (relu): ReLU()
    )
    (6): UnetBlock(
      (shuf): PixelShuffle_ICNR(
        (0): ConvLayer(
          (0): Conv2d(1536, 3072, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU()
        )
        (1): PixelShuffle(upscale_factor=2)
      )
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): ConvLayer(
        (0): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (conv2): ConvLayer(
        (0): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (relu): ReLU()
    )
    (7): UnetBlock(
      (shuf): PixelShuffle_ICNR(
        (0): ConvLayer(
          (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU()
        )
        (1): PixelShuffle(upscale_factor=2)
      )
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): ConvLayer(
        (0): Conv2d(576, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (conv2): ConvLayer(
        (0): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (relu): ReLU()
    )
    (8): PixelShuffle_ICNR(
      (0): ConvLayer(
        (0): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU()
      )
      (1): PixelShuffle(upscale_factor=2)
    )
    (9): ResizeToOrig()
    (10): MergeLayer()
    (11): ResBlock(
      (convpath): Sequential(
        (0): ConvLayer(
          (0): Conv2d(291, 291, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU()
        )
        (1): ConvLayer(
          (0): Conv2d(291, 291, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (idpath): Sequential()
      (act): ReLU(inplace=True)
    )
    (12): ConvLayer(
      (0): Conv2d(291, 2, kernel_size=(1, 1), stride=(1, 1))
    )
  )
)
learn.summary()

DynamicUnet (Input shape: 12)
============================================================================
Layer (type)         Output Shape         Param #    Trainable 
============================================================================
                     12 x 32 x 128 x 128 
Conv2d                                    864        False     
BatchNorm2d                               64         True      
ReLU                                                           
Conv2d                                    9216       False     
BatchNorm2d                               64         True      
ReLU                                                           
____________________________________________________________________________
                     12 x 64 x 128 x 128 
Conv2d                                    18432      False     
BatchNorm2d                               128        True      
ReLU                                                           
MaxPool2d                                                      
Conv2d                                    4096       False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
____________________________________________________________________________
                     12 x 256 x 64 x 64  
Conv2d                                    16384      False     
BatchNorm2d                               512        True      
Conv2d                                    16384      False     
BatchNorm2d                               512        True      
ReLU                                                           
____________________________________________________________________________
                     12 x 64 x 64 x 64   
Conv2d                                    16384      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
____________________________________________________________________________
                     12 x 256 x 64 x 64  
Conv2d                                    16384      False     
BatchNorm2d                               512        True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     12 x 64 x 64 x 64   
Conv2d                                    16384      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
____________________________________________________________________________
                     12 x 256 x 64 x 64  
Conv2d                                    16384      False     
BatchNorm2d                               512        True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     12 x 128 x 64 x 64  
Conv2d                                    32768      False     
BatchNorm2d                               256        True      
ReLU                                                           
____________________________________________________________________________
                     12 x 128 x 32 x 32  
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
ReLU                                                           
____________________________________________________________________________
                     12 x 512 x 32 x 32  
Conv2d                                    65536      False     
BatchNorm2d                               1024       True      
____________________________________________________________________________
                     []                  
AvgPool2d                                                      
____________________________________________________________________________
                     12 x 512 x 32 x 32  
Conv2d                                    131072     False     
BatchNorm2d                               1024       True      
ReLU                                                           
____________________________________________________________________________
                     12 x 128 x 32 x 32  
Conv2d                                    65536      False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
ReLU                                                           
____________________________________________________________________________
                     12 x 512 x 32 x 32  
Conv2d                                    65536      False     
BatchNorm2d                               1024       True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     12 x 128 x 32 x 32  
Conv2d                                    65536      False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
ReLU                                                           
____________________________________________________________________________
                     12 x 512 x 32 x 32  
Conv2d                                    65536      False     
BatchNorm2d                               1024       True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     12 x 128 x 32 x 32  
Conv2d                                    65536      False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
ReLU                                                           
____________________________________________________________________________
                     12 x 512 x 32 x 32  
Conv2d                                    65536      False     
BatchNorm2d                               1024       True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     12 x 256 x 32 x 32  
Conv2d                                    131072     False     
BatchNorm2d                               512        True      
ReLU                                                           
____________________________________________________________________________
                     12 x 256 x 16 x 16  
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
____________________________________________________________________________
                     12 x 1024 x 16 x 16 
Conv2d                                    262144     False     
BatchNorm2d                               2048       True      
____________________________________________________________________________
                     []                  
AvgPool2d                                                      
____________________________________________________________________________
                     12 x 1024 x 16 x 16 
Conv2d                                    524288     False     
BatchNorm2d                               2048       True      
ReLU                                                           
____________________________________________________________________________
                     12 x 256 x 16 x 16  
Conv2d                                    262144     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
____________________________________________________________________________
                     12 x 1024 x 16 x 16 
Conv2d                                    262144     False     
BatchNorm2d                               2048       True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     12 x 256 x 16 x 16  
Conv2d                                    262144     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
____________________________________________________________________________
                     12 x 1024 x 16 x 16 
Conv2d                                    262144     False     
BatchNorm2d                               2048       True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     12 x 256 x 16 x 16  
Conv2d                                    262144     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
____________________________________________________________________________
                     12 x 1024 x 16 x 16 
Conv2d                                    262144     False     
BatchNorm2d                               2048       True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     12 x 256 x 16 x 16  
Conv2d                                    262144     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
____________________________________________________________________________
                     12 x 1024 x 16 x 16 
Conv2d                                    262144     False     
BatchNorm2d                               2048       True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     12 x 256 x 16 x 16  
Conv2d                                    262144     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
____________________________________________________________________________
                     12 x 1024 x 16 x 16 
Conv2d                                    262144     False     
BatchNorm2d                               2048       True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     12 x 512 x 16 x 16  
Conv2d                                    524288     False     
BatchNorm2d                               1024       True      
ReLU                                                           
____________________________________________________________________________
                     12 x 512 x 8 x 8    
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
ReLU                                                           
____________________________________________________________________________
                     12 x 2048 x 8 x 8   
Conv2d                                    1048576    False     
BatchNorm2d                               4096       True      
____________________________________________________________________________
                     []                  
AvgPool2d                                                      
____________________________________________________________________________
                     12 x 2048 x 8 x 8   
Conv2d                                    2097152    False     
BatchNorm2d                               4096       True      
ReLU                                                           
____________________________________________________________________________
                     12 x 512 x 8 x 8    
Conv2d                                    1048576    False     
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
ReLU                                                           
____________________________________________________________________________
                     12 x 2048 x 8 x 8   
Conv2d                                    1048576    False     
BatchNorm2d                               4096       True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     12 x 512 x 8 x 8    
Conv2d                                    1048576    False     
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
ReLU                                                           
____________________________________________________________________________
                     12 x 2048 x 8 x 8   
Conv2d                                    1048576    False     
BatchNorm2d                               4096       True      
Sequential                                                     
ReLU                                                           
BatchNorm2d                               4096       True      
ReLU                                                           
____________________________________________________________________________
                     12 x 4096 x 8 x 8   
Conv2d                                    75501568   True      
ReLU                                                           
____________________________________________________________________________
                     12 x 2048 x 8 x 8   
Conv2d                                    75499520   True      
ReLU                                                           
____________________________________________________________________________
                     12 x 4096 x 8 x 8   
Conv2d                                    8392704    True      
ReLU                                                           
PixelShuffle                                                   
BatchNorm2d                               2048       True      
Conv2d                                    37750784   True      
ReLU                                                           
Conv2d                                    37750784   True      
ReLU                                                           
ReLU                                                           
____________________________________________________________________________
                     12 x 4096 x 16 x 16 
Conv2d                                    8392704    True      
ReLU                                                           
PixelShuffle                                                   
BatchNorm2d                               1024       True      
Conv2d                                    21235200   True      
ReLU                                                           
Conv2d                                    21235200   True      
ReLU                                                           
ReLU                                                           
____________________________________________________________________________
                     12 x 3072 x 32 x 32 
Conv2d                                    4721664    True      
ReLU                                                           
PixelShuffle                                                   
BatchNorm2d                               512        True      
Conv2d                                    9438208    True      
ReLU                                                           
Conv2d                                    9438208    True      
ReLU                                                           
ReLU                                                           
____________________________________________________________________________
                     12 x 2048 x 64 x 64 
Conv2d                                    2099200    True      
ReLU                                                           
PixelShuffle                                                   
BatchNorm2d                               128        True      
____________________________________________________________________________
                     12 x 288 x 128 x 12 
Conv2d                                    1493280    True      
ReLU                                                           
Conv2d                                    746784     True      
ReLU                                                           
ReLU                                                           
____________________________________________________________________________
                     12 x 1152 x 128 x 1 
Conv2d                                    332928     True      
ReLU                                                           
PixelShuffle                                                   
ResizeToOrig                                                   
MergeLayer                                                     
Conv2d                                    762420     True      
ReLU                                                           
Conv2d                                    762420     True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     12 x 2 x 255 x 255  
Conv2d                                    584        True      
____________________________________________________________________________

Total params: 339,089,232
Total trainable params: 315,615,216
Total non-trainable params: 23,474,016

Optimizer used: <function Adam at 0x7ff189be5680>
Loss function: FlattenedLoss of CrossEntropyLoss()

Model frozen up to parameter group #2

Callbacks:
  - TrainEvalCallback
  - Recorder
  - ProgressCallback
  - SaveModelCallback
learn.lr_find(suggestions=False)
!nvidia-smi
Tue Jun  8 16:12:57 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   72C    P0    32W /  70W |  13552MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+
import torch, gc
gc.collect()
torch.cuda.empty_cache()
learn.unfreeze()
learn.fit_one_cycle(
    EPOCHS,
    lr_max=slice(lr_max/ENCODER_FACTOR, lr_max),
    cbs=[WandbCallback()]
)
learn.save('xres43-best-55')
Path('models/xres43-best-55.pth')
run.finish()

Waiting for W&B process to finish, PID 15662
Program ended successfully.
Find user logs for this run at: /content/gdrive/Shareddrives/dataset/s7/SN7_buildings_train/train/L15-1716E-1211N_6864_3345_13/wandb/run-20210607_160231-1bj2r9kt/logs/debug.log
Find internal logs for this run at: /content/gdrive/Shareddrives/dataset/s7/SN7_buildings_train/train/L15-1716E-1211N_6864_3345_13/wandb/run-20210607_160231-1bj2r9kt/logs/debug-internal.log
Synced 4 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)

Discussion

I'm not sure how to rate the results because I don't have any comparison. A Dice score of 0.57 doesn't sound great. But considering how difficult the dataset is and that I didn't customize the architecture at all, I'm quite pleased with the result. There is a lot to improve however! The original SpaceNet7 challenge required recognizing individual buildings and tracking them trough time, that's something I'd like to work on in the future.

What worked?

  • Using a pretrained encoder.
  • Ignoring most images of the dataset. I tried using 5 instead of 1 images per scene, which increased training time by 5 but did not improve the results significantly.
  • Standard data augmentations. Without them, the model started to overfit sooner.
  • Undersampling. While it did not have a large effect, it sped up training a little bit and it helped the accuracy.
  • Weighted cross-entropy loss. Without the weights, the model had a strong bias towards the dominating background class and failed to recognize many buildings.

What didn't?

  • I hoped to get faster training with the Mish activation function, but training was unstable.
  • Dice loss instead of cross-entropy loss was unstable as well.
  • I tried adding self-attention to the U-Net, I hoped it would help classifying larger structures. I did not notice a significant difference.
  • A deeper xResNet50 encoder that I tried increased training time 6-fold, but did not improve results.

Other ideas to improve the results

  • Better data processing: using overlapping tiles, scaling up the image tiles
  • Dynamic thresholding for turning the predicted probabilities into a binary mask.
  • Implement recent advancements in segmentation models, ie UNet with ASPP or Eff-UNet
  • More compute: Deeper models. Use cross-validation with several folds to utilize all 60 scenes. Ensemble different models.

Thank you for reading this far! The challenge was fun and I learned a lot. There is also a lot of room for improvement and work to do :)