In [1]:
import os

import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset

from torchvision.transforms import v2

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

import numpy as np
import matplotlib.pyplot as plt
import sklearn.metrics as metrics
import random

random.seed(42)
torch.manual_seed(42)
np.random.seed(42)

import timm
from pprint import pprint
from collections import Counter
In [2]:
device = 'cuda'
In [3]:
DATA_PATH = '/net/travail/bformanek/MRI_dataset'
TRAIN_FOLDER = DATA_PATH + '/train'
VAL_FOLDER = DATA_PATH + '/val'
TEST_FOLDER = DATA_PATH + '/test'

train_categories = os.listdir(TRAIN_FOLDER)
val_categories = os.listdir(VAL_FOLDER)
test_categories = os.listdir(TEST_FOLDER)

print("Train image distribution: ")
class_num_in_train = []
for i in range(0, len(train_categories)):
  CLASS_FOLDER = TRAIN_FOLDER + '/' + train_categories[i]
  class_elements = os.listdir(CLASS_FOLDER)
  class_num_in_train.append(len(class_elements))
  print(f' {train_categories[i]}: {class_num_in_train[i]}')
  
print("Validation image distribution: ")
class_num_in_val = []
for i in range(0, len(val_categories)):
  CLASS_FOLDER = VAL_FOLDER + '/' + val_categories[i]
  class_elements = os.listdir(CLASS_FOLDER)
  class_num_in_val.append(len(class_elements))
  print(f' {val_categories[i]}: {class_num_in_val[i]}')
  
print("Test image distribution: ")
class_num_in_test = []
for i in range(0, len(test_categories)):
  CLASS_FOLDER = TEST_FOLDER + '/' + test_categories[i]
  class_elements = os.listdir(CLASS_FOLDER)
  class_num_in_test.append(len(class_elements))
  print(f' {test_categories[i]}: {class_num_in_test[i]}')
  
num_classes = len(class_num_in_train)
Train image distribution: 
 T2star: 25
 T2w: 1156
 FLAIRCE: 1126
 FLAIR: 5950
 T1w: 5881
 OTHER: 382
 T1wCE: 5944
Validation image distribution: 
 T2w: 160
 FLAIRCE: 157
 FLAIR: 844
 T1w: 838
 OTHER: 49
 T1wCE: 844
Test image distribution: 
 T2star: 4
 T2w: 325
 FLAIRCE: 316
 FLAIR: 1693
 T1w: 1678
 OTHER: 118
 T1wCE: 1696
In [4]:
def train_for_epoch_with_scaler(model, train_loader, optimizer, criterion, scaler, device):
    # set model to train
    model.train()
    
    train_losses = []
    train_accuracies = []
    counter = 0

    for batch, target in train_loader:

        # data to GPU
        batch = batch.to(device)
        target = target.to(device)

        # reset optimizer
        optimizer.zero_grad()

        # forward pass
        predictions = model(batch)

        # calculate accuracy
        accuracy = (torch.argmax(predictions, dim=1) == target).sum().item() / target.size(0)
        
        # calculate loss
        loss = criterion(predictions, target)

        # backward pass
        scaler.scale(loss).backward()

        # parameter update
        scaler.step(optimizer)
        scaler.update()

        # track loss
        train_losses.append(float(loss.item()))
        train_accuracies.append(accuracy)

        counter += 1
        if counter % 20 == 0:
          print('[{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                int(counter * len(batch)), len(train_loader.dataset),
                100. * counter / len(train_loader), loss.item()))

    train_loss = np.mean(np.array(train_losses))
    train_accuracy = np.mean(np.array(train_accuracies))
    
    print('\nTrain: Average loss: {:.4f}, Accuracy: {:.4f}\n'.format(
        train_loss, train_accuracy))
    
    return train_loss, train_accuracy

def validate(model, val_loader, criterion, device):
    model.eval()
    
    val_losses = []
    y_true, y_pred = [], []

    with torch.no_grad():
        for batch, target in val_loader:

            # move data to the device
            batch = batch.to(device)
            target = target.to(device)

            with torch.autocast(device_type=device, dtype=torch.float16):
              # make predictions
              predictions = model(batch)

              # calculate loss
              loss = criterion(predictions, target)

            # track losses and predictions
            val_losses.append(float(loss.item()))
            y_true.extend(target.cpu().numpy())
            y_pred.extend(predictions.argmax(dim=1).cpu().numpy())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    val_losses = np.array(val_losses)

    # calculate validation accuracy from y_true and y_pred
    val_accuracy = np.mean(y_true == y_pred)

    # calculate the mean validation loss
    val_loss = np.mean(val_losses)

    print('Validation: Average loss: {:.4f}, Accuracy: {:.4f}\n'.format(
        val_loss, val_accuracy))

    return val_loss, val_accuracy

def train_with_scaler(model, train_loader, val_loader, optimizer, criterion, epochs, scaler, device, checkpoints_foler = None, first_epoch=1):
    train_losses, val_losses = [],  []
    train_accuracies, val_accuracies = [], []
    max_val_acc = 0
    best_epoch = 0

    for epoch in range(first_epoch, epochs+first_epoch):

        print('Train Epoch: {}'.format(epoch))

        # train
        train_loss, train_acc = train_for_epoch_with_scaler(model, train_loader, optimizer, criterion, scaler, device)

        # validation
        valid_loss, valid_acc = validate(model, val_loader, criterion, device)

        train_losses.append(train_loss)
        val_losses.append(valid_loss)
        train_accuracies.append(train_acc)
        val_accuracies.append(valid_acc)

        # save checkpoint
        if checkpoints_foler != None and max_val_acc < valid_acc:
          max_val_acc = valid_acc
          best_epoch = epoch
          torch.save(model, checkpoints_foler+f'/avp_{epoch:03d}.pkl')

    return train_losses, val_losses, train_accuracies, val_accuracies, best_epoch
In [5]:
# define custom resample class to change image resolution without rescaling
class RandomResample:
    def __init__(self, scale_factor):
        self.scale_factor = random.uniform(0,scale_factor)
    
    def __call__(self, img):
        # Downsample
        width, height = img.size
        downscaled_size = (int(width / self.scale_factor), int(height / self.scale_factor))
        
        # Downsample the image
        img_downsampled = img.resize(downscaled_size)
        
        # Upsample back to the original size
        img_upsampled = img_downsampled.resize((width, height))
        
        return img_upsampled
In [6]:
train_transform = transforms.Compose([
    transforms.v2.Resize(224),
    
    # augmentations
    transforms.v2.RandomHorizontalFlip(p=0.5),
    transforms.v2.RandomVerticalFlip(p=0.5),
    transforms.v2.RandomRotation(degrees=360, expand=True),  # expand=True: esnure that the whole image is represented on the rotated image
    #transforms.v2.ColorJitter(contrast=0.1),
    #transforms.v2.GaussianBlur(7, sigma=2),
    RandomResample(scale_factor=2),
    
    transforms.v2.Resize(224),
    transforms.ToTensor()
    
])
valid_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor()
])
In [7]:
train_set = ImageFolder(TRAIN_FOLDER, transform = train_transform)
val_set = ImageFolder(VAL_FOLDER, transform = valid_transform)
test_set = ImageFolder(TEST_FOLDER, transform = valid_transform)

BATCH_SIZE = 64
WORKERS = 8
train_loader = DataLoader(train_set, batch_size = BATCH_SIZE, shuffle = True, num_workers=WORKERS)
val_loader = DataLoader(val_set, batch_size = BATCH_SIZE,  shuffle = False, num_workers=WORKERS)
test_loader = DataLoader(test_set, batch_size = BATCH_SIZE,  shuffle = False, num_workers=WORKERS)

# print(f'train samples: {len(train_set)}  validation samples: {len(val_set)}  test samples: {len(test_set)}')

#for image_batch, labels_batch in train_loader:
#  print("Batch sizes:", image_batch.shape, "(batch, channels, height, width)")
#  print("Label vector size:", labels_batch.shape)
#  break
In [8]:
num_in_class_dict = dict(Counter(train_set.targets[i] for i in range(len(train_set))))
num_in_class = np.zeros([1,len(num_in_class_dict)])
for i in range(0, len(num_in_class_dict)):
  num_in_class[0, i] = num_in_class_dict[i]

class_weights = 1-(num_in_class/num_in_class.sum()).squeeze()
class_weights_tensor = torch.Tensor(class_weights).to(device)

# print(num_in_class_dict)
# print(num_in_class)
In [9]:
MODEL_NAME = 'resnet18' ##resnet18, resnet50, efficientnet_b0
In [10]:
model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=num_classes)
model.to(device)
model.safetensors:   0%|          | 0.00/46.8M [00:00<?, ?B/s]
Out[10]:
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
  )
  (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))
  (fc): Linear(in_features=512, out_features=7, bias=True)
)
In [11]:
criterion_balanced = nn.CrossEntropyLoss(weight = class_weights_tensor)
optimizer_Adam = optim.Adam(model.parameters(), 1e-3)
scaler = torch.cuda.amp.GradScaler()
/tmp/cache-bformanek/ipykernel_128826/3247579378.py:3: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  scaler = torch.cuda.amp.GradScaler()
In [12]:
RESULT_FOLDER_NAME = MODEL_NAME+"_flips_90_resample"

checkpoints_foler = '/net/travail/bformanek/checkpoints/transfer_checkpoints_'+RESULT_FOLDER_NAME
if not os.path.exists(checkpoints_foler):
    os.mkdir(checkpoints_foler)
In [13]:
epochs = 30
train_losses, val_losses, train_accuracies, val_accuracies, best_epoch = train_with_scaler(model, train_loader, val_loader, optimizer_Adam, criterion_balanced, 
                                                                                           epochs, scaler, device, checkpoints_foler=checkpoints_foler)
Train Epoch: 1
[1280/20457 (6%)]	Loss: 1.054133
[2560/20457 (12%)]	Loss: 0.725362
[3840/20457 (19%)]	Loss: 0.610338
[5120/20457 (25%)]	Loss: 0.386082
[6400/20457 (31%)]	Loss: 0.483535
[7680/20457 (38%)]	Loss: 0.419246
[8960/20457 (44%)]	Loss: 0.328190
[10240/20457 (50%)]	Loss: 0.308688
[11520/20457 (56%)]	Loss: 0.554291
[12800/20457 (62%)]	Loss: 0.545511
[14080/20457 (69%)]	Loss: 0.206743
[15360/20457 (75%)]	Loss: 0.248842
[16640/20457 (81%)]	Loss: 0.231432
[17920/20457 (88%)]	Loss: 0.220167
[19200/20457 (94%)]	Loss: 0.388959
[13120/20457 (100%)]	Loss: 0.223442

Train: Average loss: 0.4534, Accuracy: 0.8358

Validation: Average loss: 1.5219, Accuracy: 0.7027

Train Epoch: 2
[1280/20457 (6%)]	Loss: 0.154668
[2560/20457 (12%)]	Loss: 0.071672
[3840/20457 (19%)]	Loss: 0.456156
[5120/20457 (25%)]	Loss: 0.200928
[6400/20457 (31%)]	Loss: 0.288852
[7680/20457 (38%)]	Loss: 0.125828
[8960/20457 (44%)]	Loss: 0.158839
[10240/20457 (50%)]	Loss: 0.223312
[11520/20457 (56%)]	Loss: 0.165408
[12800/20457 (62%)]	Loss: 0.213668
[14080/20457 (69%)]	Loss: 0.175157
[15360/20457 (75%)]	Loss: 0.117205
[16640/20457 (81%)]	Loss: 0.178432
[17920/20457 (88%)]	Loss: 0.235235
[19200/20457 (94%)]	Loss: 0.135907
[13120/20457 (100%)]	Loss: 0.271736

Train: Average loss: 0.2184, Accuracy: 0.9194

Validation: Average loss: 1.2750, Accuracy: 0.8500

Train Epoch: 3
[1280/20457 (6%)]	Loss: 0.250221
[2560/20457 (12%)]	Loss: 0.124799
[3840/20457 (19%)]	Loss: 0.126421
[5120/20457 (25%)]	Loss: 0.105148
[6400/20457 (31%)]	Loss: 0.213071
[7680/20457 (38%)]	Loss: 0.075168
[8960/20457 (44%)]	Loss: 0.452005
[10240/20457 (50%)]	Loss: 0.217459
[11520/20457 (56%)]	Loss: 0.391471
[12800/20457 (62%)]	Loss: 0.218012
[14080/20457 (69%)]	Loss: 0.207581
[15360/20457 (75%)]	Loss: 0.060743
[16640/20457 (81%)]	Loss: 0.155725
[17920/20457 (88%)]	Loss: 0.163342
[19200/20457 (94%)]	Loss: 0.328871
[13120/20457 (100%)]	Loss: 0.150410

Train: Average loss: 0.1682, Accuracy: 0.9380

Validation: Average loss: 1.3876, Accuracy: 0.6535

Train Epoch: 4
[1280/20457 (6%)]	Loss: 0.155711
[2560/20457 (12%)]	Loss: 0.111006
[3840/20457 (19%)]	Loss: 0.153953
[5120/20457 (25%)]	Loss: 0.161102
[6400/20457 (31%)]	Loss: 0.193571
[7680/20457 (38%)]	Loss: 0.209899
[8960/20457 (44%)]	Loss: 0.139377
[10240/20457 (50%)]	Loss: 0.081941
[11520/20457 (56%)]	Loss: 0.179336
[12800/20457 (62%)]	Loss: 0.148540
[14080/20457 (69%)]	Loss: 0.137478
[15360/20457 (75%)]	Loss: 0.170339
[16640/20457 (81%)]	Loss: 0.022386
[17920/20457 (88%)]	Loss: 0.103414
[19200/20457 (94%)]	Loss: 0.172884
[13120/20457 (100%)]	Loss: 0.150939

Train: Average loss: 0.1391, Accuracy: 0.9509

Validation: Average loss: 1.3734, Accuracy: 0.7769

Train Epoch: 5
[1280/20457 (6%)]	Loss: 0.062116
[2560/20457 (12%)]	Loss: 0.379405
[3840/20457 (19%)]	Loss: 0.125783
[5120/20457 (25%)]	Loss: 0.279066
[6400/20457 (31%)]	Loss: 0.059350
[7680/20457 (38%)]	Loss: 0.189140
[8960/20457 (44%)]	Loss: 0.164375
[10240/20457 (50%)]	Loss: 0.063068
[11520/20457 (56%)]	Loss: 0.139561
[12800/20457 (62%)]	Loss: 0.068787
[14080/20457 (69%)]	Loss: 0.048323
[15360/20457 (75%)]	Loss: 0.120081
[16640/20457 (81%)]	Loss: 0.119275
[17920/20457 (88%)]	Loss: 0.179179
[19200/20457 (94%)]	Loss: 0.179864
[13120/20457 (100%)]	Loss: 0.312999

Train: Average loss: 0.1280, Accuracy: 0.9526

Validation: Average loss: 1.2929, Accuracy: 0.7526

Train Epoch: 6
[1280/20457 (6%)]	Loss: 0.113230
[2560/20457 (12%)]	Loss: 0.062875
[3840/20457 (19%)]	Loss: 0.216621
[5120/20457 (25%)]	Loss: 0.073777
[6400/20457 (31%)]	Loss: 0.144536
[7680/20457 (38%)]	Loss: 0.115534
[8960/20457 (44%)]	Loss: 0.097828
[10240/20457 (50%)]	Loss: 0.098064
[11520/20457 (56%)]	Loss: 0.118195
[12800/20457 (62%)]	Loss: 0.119108
[14080/20457 (69%)]	Loss: 0.038263
[15360/20457 (75%)]	Loss: 0.056077
[16640/20457 (81%)]	Loss: 0.143987
[17920/20457 (88%)]	Loss: 0.074841
[19200/20457 (94%)]	Loss: 0.082691
[13120/20457 (100%)]	Loss: 0.117237

Train: Average loss: 0.1107, Accuracy: 0.9585

Validation: Average loss: 1.4264, Accuracy: 0.8396

Train Epoch: 7
[1280/20457 (6%)]	Loss: 0.198399
[2560/20457 (12%)]	Loss: 0.164641
[3840/20457 (19%)]	Loss: 0.051346
[5120/20457 (25%)]	Loss: 0.137102
[6400/20457 (31%)]	Loss: 0.164529
[7680/20457 (38%)]	Loss: 0.102826
[8960/20457 (44%)]	Loss: 0.151736
[10240/20457 (50%)]	Loss: 0.176340
[11520/20457 (56%)]	Loss: 0.020442
[12800/20457 (62%)]	Loss: 0.168744
[14080/20457 (69%)]	Loss: 0.127338
[15360/20457 (75%)]	Loss: 0.064539
[16640/20457 (81%)]	Loss: 0.041766
[17920/20457 (88%)]	Loss: 0.105609
[19200/20457 (94%)]	Loss: 0.116521
[13120/20457 (100%)]	Loss: 0.195548

Train: Average loss: 0.1048, Accuracy: 0.9613

Validation: Average loss: 1.7050, Accuracy: 0.8177

Train Epoch: 8
[1280/20457 (6%)]	Loss: 0.050569
[2560/20457 (12%)]	Loss: 0.050108
[3840/20457 (19%)]	Loss: 0.094313
[5120/20457 (25%)]	Loss: 0.073465
[6400/20457 (31%)]	Loss: 0.012324
[7680/20457 (38%)]	Loss: 0.253782
[8960/20457 (44%)]	Loss: 0.053417
[10240/20457 (50%)]	Loss: 0.077354
[11520/20457 (56%)]	Loss: 0.061405
[12800/20457 (62%)]	Loss: 0.043902
[14080/20457 (69%)]	Loss: 0.098953
[15360/20457 (75%)]	Loss: 0.085991
[16640/20457 (81%)]	Loss: 0.071721
[17920/20457 (88%)]	Loss: 0.068376
[19200/20457 (94%)]	Loss: 0.042487
[13120/20457 (100%)]	Loss: 0.047457

Train: Average loss: 0.0958, Accuracy: 0.9647

Validation: Average loss: 1.7110, Accuracy: 0.7037

Train Epoch: 9
[1280/20457 (6%)]	Loss: 0.039522
[2560/20457 (12%)]	Loss: 0.090020
[3840/20457 (19%)]	Loss: 0.203921
[5120/20457 (25%)]	Loss: 0.052970
[6400/20457 (31%)]	Loss: 0.165568
[7680/20457 (38%)]	Loss: 0.052486
[8960/20457 (44%)]	Loss: 0.110796
[10240/20457 (50%)]	Loss: 0.118852
[11520/20457 (56%)]	Loss: 0.089906
[12800/20457 (62%)]	Loss: 0.037743
[14080/20457 (69%)]	Loss: 0.071622
[15360/20457 (75%)]	Loss: 0.091884
[16640/20457 (81%)]	Loss: 0.144048
[17920/20457 (88%)]	Loss: 0.020264
[19200/20457 (94%)]	Loss: 0.154216
[13120/20457 (100%)]	Loss: 0.068246

Train: Average loss: 0.0857, Accuracy: 0.9704

Validation: Average loss: 1.4652, Accuracy: 0.7879

Train Epoch: 10
[1280/20457 (6%)]	Loss: 0.087559
[2560/20457 (12%)]	Loss: 0.075757
[3840/20457 (19%)]	Loss: 0.131536
[5120/20457 (25%)]	Loss: 0.294059
[6400/20457 (31%)]	Loss: 0.050103
[7680/20457 (38%)]	Loss: 0.272602
[8960/20457 (44%)]	Loss: 0.051185
[10240/20457 (50%)]	Loss: 0.026480
[11520/20457 (56%)]	Loss: 0.145056
[12800/20457 (62%)]	Loss: 0.037322
[14080/20457 (69%)]	Loss: 0.050382
[15360/20457 (75%)]	Loss: 0.103761
[16640/20457 (81%)]	Loss: 0.096145
[17920/20457 (88%)]	Loss: 0.084374
[19200/20457 (94%)]	Loss: 0.165116
[13120/20457 (100%)]	Loss: 0.038771

Train: Average loss: 0.0763, Accuracy: 0.9722

Validation: Average loss: 1.2155, Accuracy: 0.8538

Train Epoch: 11
[1280/20457 (6%)]	Loss: 0.093485
[2560/20457 (12%)]	Loss: 0.161965
[3840/20457 (19%)]	Loss: 0.139011
[5120/20457 (25%)]	Loss: 0.096717
[6400/20457 (31%)]	Loss: 0.047687
[7680/20457 (38%)]	Loss: 0.135407
[8960/20457 (44%)]	Loss: 0.037390
[10240/20457 (50%)]	Loss: 0.108663
[11520/20457 (56%)]	Loss: 0.062230
[12800/20457 (62%)]	Loss: 0.090022
[14080/20457 (69%)]	Loss: 0.049208
[15360/20457 (75%)]	Loss: 0.063414
[16640/20457 (81%)]	Loss: 0.038339
[17920/20457 (88%)]	Loss: 0.089041
[19200/20457 (94%)]	Loss: 0.171923
[13120/20457 (100%)]	Loss: 0.111392

Train: Average loss: 0.0817, Accuracy: 0.9689

Validation: Average loss: 1.4669, Accuracy: 0.7765

Train Epoch: 12
[1280/20457 (6%)]	Loss: 0.063647
[2560/20457 (12%)]	Loss: 0.073979
[3840/20457 (19%)]	Loss: 0.020393
[5120/20457 (25%)]	Loss: 0.051756
[6400/20457 (31%)]	Loss: 0.045920
[7680/20457 (38%)]	Loss: 0.027033
[8960/20457 (44%)]	Loss: 0.078584
[10240/20457 (50%)]	Loss: 0.112048
[11520/20457 (56%)]	Loss: 0.044356
[12800/20457 (62%)]	Loss: 0.037912
[14080/20457 (69%)]	Loss: 0.076442
[15360/20457 (75%)]	Loss: 0.062173
[16640/20457 (81%)]	Loss: 0.110197
[17920/20457 (88%)]	Loss: 0.105149
[19200/20457 (94%)]	Loss: 0.019773
[13120/20457 (100%)]	Loss: 0.018527

Train: Average loss: 0.0765, Accuracy: 0.9726

Validation: Average loss: 2.1678, Accuracy: 0.7117

Train Epoch: 13
[1280/20457 (6%)]	Loss: 0.103463
[2560/20457 (12%)]	Loss: 0.034114
[3840/20457 (19%)]	Loss: 0.096323
[5120/20457 (25%)]	Loss: 0.039516
[6400/20457 (31%)]	Loss: 0.007555
[7680/20457 (38%)]	Loss: 0.019345
[8960/20457 (44%)]	Loss: 0.090555
[10240/20457 (50%)]	Loss: 0.036319
[11520/20457 (56%)]	Loss: 0.069145
[12800/20457 (62%)]	Loss: 0.058661
[14080/20457 (69%)]	Loss: 0.070616
[15360/20457 (75%)]	Loss: 0.075970
[16640/20457 (81%)]	Loss: 0.087841
[17920/20457 (88%)]	Loss: 0.030260
[19200/20457 (94%)]	Loss: 0.027599
[13120/20457 (100%)]	Loss: 0.045015

Train: Average loss: 0.0678, Accuracy: 0.9748

Validation: Average loss: 2.2481, Accuracy: 0.7897

Train Epoch: 14
[1280/20457 (6%)]	Loss: 0.088762
[2560/20457 (12%)]	Loss: 0.064384
[3840/20457 (19%)]	Loss: 0.078746
[5120/20457 (25%)]	Loss: 0.091362
[6400/20457 (31%)]	Loss: 0.013195
[7680/20457 (38%)]	Loss: 0.088708
[8960/20457 (44%)]	Loss: 0.021906
[10240/20457 (50%)]	Loss: 0.141558
[11520/20457 (56%)]	Loss: 0.116566
[12800/20457 (62%)]	Loss: 0.014188
[14080/20457 (69%)]	Loss: 0.043461
[15360/20457 (75%)]	Loss: 0.081282
[16640/20457 (81%)]	Loss: 0.161050
[17920/20457 (88%)]	Loss: 0.116595
[19200/20457 (94%)]	Loss: 0.097829
[13120/20457 (100%)]	Loss: 0.218975

Train: Average loss: 0.0692, Accuracy: 0.9752

Validation: Average loss: 2.1919, Accuracy: 0.7568

Train Epoch: 15
[1280/20457 (6%)]	Loss: 0.076715
[2560/20457 (12%)]	Loss: 0.208490
[3840/20457 (19%)]	Loss: 0.032941
[5120/20457 (25%)]	Loss: 0.032039
[6400/20457 (31%)]	Loss: 0.106540
[7680/20457 (38%)]	Loss: 0.041070
[8960/20457 (44%)]	Loss: 0.016431
[10240/20457 (50%)]	Loss: 0.152360
[11520/20457 (56%)]	Loss: 0.101323
[12800/20457 (62%)]	Loss: 0.185688
[14080/20457 (69%)]	Loss: 0.088974
[15360/20457 (75%)]	Loss: 0.081043
[16640/20457 (81%)]	Loss: 0.128916
[17920/20457 (88%)]	Loss: 0.011430
[19200/20457 (94%)]	Loss: 0.153337
[13120/20457 (100%)]	Loss: 0.151013

Train: Average loss: 0.0623, Accuracy: 0.9763

Validation: Average loss: 2.1169, Accuracy: 0.8132

Train Epoch: 16
[1280/20457 (6%)]	Loss: 0.104899
[2560/20457 (12%)]	Loss: 0.118871
[3840/20457 (19%)]	Loss: 0.004833
[5120/20457 (25%)]	Loss: 0.058821
[6400/20457 (31%)]	Loss: 0.043759
[7680/20457 (38%)]	Loss: 0.116204
[8960/20457 (44%)]	Loss: 0.131321
[10240/20457 (50%)]	Loss: 0.085480
[11520/20457 (56%)]	Loss: 0.008109
[12800/20457 (62%)]	Loss: 0.064811
[14080/20457 (69%)]	Loss: 0.072946
[15360/20457 (75%)]	Loss: 0.124129
[16640/20457 (81%)]	Loss: 0.029261
[17920/20457 (88%)]	Loss: 0.032921
[19200/20457 (94%)]	Loss: 0.016796
[13120/20457 (100%)]	Loss: 0.009986

Train: Average loss: 0.0679, Accuracy: 0.9757

Validation: Average loss: 2.1286, Accuracy: 0.8565

Train Epoch: 17
[1280/20457 (6%)]	Loss: 0.018248
[2560/20457 (12%)]	Loss: 0.085281
[3840/20457 (19%)]	Loss: 0.060398
[5120/20457 (25%)]	Loss: 0.012627
[6400/20457 (31%)]	Loss: 0.079470
[7680/20457 (38%)]	Loss: 0.025762
[8960/20457 (44%)]	Loss: 0.163033
[10240/20457 (50%)]	Loss: 0.021334
[11520/20457 (56%)]	Loss: 0.019311
[12800/20457 (62%)]	Loss: 0.029942
[14080/20457 (69%)]	Loss: 0.023639
[15360/20457 (75%)]	Loss: 0.024024
[16640/20457 (81%)]	Loss: 0.053554
[17920/20457 (88%)]	Loss: 0.014610
[19200/20457 (94%)]	Loss: 0.106618
[13120/20457 (100%)]	Loss: 0.056141

Train: Average loss: 0.0612, Accuracy: 0.9788

Validation: Average loss: 3.0565, Accuracy: 0.6857

Train Epoch: 18
[1280/20457 (6%)]	Loss: 0.053011
[2560/20457 (12%)]	Loss: 0.164775
[3840/20457 (19%)]	Loss: 0.056294
[5120/20457 (25%)]	Loss: 0.087227
[6400/20457 (31%)]	Loss: 0.064727
[7680/20457 (38%)]	Loss: 0.084947
[8960/20457 (44%)]	Loss: 0.011981
[10240/20457 (50%)]	Loss: 0.116578
[11520/20457 (56%)]	Loss: 0.019838
[12800/20457 (62%)]	Loss: 0.070667
[14080/20457 (69%)]	Loss: 0.056050
[15360/20457 (75%)]	Loss: 0.074088
[16640/20457 (81%)]	Loss: 0.008747
[17920/20457 (88%)]	Loss: 0.039805
[19200/20457 (94%)]	Loss: 0.142061
[13120/20457 (100%)]	Loss: 0.025551

Train: Average loss: 0.0526, Accuracy: 0.9809

Validation: Average loss: 2.0993, Accuracy: 0.7918

Train Epoch: 19
[1280/20457 (6%)]	Loss: 0.012908
[2560/20457 (12%)]	Loss: 0.011357
[3840/20457 (19%)]	Loss: 0.034939
[5120/20457 (25%)]	Loss: 0.016155
[6400/20457 (31%)]	Loss: 0.047667
[7680/20457 (38%)]	Loss: 0.023641
[8960/20457 (44%)]	Loss: 0.038973
[10240/20457 (50%)]	Loss: 0.056826
[11520/20457 (56%)]	Loss: 0.016601
[12800/20457 (62%)]	Loss: 0.092312
[14080/20457 (69%)]	Loss: 0.081406
[15360/20457 (75%)]	Loss: 0.062551
[16640/20457 (81%)]	Loss: 0.033870
[17920/20457 (88%)]	Loss: 0.092916
[19200/20457 (94%)]	Loss: 0.145001
[13120/20457 (100%)]	Loss: 0.056936

Train: Average loss: 0.0571, Accuracy: 0.9794

Validation: Average loss: 2.0757, Accuracy: 0.8565

Train Epoch: 20
[1280/20457 (6%)]	Loss: 0.018353
[2560/20457 (12%)]	Loss: 0.030320
[3840/20457 (19%)]	Loss: 0.099210
[5120/20457 (25%)]	Loss: 0.027942
[6400/20457 (31%)]	Loss: 0.058292
[7680/20457 (38%)]	Loss: 0.096841
[8960/20457 (44%)]	Loss: 0.026297
[10240/20457 (50%)]	Loss: 0.052762
[11520/20457 (56%)]	Loss: 0.087915
[12800/20457 (62%)]	Loss: 0.006828
[14080/20457 (69%)]	Loss: 0.007386
[15360/20457 (75%)]	Loss: 0.025286
[16640/20457 (81%)]	Loss: 0.063003
[17920/20457 (88%)]	Loss: 0.010459
[19200/20457 (94%)]	Loss: 0.073292
[13120/20457 (100%)]	Loss: 0.016993

Train: Average loss: 0.0502, Accuracy: 0.9824

Validation: Average loss: 2.2133, Accuracy: 0.7647

Train Epoch: 21
[1280/20457 (6%)]	Loss: 0.046928
[2560/20457 (12%)]	Loss: 0.022332
[3840/20457 (19%)]	Loss: 0.030530
[5120/20457 (25%)]	Loss: 0.024616
[6400/20457 (31%)]	Loss: 0.092403
[7680/20457 (38%)]	Loss: 0.013935
[8960/20457 (44%)]	Loss: 0.044144
[10240/20457 (50%)]	Loss: 0.060256
[11520/20457 (56%)]	Loss: 0.137226
[12800/20457 (62%)]	Loss: 0.052821
[14080/20457 (69%)]	Loss: 0.004136
[15360/20457 (75%)]	Loss: 0.025377
[16640/20457 (81%)]	Loss: 0.056544
[17920/20457 (88%)]	Loss: 0.020745
[19200/20457 (94%)]	Loss: 0.063546
[13120/20457 (100%)]	Loss: 0.068892

Train: Average loss: 0.0542, Accuracy: 0.9791

Validation: Average loss: 2.5180, Accuracy: 0.8222

Train Epoch: 22
[1280/20457 (6%)]	Loss: 0.036498
[2560/20457 (12%)]	Loss: 0.029486
[3840/20457 (19%)]	Loss: 0.009430
[5120/20457 (25%)]	Loss: 0.031182
[6400/20457 (31%)]	Loss: 0.018676
[7680/20457 (38%)]	Loss: 0.009826
[8960/20457 (44%)]	Loss: 0.013832
[10240/20457 (50%)]	Loss: 0.066528
[11520/20457 (56%)]	Loss: 0.012110
[12800/20457 (62%)]	Loss: 0.091061
[14080/20457 (69%)]	Loss: 0.033956
[15360/20457 (75%)]	Loss: 0.031217
[16640/20457 (81%)]	Loss: 0.100485
[17920/20457 (88%)]	Loss: 0.027754
[19200/20457 (94%)]	Loss: 0.038433
[13120/20457 (100%)]	Loss: 0.100482

Train: Average loss: 0.0483, Accuracy: 0.9823

Validation: Average loss: 2.4675, Accuracy: 0.8105

Train Epoch: 23
[1280/20457 (6%)]	Loss: 0.079316
[2560/20457 (12%)]	Loss: 0.012087
[3840/20457 (19%)]	Loss: 0.063074
[5120/20457 (25%)]	Loss: 0.012594
[6400/20457 (31%)]	Loss: 0.089318
[7680/20457 (38%)]	Loss: 0.040973
[8960/20457 (44%)]	Loss: 0.051222
[10240/20457 (50%)]	Loss: 0.241976
[11520/20457 (56%)]	Loss: 0.099788
[12800/20457 (62%)]	Loss: 0.019505
[14080/20457 (69%)]	Loss: 0.039544
[15360/20457 (75%)]	Loss: 0.052573
[16640/20457 (81%)]	Loss: 0.026704
[17920/20457 (88%)]	Loss: 0.010631
[19200/20457 (94%)]	Loss: 0.068934
[13120/20457 (100%)]	Loss: 0.047436

Train: Average loss: 0.0450, Accuracy: 0.9836

Validation: Average loss: 2.6579, Accuracy: 0.7443

Train Epoch: 24
[1280/20457 (6%)]	Loss: 0.018068
[2560/20457 (12%)]	Loss: 0.180993
[3840/20457 (19%)]	Loss: 0.027099
[5120/20457 (25%)]	Loss: 0.031176
[6400/20457 (31%)]	Loss: 0.020737
[7680/20457 (38%)]	Loss: 0.031309
[8960/20457 (44%)]	Loss: 0.023619
[10240/20457 (50%)]	Loss: 0.022132
[11520/20457 (56%)]	Loss: 0.014054
[12800/20457 (62%)]	Loss: 0.009688
[14080/20457 (69%)]	Loss: 0.010825
[15360/20457 (75%)]	Loss: 0.033677
[16640/20457 (81%)]	Loss: 0.003488
[17920/20457 (88%)]	Loss: 0.077383
[19200/20457 (94%)]	Loss: 0.006477
[13120/20457 (100%)]	Loss: 0.021153

Train: Average loss: 0.0477, Accuracy: 0.9834

Validation: Average loss: 3.1563, Accuracy: 0.8219

Train Epoch: 25
[1280/20457 (6%)]	Loss: 0.011585
[2560/20457 (12%)]	Loss: 0.026237
[3840/20457 (19%)]	Loss: 0.032970
[5120/20457 (25%)]	Loss: 0.017178
[6400/20457 (31%)]	Loss: 0.063584
[7680/20457 (38%)]	Loss: 0.030424
[8960/20457 (44%)]	Loss: 0.038144
[10240/20457 (50%)]	Loss: 0.006553
[11520/20457 (56%)]	Loss: 0.084033
[12800/20457 (62%)]	Loss: 0.042627
[14080/20457 (69%)]	Loss: 0.023425
[15360/20457 (75%)]	Loss: 0.016240
[16640/20457 (81%)]	Loss: 0.028662
[17920/20457 (88%)]	Loss: 0.022785
[19200/20457 (94%)]	Loss: 0.023668
[13120/20457 (100%)]	Loss: 0.033970

Train: Average loss: 0.0456, Accuracy: 0.9837

Validation: Average loss: 2.5554, Accuracy: 0.8389

Train Epoch: 26
[1280/20457 (6%)]	Loss: 0.238269
[2560/20457 (12%)]	Loss: 0.046602
[3840/20457 (19%)]	Loss: 0.016120
[5120/20457 (25%)]	Loss: 0.010275
[6400/20457 (31%)]	Loss: 0.063294
[7680/20457 (38%)]	Loss: 0.028716
[8960/20457 (44%)]	Loss: 0.055632
[10240/20457 (50%)]	Loss: 0.003819
[11520/20457 (56%)]	Loss: 0.068559
[12800/20457 (62%)]	Loss: 0.017301
[14080/20457 (69%)]	Loss: 0.091028
[15360/20457 (75%)]	Loss: 0.201758
[16640/20457 (81%)]	Loss: 0.009264
[17920/20457 (88%)]	Loss: 0.058656
[19200/20457 (94%)]	Loss: 0.082569
[13120/20457 (100%)]	Loss: 0.189526

Train: Average loss: 0.0533, Accuracy: 0.9811

Validation: Average loss: 4.0893, Accuracy: 0.7737

Train Epoch: 27
[1280/20457 (6%)]	Loss: 0.030012
[2560/20457 (12%)]	Loss: 0.022481
[3840/20457 (19%)]	Loss: 0.013295
[5120/20457 (25%)]	Loss: 0.013206
[6400/20457 (31%)]	Loss: 0.004704
[7680/20457 (38%)]	Loss: 0.026337
[8960/20457 (44%)]	Loss: 0.122487
[10240/20457 (50%)]	Loss: 0.010481
[11520/20457 (56%)]	Loss: 0.096591
[12800/20457 (62%)]	Loss: 0.006623
[14080/20457 (69%)]	Loss: 0.016750
[15360/20457 (75%)]	Loss: 0.165899
[16640/20457 (81%)]	Loss: 0.066964
[17920/20457 (88%)]	Loss: 0.044991
[19200/20457 (94%)]	Loss: 0.028740
[13120/20457 (100%)]	Loss: 0.010718

Train: Average loss: 0.0417, Accuracy: 0.9856

Validation: Average loss: 3.4930, Accuracy: 0.8046

Train Epoch: 28
[1280/20457 (6%)]	Loss: 0.013860
[2560/20457 (12%)]	Loss: 0.003523
[3840/20457 (19%)]	Loss: 0.017058
[5120/20457 (25%)]	Loss: 0.036954
[6400/20457 (31%)]	Loss: 0.021833
[7680/20457 (38%)]	Loss: 0.108124
[8960/20457 (44%)]	Loss: 0.022643
[10240/20457 (50%)]	Loss: 0.034491
[11520/20457 (56%)]	Loss: 0.046759
[12800/20457 (62%)]	Loss: 0.019284
[14080/20457 (69%)]	Loss: 0.039736
[15360/20457 (75%)]	Loss: 0.030811
[16640/20457 (81%)]	Loss: 0.035037
[17920/20457 (88%)]	Loss: 0.030163
[19200/20457 (94%)]	Loss: 0.015596
[13120/20457 (100%)]	Loss: 0.023386

Train: Average loss: 0.0433, Accuracy: 0.9843

Validation: Average loss: 2.9932, Accuracy: 0.8039

Train Epoch: 29
[1280/20457 (6%)]	Loss: 0.034986
[2560/20457 (12%)]	Loss: 0.014053
[3840/20457 (19%)]	Loss: 0.021670
[5120/20457 (25%)]	Loss: 0.013239
[6400/20457 (31%)]	Loss: 0.019322
[7680/20457 (38%)]	Loss: 0.108043
[8960/20457 (44%)]	Loss: 0.080542
[10240/20457 (50%)]	Loss: 0.050894
[11520/20457 (56%)]	Loss: 0.022049
[12800/20457 (62%)]	Loss: 0.056777
[14080/20457 (69%)]	Loss: 0.036953
[15360/20457 (75%)]	Loss: 0.008373
[16640/20457 (81%)]	Loss: 0.094993
[17920/20457 (88%)]	Loss: 0.046380
[19200/20457 (94%)]	Loss: 0.040345
[13120/20457 (100%)]	Loss: 0.067332

Train: Average loss: 0.0421, Accuracy: 0.9840

Validation: Average loss: 3.1402, Accuracy: 0.8191

Train Epoch: 30
[1280/20457 (6%)]	Loss: 0.038956
[2560/20457 (12%)]	Loss: 0.016095
[3840/20457 (19%)]	Loss: 0.015689
[5120/20457 (25%)]	Loss: 0.022543
[6400/20457 (31%)]	Loss: 0.032722
[7680/20457 (38%)]	Loss: 0.063683
[8960/20457 (44%)]	Loss: 0.009182
[10240/20457 (50%)]	Loss: 0.036861
[11520/20457 (56%)]	Loss: 0.067981
[12800/20457 (62%)]	Loss: 0.020677
[14080/20457 (69%)]	Loss: 0.007287
[15360/20457 (75%)]	Loss: 0.016112
[16640/20457 (81%)]	Loss: 0.008147
[17920/20457 (88%)]	Loss: 0.008376
[19200/20457 (94%)]	Loss: 0.031842
[13120/20457 (100%)]	Loss: 0.003770

Train: Average loss: 0.0381, Accuracy: 0.9865

Validation: Average loss: 2.6988, Accuracy: 0.8299

In [14]:
epochs = range(1, len(train_losses) + 1)

plt.figure(figsize=(15,6))
plt.subplot(1,2,1)
plt.plot(epochs, train_losses, '-o', label='Training loss')
plt.plot(epochs, val_losses, '-o', label='Validation loss')
plt.legend()
plt.title('Learning curves')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.xticks(epochs)
plt.subplot(1,2,2)
plt.plot(epochs, train_accuracies, '-o', label='Training accuracy')
plt.plot(epochs, val_accuracies, '-o', label='Validation accuracy')
plt.legend()
plt.title('Learning curves')
plt.xlabel('Epoch')
plt.ylabel('accuracy')
plt.xticks(epochs)
plt.show()
In [15]:
# best_epoch = 32
model = torch.load(checkpoints_foler+f'/avp_{best_epoch:03d}.pkl')
/tmp/cache-bformanek/ipykernel_128826/529002640.py:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  model = torch.load(checkpoints_foler+f'/avp_{best_epoch:03d}.pkl')
In [16]:
def predict(model, data_loader):
    model.eval()

    # save the predictions in this list
    y_pred = []

    # no gradient needed
    with torch.no_grad():

        # go over each batch in the loader. We can ignore the targets here
        for batch, _ in data_loader:

            # Move batch to the GPU
            batch = batch.to(device)

            # predict probabilities of each class
            predictions = model(batch)

            # apply a softmax to the predictions
            predictions = F.softmax(predictions, dim=1)

            # move to the cpu and convert to numpy
            predictions = predictions.cpu().numpy()

            # save
            y_pred.append(predictions)

    # stack predictions into a (num_samples, 10) array
    y_pred = np.vstack(y_pred)
    return y_pred
In [17]:
# compute predictions on the test set
y_pred = predict(model, test_loader)
# find the argmax of each of the predictions
y_pred = y_pred.argmax(axis=1)

# get the true labels and convert to numpy
y_true = np.array(test_set.targets)
In [18]:
num_errors = np.sum(y_true != y_pred)
print(f'Test errors {num_errors} (out of {len(test_set)})  {num_errors/len(test_set)*100:0.2f}%')
print(f'Test accuracy {100-num_errors/len(test_set)*100:0.2f}%')
Test errors 484 (out of 5823)  8.31%
Test accuracy 91.69%
In [19]:
from sklearn.metrics import confusion_matrix
import seaborn as sns

conf_matrix = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
            xticklabels=train_categories,
            yticklabels=train_categories)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.show()
/usr/lib/python3/dist-packages/statsmodels/__init__.py:6: UserWarning: This appears to be an armel system, on which statsmodels is buggy (crashes and possibly wrong answers) - https://bugs.debian.org/968210
  warnings.warn("This appears to be an armel system, on which statsmodels is buggy (crashes and possibly wrong answers) - https://bugs.debian.org/968210")
In [20]:
TP = conf_matrix.diagonal()
P = conf_matrix.sum(axis=1)

# Calculate balanced accuracy
balanced_accuracy = sum(TP / P) / len(P)
print(f'Balanced accuracy {balanced_accuracy*100:0.2f}%')
Balanced accuracy 88.54%
In [ ]: