import os
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset
from torchvision.transforms import v2
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
import numpy as np
import matplotlib.pyplot as plt
import sklearn.metrics as metrics
import random
random.seed(42)
torch.manual_seed(42)
np.random.seed(42)
import timm
from pprint import pprint
from collections import Counter
device = 'cuda'
DATA_PATH = '/net/travail/bformanek/MRI_dataset'
TRAIN_FOLDER = DATA_PATH + '/train'
VAL_FOLDER = DATA_PATH + '/val'
TEST_FOLDER = DATA_PATH + '/test'
train_categories = os.listdir(TRAIN_FOLDER)
val_categories = os.listdir(VAL_FOLDER)
test_categories = os.listdir(TEST_FOLDER)
print("Train image distribution: ")
class_num_in_train = []
for i in range(0, len(train_categories)):
CLASS_FOLDER = TRAIN_FOLDER + '/' + train_categories[i]
class_elements = os.listdir(CLASS_FOLDER)
class_num_in_train.append(len(class_elements))
print(f' {train_categories[i]}: {class_num_in_train[i]}')
print("Validation image distribution: ")
class_num_in_val = []
for i in range(0, len(val_categories)):
CLASS_FOLDER = VAL_FOLDER + '/' + val_categories[i]
class_elements = os.listdir(CLASS_FOLDER)
class_num_in_val.append(len(class_elements))
print(f' {val_categories[i]}: {class_num_in_val[i]}')
print("Test image distribution: ")
class_num_in_test = []
for i in range(0, len(test_categories)):
CLASS_FOLDER = TEST_FOLDER + '/' + test_categories[i]
class_elements = os.listdir(CLASS_FOLDER)
class_num_in_test.append(len(class_elements))
print(f' {test_categories[i]}: {class_num_in_test[i]}')
num_classes = len(class_num_in_train)
Train image distribution: T2star: 25 T2w: 1156 FLAIRCE: 1126 FLAIR: 5950 T1w: 5881 OTHER: 382 T1wCE: 5944 Validation image distribution: T2w: 160 FLAIRCE: 157 FLAIR: 844 T1w: 838 OTHER: 49 T1wCE: 844 Test image distribution: T2star: 4 T2w: 325 FLAIRCE: 316 FLAIR: 1693 T1w: 1678 OTHER: 118 T1wCE: 1696
def train_for_epoch_with_scaler(model, train_loader, optimizer, criterion, scaler, device):
# set model to train
model.train()
train_losses = []
train_accuracies = []
counter = 0
for batch, target in train_loader:
# data to GPU
batch = batch.to(device)
target = target.to(device)
# reset optimizer
optimizer.zero_grad()
# forward pass
predictions = model(batch)
# calculate accuracy
accuracy = (torch.argmax(predictions, dim=1) == target).sum().item() / target.size(0)
# calculate loss
loss = criterion(predictions, target)
# backward pass
scaler.scale(loss).backward()
# parameter update
scaler.step(optimizer)
scaler.update()
# track loss
train_losses.append(float(loss.item()))
train_accuracies.append(accuracy)
counter += 1
if counter % 20 == 0:
print('[{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
int(counter * len(batch)), len(train_loader.dataset),
100. * counter / len(train_loader), loss.item()))
train_loss = np.mean(np.array(train_losses))
train_accuracy = np.mean(np.array(train_accuracies))
print('\nTrain: Average loss: {:.4f}, Accuracy: {:.4f}\n'.format(
train_loss, train_accuracy))
return train_loss, train_accuracy
def validate(model, val_loader, criterion, device):
model.eval()
val_losses = []
y_true, y_pred = [], []
with torch.no_grad():
for batch, target in val_loader:
# move data to the device
batch = batch.to(device)
target = target.to(device)
with torch.autocast(device_type=device, dtype=torch.float16):
# make predictions
predictions = model(batch)
# calculate loss
loss = criterion(predictions, target)
# track losses and predictions
val_losses.append(float(loss.item()))
y_true.extend(target.cpu().numpy())
y_pred.extend(predictions.argmax(dim=1).cpu().numpy())
y_true = np.array(y_true)
y_pred = np.array(y_pred)
val_losses = np.array(val_losses)
# calculate validation accuracy from y_true and y_pred
val_accuracy = np.mean(y_true == y_pred)
# calculate the mean validation loss
val_loss = np.mean(val_losses)
print('Validation: Average loss: {:.4f}, Accuracy: {:.4f}\n'.format(
val_loss, val_accuracy))
return val_loss, val_accuracy
def train_with_scaler(model, train_loader, val_loader, optimizer, criterion, epochs, scaler, device, checkpoints_foler = None, first_epoch=1):
train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []
max_val_acc = 0
best_epoch = 0
for epoch in range(first_epoch, epochs+first_epoch):
print('Train Epoch: {}'.format(epoch))
# train
train_loss, train_acc = train_for_epoch_with_scaler(model, train_loader, optimizer, criterion, scaler, device)
# validation
valid_loss, valid_acc = validate(model, val_loader, criterion, device)
train_losses.append(train_loss)
val_losses.append(valid_loss)
train_accuracies.append(train_acc)
val_accuracies.append(valid_acc)
# save checkpoint
if checkpoints_foler != None and max_val_acc < valid_acc:
max_val_acc = valid_acc
best_epoch = epoch
torch.save(model, checkpoints_foler+f'/avp_{epoch:03d}.pkl')
return train_losses, val_losses, train_accuracies, val_accuracies, best_epoch
# define custom resample class to change image resolution without rescaling
class RandomResample:
def __init__(self, scale_factor):
self.scale_factor = random.uniform(0,scale_factor)
def __call__(self, img):
# Downsample
width, height = img.size
downscaled_size = (int(width / self.scale_factor), int(height / self.scale_factor))
# Downsample the image
img_downsampled = img.resize(downscaled_size)
# Upsample back to the original size
img_upsampled = img_downsampled.resize((width, height))
return img_upsampled
train_transform = transforms.Compose([
transforms.v2.Resize(224),
# augmentations
transforms.v2.RandomHorizontalFlip(p=0.5),
transforms.v2.RandomVerticalFlip(p=0.5),
transforms.v2.RandomRotation(degrees=360, expand=True), # expand=True: esnure that the whole image is represented on the rotated image
#transforms.v2.ColorJitter(contrast=0.1),
#transforms.v2.GaussianBlur(7, sigma=2),
RandomResample(scale_factor=2),
transforms.v2.Resize(224),
transforms.ToTensor()
])
valid_transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor()
])
train_set = ImageFolder(TRAIN_FOLDER, transform = train_transform)
val_set = ImageFolder(VAL_FOLDER, transform = valid_transform)
test_set = ImageFolder(TEST_FOLDER, transform = valid_transform)
BATCH_SIZE = 64
WORKERS = 8
train_loader = DataLoader(train_set, batch_size = BATCH_SIZE, shuffle = True, num_workers=WORKERS)
val_loader = DataLoader(val_set, batch_size = BATCH_SIZE, shuffle = False, num_workers=WORKERS)
test_loader = DataLoader(test_set, batch_size = BATCH_SIZE, shuffle = False, num_workers=WORKERS)
# print(f'train samples: {len(train_set)} validation samples: {len(val_set)} test samples: {len(test_set)}')
#for image_batch, labels_batch in train_loader:
# print("Batch sizes:", image_batch.shape, "(batch, channels, height, width)")
# print("Label vector size:", labels_batch.shape)
# break
num_in_class_dict = dict(Counter(train_set.targets[i] for i in range(len(train_set))))
num_in_class = np.zeros([1,len(num_in_class_dict)])
for i in range(0, len(num_in_class_dict)):
num_in_class[0, i] = num_in_class_dict[i]
class_weights = 1-(num_in_class/num_in_class.sum()).squeeze()
class_weights_tensor = torch.Tensor(class_weights).to(device)
# print(num_in_class_dict)
# print(num_in_class)
MODEL_NAME = 'resnet18' ##resnet18, resnet50, efficientnet_b0
model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=num_classes)
model.to(device)
model.safetensors: 0%| | 0.00/46.8M [00:00<?, ?B/s]
ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) ) (1): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) ) ) (layer2): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) (downsample): Sequential( (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) ) ) (layer3): Sequential( (0): BasicBlock( (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) (downsample): Sequential( (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) ) ) (layer4): Sequential( (0): BasicBlock( (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) (downsample): Sequential( (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) ) ) (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1)) (fc): Linear(in_features=512, out_features=7, bias=True) )
criterion_balanced = nn.CrossEntropyLoss(weight = class_weights_tensor)
optimizer_Adam = optim.Adam(model.parameters(), 1e-3)
scaler = torch.cuda.amp.GradScaler()
/tmp/cache-bformanek/ipykernel_128826/3247579378.py:3: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead. scaler = torch.cuda.amp.GradScaler()
RESULT_FOLDER_NAME = MODEL_NAME+"_flips_90_resample"
checkpoints_foler = '/net/travail/bformanek/checkpoints/transfer_checkpoints_'+RESULT_FOLDER_NAME
if not os.path.exists(checkpoints_foler):
os.mkdir(checkpoints_foler)
epochs = 30
train_losses, val_losses, train_accuracies, val_accuracies, best_epoch = train_with_scaler(model, train_loader, val_loader, optimizer_Adam, criterion_balanced,
epochs, scaler, device, checkpoints_foler=checkpoints_foler)
Train Epoch: 1 [1280/20457 (6%)] Loss: 1.054133 [2560/20457 (12%)] Loss: 0.725362 [3840/20457 (19%)] Loss: 0.610338 [5120/20457 (25%)] Loss: 0.386082 [6400/20457 (31%)] Loss: 0.483535 [7680/20457 (38%)] Loss: 0.419246 [8960/20457 (44%)] Loss: 0.328190 [10240/20457 (50%)] Loss: 0.308688 [11520/20457 (56%)] Loss: 0.554291 [12800/20457 (62%)] Loss: 0.545511 [14080/20457 (69%)] Loss: 0.206743 [15360/20457 (75%)] Loss: 0.248842 [16640/20457 (81%)] Loss: 0.231432 [17920/20457 (88%)] Loss: 0.220167 [19200/20457 (94%)] Loss: 0.388959 [13120/20457 (100%)] Loss: 0.223442 Train: Average loss: 0.4534, Accuracy: 0.8358 Validation: Average loss: 1.5219, Accuracy: 0.7027 Train Epoch: 2 [1280/20457 (6%)] Loss: 0.154668 [2560/20457 (12%)] Loss: 0.071672 [3840/20457 (19%)] Loss: 0.456156 [5120/20457 (25%)] Loss: 0.200928 [6400/20457 (31%)] Loss: 0.288852 [7680/20457 (38%)] Loss: 0.125828 [8960/20457 (44%)] Loss: 0.158839 [10240/20457 (50%)] Loss: 0.223312 [11520/20457 (56%)] Loss: 0.165408 [12800/20457 (62%)] Loss: 0.213668 [14080/20457 (69%)] Loss: 0.175157 [15360/20457 (75%)] Loss: 0.117205 [16640/20457 (81%)] Loss: 0.178432 [17920/20457 (88%)] Loss: 0.235235 [19200/20457 (94%)] Loss: 0.135907 [13120/20457 (100%)] Loss: 0.271736 Train: Average loss: 0.2184, Accuracy: 0.9194 Validation: Average loss: 1.2750, Accuracy: 0.8500 Train Epoch: 3 [1280/20457 (6%)] Loss: 0.250221 [2560/20457 (12%)] Loss: 0.124799 [3840/20457 (19%)] Loss: 0.126421 [5120/20457 (25%)] Loss: 0.105148 [6400/20457 (31%)] Loss: 0.213071 [7680/20457 (38%)] Loss: 0.075168 [8960/20457 (44%)] Loss: 0.452005 [10240/20457 (50%)] Loss: 0.217459 [11520/20457 (56%)] Loss: 0.391471 [12800/20457 (62%)] Loss: 0.218012 [14080/20457 (69%)] Loss: 0.207581 [15360/20457 (75%)] Loss: 0.060743 [16640/20457 (81%)] Loss: 0.155725 [17920/20457 (88%)] Loss: 0.163342 [19200/20457 (94%)] Loss: 0.328871 [13120/20457 (100%)] Loss: 0.150410 Train: Average loss: 0.1682, Accuracy: 0.9380 Validation: Average loss: 1.3876, Accuracy: 0.6535 Train Epoch: 4 [1280/20457 (6%)] Loss: 0.155711 [2560/20457 (12%)] Loss: 0.111006 [3840/20457 (19%)] Loss: 0.153953 [5120/20457 (25%)] Loss: 0.161102 [6400/20457 (31%)] Loss: 0.193571 [7680/20457 (38%)] Loss: 0.209899 [8960/20457 (44%)] Loss: 0.139377 [10240/20457 (50%)] Loss: 0.081941 [11520/20457 (56%)] Loss: 0.179336 [12800/20457 (62%)] Loss: 0.148540 [14080/20457 (69%)] Loss: 0.137478 [15360/20457 (75%)] Loss: 0.170339 [16640/20457 (81%)] Loss: 0.022386 [17920/20457 (88%)] Loss: 0.103414 [19200/20457 (94%)] Loss: 0.172884 [13120/20457 (100%)] Loss: 0.150939 Train: Average loss: 0.1391, Accuracy: 0.9509 Validation: Average loss: 1.3734, Accuracy: 0.7769 Train Epoch: 5 [1280/20457 (6%)] Loss: 0.062116 [2560/20457 (12%)] Loss: 0.379405 [3840/20457 (19%)] Loss: 0.125783 [5120/20457 (25%)] Loss: 0.279066 [6400/20457 (31%)] Loss: 0.059350 [7680/20457 (38%)] Loss: 0.189140 [8960/20457 (44%)] Loss: 0.164375 [10240/20457 (50%)] Loss: 0.063068 [11520/20457 (56%)] Loss: 0.139561 [12800/20457 (62%)] Loss: 0.068787 [14080/20457 (69%)] Loss: 0.048323 [15360/20457 (75%)] Loss: 0.120081 [16640/20457 (81%)] Loss: 0.119275 [17920/20457 (88%)] Loss: 0.179179 [19200/20457 (94%)] Loss: 0.179864 [13120/20457 (100%)] Loss: 0.312999 Train: Average loss: 0.1280, Accuracy: 0.9526 Validation: Average loss: 1.2929, Accuracy: 0.7526 Train Epoch: 6 [1280/20457 (6%)] Loss: 0.113230 [2560/20457 (12%)] Loss: 0.062875 [3840/20457 (19%)] Loss: 0.216621 [5120/20457 (25%)] Loss: 0.073777 [6400/20457 (31%)] Loss: 0.144536 [7680/20457 (38%)] Loss: 0.115534 [8960/20457 (44%)] Loss: 0.097828 [10240/20457 (50%)] Loss: 0.098064 [11520/20457 (56%)] Loss: 0.118195 [12800/20457 (62%)] Loss: 0.119108 [14080/20457 (69%)] Loss: 0.038263 [15360/20457 (75%)] Loss: 0.056077 [16640/20457 (81%)] Loss: 0.143987 [17920/20457 (88%)] Loss: 0.074841 [19200/20457 (94%)] Loss: 0.082691 [13120/20457 (100%)] Loss: 0.117237 Train: Average loss: 0.1107, Accuracy: 0.9585 Validation: Average loss: 1.4264, Accuracy: 0.8396 Train Epoch: 7 [1280/20457 (6%)] Loss: 0.198399 [2560/20457 (12%)] Loss: 0.164641 [3840/20457 (19%)] Loss: 0.051346 [5120/20457 (25%)] Loss: 0.137102 [6400/20457 (31%)] Loss: 0.164529 [7680/20457 (38%)] Loss: 0.102826 [8960/20457 (44%)] Loss: 0.151736 [10240/20457 (50%)] Loss: 0.176340 [11520/20457 (56%)] Loss: 0.020442 [12800/20457 (62%)] Loss: 0.168744 [14080/20457 (69%)] Loss: 0.127338 [15360/20457 (75%)] Loss: 0.064539 [16640/20457 (81%)] Loss: 0.041766 [17920/20457 (88%)] Loss: 0.105609 [19200/20457 (94%)] Loss: 0.116521 [13120/20457 (100%)] Loss: 0.195548 Train: Average loss: 0.1048, Accuracy: 0.9613 Validation: Average loss: 1.7050, Accuracy: 0.8177 Train Epoch: 8 [1280/20457 (6%)] Loss: 0.050569 [2560/20457 (12%)] Loss: 0.050108 [3840/20457 (19%)] Loss: 0.094313 [5120/20457 (25%)] Loss: 0.073465 [6400/20457 (31%)] Loss: 0.012324 [7680/20457 (38%)] Loss: 0.253782 [8960/20457 (44%)] Loss: 0.053417 [10240/20457 (50%)] Loss: 0.077354 [11520/20457 (56%)] Loss: 0.061405 [12800/20457 (62%)] Loss: 0.043902 [14080/20457 (69%)] Loss: 0.098953 [15360/20457 (75%)] Loss: 0.085991 [16640/20457 (81%)] Loss: 0.071721 [17920/20457 (88%)] Loss: 0.068376 [19200/20457 (94%)] Loss: 0.042487 [13120/20457 (100%)] Loss: 0.047457 Train: Average loss: 0.0958, Accuracy: 0.9647 Validation: Average loss: 1.7110, Accuracy: 0.7037 Train Epoch: 9 [1280/20457 (6%)] Loss: 0.039522 [2560/20457 (12%)] Loss: 0.090020 [3840/20457 (19%)] Loss: 0.203921 [5120/20457 (25%)] Loss: 0.052970 [6400/20457 (31%)] Loss: 0.165568 [7680/20457 (38%)] Loss: 0.052486 [8960/20457 (44%)] Loss: 0.110796 [10240/20457 (50%)] Loss: 0.118852 [11520/20457 (56%)] Loss: 0.089906 [12800/20457 (62%)] Loss: 0.037743 [14080/20457 (69%)] Loss: 0.071622 [15360/20457 (75%)] Loss: 0.091884 [16640/20457 (81%)] Loss: 0.144048 [17920/20457 (88%)] Loss: 0.020264 [19200/20457 (94%)] Loss: 0.154216 [13120/20457 (100%)] Loss: 0.068246 Train: Average loss: 0.0857, Accuracy: 0.9704 Validation: Average loss: 1.4652, Accuracy: 0.7879 Train Epoch: 10 [1280/20457 (6%)] Loss: 0.087559 [2560/20457 (12%)] Loss: 0.075757 [3840/20457 (19%)] Loss: 0.131536 [5120/20457 (25%)] Loss: 0.294059 [6400/20457 (31%)] Loss: 0.050103 [7680/20457 (38%)] Loss: 0.272602 [8960/20457 (44%)] Loss: 0.051185 [10240/20457 (50%)] Loss: 0.026480 [11520/20457 (56%)] Loss: 0.145056 [12800/20457 (62%)] Loss: 0.037322 [14080/20457 (69%)] Loss: 0.050382 [15360/20457 (75%)] Loss: 0.103761 [16640/20457 (81%)] Loss: 0.096145 [17920/20457 (88%)] Loss: 0.084374 [19200/20457 (94%)] Loss: 0.165116 [13120/20457 (100%)] Loss: 0.038771 Train: Average loss: 0.0763, Accuracy: 0.9722 Validation: Average loss: 1.2155, Accuracy: 0.8538 Train Epoch: 11 [1280/20457 (6%)] Loss: 0.093485 [2560/20457 (12%)] Loss: 0.161965 [3840/20457 (19%)] Loss: 0.139011 [5120/20457 (25%)] Loss: 0.096717 [6400/20457 (31%)] Loss: 0.047687 [7680/20457 (38%)] Loss: 0.135407 [8960/20457 (44%)] Loss: 0.037390 [10240/20457 (50%)] Loss: 0.108663 [11520/20457 (56%)] Loss: 0.062230 [12800/20457 (62%)] Loss: 0.090022 [14080/20457 (69%)] Loss: 0.049208 [15360/20457 (75%)] Loss: 0.063414 [16640/20457 (81%)] Loss: 0.038339 [17920/20457 (88%)] Loss: 0.089041 [19200/20457 (94%)] Loss: 0.171923 [13120/20457 (100%)] Loss: 0.111392 Train: Average loss: 0.0817, Accuracy: 0.9689 Validation: Average loss: 1.4669, Accuracy: 0.7765 Train Epoch: 12 [1280/20457 (6%)] Loss: 0.063647 [2560/20457 (12%)] Loss: 0.073979 [3840/20457 (19%)] Loss: 0.020393 [5120/20457 (25%)] Loss: 0.051756 [6400/20457 (31%)] Loss: 0.045920 [7680/20457 (38%)] Loss: 0.027033 [8960/20457 (44%)] Loss: 0.078584 [10240/20457 (50%)] Loss: 0.112048 [11520/20457 (56%)] Loss: 0.044356 [12800/20457 (62%)] Loss: 0.037912 [14080/20457 (69%)] Loss: 0.076442 [15360/20457 (75%)] Loss: 0.062173 [16640/20457 (81%)] Loss: 0.110197 [17920/20457 (88%)] Loss: 0.105149 [19200/20457 (94%)] Loss: 0.019773 [13120/20457 (100%)] Loss: 0.018527 Train: Average loss: 0.0765, Accuracy: 0.9726 Validation: Average loss: 2.1678, Accuracy: 0.7117 Train Epoch: 13 [1280/20457 (6%)] Loss: 0.103463 [2560/20457 (12%)] Loss: 0.034114 [3840/20457 (19%)] Loss: 0.096323 [5120/20457 (25%)] Loss: 0.039516 [6400/20457 (31%)] Loss: 0.007555 [7680/20457 (38%)] Loss: 0.019345 [8960/20457 (44%)] Loss: 0.090555 [10240/20457 (50%)] Loss: 0.036319 [11520/20457 (56%)] Loss: 0.069145 [12800/20457 (62%)] Loss: 0.058661 [14080/20457 (69%)] Loss: 0.070616 [15360/20457 (75%)] Loss: 0.075970 [16640/20457 (81%)] Loss: 0.087841 [17920/20457 (88%)] Loss: 0.030260 [19200/20457 (94%)] Loss: 0.027599 [13120/20457 (100%)] Loss: 0.045015 Train: Average loss: 0.0678, Accuracy: 0.9748 Validation: Average loss: 2.2481, Accuracy: 0.7897 Train Epoch: 14 [1280/20457 (6%)] Loss: 0.088762 [2560/20457 (12%)] Loss: 0.064384 [3840/20457 (19%)] Loss: 0.078746 [5120/20457 (25%)] Loss: 0.091362 [6400/20457 (31%)] Loss: 0.013195 [7680/20457 (38%)] Loss: 0.088708 [8960/20457 (44%)] Loss: 0.021906 [10240/20457 (50%)] Loss: 0.141558 [11520/20457 (56%)] Loss: 0.116566 [12800/20457 (62%)] Loss: 0.014188 [14080/20457 (69%)] Loss: 0.043461 [15360/20457 (75%)] Loss: 0.081282 [16640/20457 (81%)] Loss: 0.161050 [17920/20457 (88%)] Loss: 0.116595 [19200/20457 (94%)] Loss: 0.097829 [13120/20457 (100%)] Loss: 0.218975 Train: Average loss: 0.0692, Accuracy: 0.9752 Validation: Average loss: 2.1919, Accuracy: 0.7568 Train Epoch: 15 [1280/20457 (6%)] Loss: 0.076715 [2560/20457 (12%)] Loss: 0.208490 [3840/20457 (19%)] Loss: 0.032941 [5120/20457 (25%)] Loss: 0.032039 [6400/20457 (31%)] Loss: 0.106540 [7680/20457 (38%)] Loss: 0.041070 [8960/20457 (44%)] Loss: 0.016431 [10240/20457 (50%)] Loss: 0.152360 [11520/20457 (56%)] Loss: 0.101323 [12800/20457 (62%)] Loss: 0.185688 [14080/20457 (69%)] Loss: 0.088974 [15360/20457 (75%)] Loss: 0.081043 [16640/20457 (81%)] Loss: 0.128916 [17920/20457 (88%)] Loss: 0.011430 [19200/20457 (94%)] Loss: 0.153337 [13120/20457 (100%)] Loss: 0.151013 Train: Average loss: 0.0623, Accuracy: 0.9763 Validation: Average loss: 2.1169, Accuracy: 0.8132 Train Epoch: 16 [1280/20457 (6%)] Loss: 0.104899 [2560/20457 (12%)] Loss: 0.118871 [3840/20457 (19%)] Loss: 0.004833 [5120/20457 (25%)] Loss: 0.058821 [6400/20457 (31%)] Loss: 0.043759 [7680/20457 (38%)] Loss: 0.116204 [8960/20457 (44%)] Loss: 0.131321 [10240/20457 (50%)] Loss: 0.085480 [11520/20457 (56%)] Loss: 0.008109 [12800/20457 (62%)] Loss: 0.064811 [14080/20457 (69%)] Loss: 0.072946 [15360/20457 (75%)] Loss: 0.124129 [16640/20457 (81%)] Loss: 0.029261 [17920/20457 (88%)] Loss: 0.032921 [19200/20457 (94%)] Loss: 0.016796 [13120/20457 (100%)] Loss: 0.009986 Train: Average loss: 0.0679, Accuracy: 0.9757 Validation: Average loss: 2.1286, Accuracy: 0.8565 Train Epoch: 17 [1280/20457 (6%)] Loss: 0.018248 [2560/20457 (12%)] Loss: 0.085281 [3840/20457 (19%)] Loss: 0.060398 [5120/20457 (25%)] Loss: 0.012627 [6400/20457 (31%)] Loss: 0.079470 [7680/20457 (38%)] Loss: 0.025762 [8960/20457 (44%)] Loss: 0.163033 [10240/20457 (50%)] Loss: 0.021334 [11520/20457 (56%)] Loss: 0.019311 [12800/20457 (62%)] Loss: 0.029942 [14080/20457 (69%)] Loss: 0.023639 [15360/20457 (75%)] Loss: 0.024024 [16640/20457 (81%)] Loss: 0.053554 [17920/20457 (88%)] Loss: 0.014610 [19200/20457 (94%)] Loss: 0.106618 [13120/20457 (100%)] Loss: 0.056141 Train: Average loss: 0.0612, Accuracy: 0.9788 Validation: Average loss: 3.0565, Accuracy: 0.6857 Train Epoch: 18 [1280/20457 (6%)] Loss: 0.053011 [2560/20457 (12%)] Loss: 0.164775 [3840/20457 (19%)] Loss: 0.056294 [5120/20457 (25%)] Loss: 0.087227 [6400/20457 (31%)] Loss: 0.064727 [7680/20457 (38%)] Loss: 0.084947 [8960/20457 (44%)] Loss: 0.011981 [10240/20457 (50%)] Loss: 0.116578 [11520/20457 (56%)] Loss: 0.019838 [12800/20457 (62%)] Loss: 0.070667 [14080/20457 (69%)] Loss: 0.056050 [15360/20457 (75%)] Loss: 0.074088 [16640/20457 (81%)] Loss: 0.008747 [17920/20457 (88%)] Loss: 0.039805 [19200/20457 (94%)] Loss: 0.142061 [13120/20457 (100%)] Loss: 0.025551 Train: Average loss: 0.0526, Accuracy: 0.9809 Validation: Average loss: 2.0993, Accuracy: 0.7918 Train Epoch: 19 [1280/20457 (6%)] Loss: 0.012908 [2560/20457 (12%)] Loss: 0.011357 [3840/20457 (19%)] Loss: 0.034939 [5120/20457 (25%)] Loss: 0.016155 [6400/20457 (31%)] Loss: 0.047667 [7680/20457 (38%)] Loss: 0.023641 [8960/20457 (44%)] Loss: 0.038973 [10240/20457 (50%)] Loss: 0.056826 [11520/20457 (56%)] Loss: 0.016601 [12800/20457 (62%)] Loss: 0.092312 [14080/20457 (69%)] Loss: 0.081406 [15360/20457 (75%)] Loss: 0.062551 [16640/20457 (81%)] Loss: 0.033870 [17920/20457 (88%)] Loss: 0.092916 [19200/20457 (94%)] Loss: 0.145001 [13120/20457 (100%)] Loss: 0.056936 Train: Average loss: 0.0571, Accuracy: 0.9794 Validation: Average loss: 2.0757, Accuracy: 0.8565 Train Epoch: 20 [1280/20457 (6%)] Loss: 0.018353 [2560/20457 (12%)] Loss: 0.030320 [3840/20457 (19%)] Loss: 0.099210 [5120/20457 (25%)] Loss: 0.027942 [6400/20457 (31%)] Loss: 0.058292 [7680/20457 (38%)] Loss: 0.096841 [8960/20457 (44%)] Loss: 0.026297 [10240/20457 (50%)] Loss: 0.052762 [11520/20457 (56%)] Loss: 0.087915 [12800/20457 (62%)] Loss: 0.006828 [14080/20457 (69%)] Loss: 0.007386 [15360/20457 (75%)] Loss: 0.025286 [16640/20457 (81%)] Loss: 0.063003 [17920/20457 (88%)] Loss: 0.010459 [19200/20457 (94%)] Loss: 0.073292 [13120/20457 (100%)] Loss: 0.016993 Train: Average loss: 0.0502, Accuracy: 0.9824 Validation: Average loss: 2.2133, Accuracy: 0.7647 Train Epoch: 21 [1280/20457 (6%)] Loss: 0.046928 [2560/20457 (12%)] Loss: 0.022332 [3840/20457 (19%)] Loss: 0.030530 [5120/20457 (25%)] Loss: 0.024616 [6400/20457 (31%)] Loss: 0.092403 [7680/20457 (38%)] Loss: 0.013935 [8960/20457 (44%)] Loss: 0.044144 [10240/20457 (50%)] Loss: 0.060256 [11520/20457 (56%)] Loss: 0.137226 [12800/20457 (62%)] Loss: 0.052821 [14080/20457 (69%)] Loss: 0.004136 [15360/20457 (75%)] Loss: 0.025377 [16640/20457 (81%)] Loss: 0.056544 [17920/20457 (88%)] Loss: 0.020745 [19200/20457 (94%)] Loss: 0.063546 [13120/20457 (100%)] Loss: 0.068892 Train: Average loss: 0.0542, Accuracy: 0.9791 Validation: Average loss: 2.5180, Accuracy: 0.8222 Train Epoch: 22 [1280/20457 (6%)] Loss: 0.036498 [2560/20457 (12%)] Loss: 0.029486 [3840/20457 (19%)] Loss: 0.009430 [5120/20457 (25%)] Loss: 0.031182 [6400/20457 (31%)] Loss: 0.018676 [7680/20457 (38%)] Loss: 0.009826 [8960/20457 (44%)] Loss: 0.013832 [10240/20457 (50%)] Loss: 0.066528 [11520/20457 (56%)] Loss: 0.012110 [12800/20457 (62%)] Loss: 0.091061 [14080/20457 (69%)] Loss: 0.033956 [15360/20457 (75%)] Loss: 0.031217 [16640/20457 (81%)] Loss: 0.100485 [17920/20457 (88%)] Loss: 0.027754 [19200/20457 (94%)] Loss: 0.038433 [13120/20457 (100%)] Loss: 0.100482 Train: Average loss: 0.0483, Accuracy: 0.9823 Validation: Average loss: 2.4675, Accuracy: 0.8105 Train Epoch: 23 [1280/20457 (6%)] Loss: 0.079316 [2560/20457 (12%)] Loss: 0.012087 [3840/20457 (19%)] Loss: 0.063074 [5120/20457 (25%)] Loss: 0.012594 [6400/20457 (31%)] Loss: 0.089318 [7680/20457 (38%)] Loss: 0.040973 [8960/20457 (44%)] Loss: 0.051222 [10240/20457 (50%)] Loss: 0.241976 [11520/20457 (56%)] Loss: 0.099788 [12800/20457 (62%)] Loss: 0.019505 [14080/20457 (69%)] Loss: 0.039544 [15360/20457 (75%)] Loss: 0.052573 [16640/20457 (81%)] Loss: 0.026704 [17920/20457 (88%)] Loss: 0.010631 [19200/20457 (94%)] Loss: 0.068934 [13120/20457 (100%)] Loss: 0.047436 Train: Average loss: 0.0450, Accuracy: 0.9836 Validation: Average loss: 2.6579, Accuracy: 0.7443 Train Epoch: 24 [1280/20457 (6%)] Loss: 0.018068 [2560/20457 (12%)] Loss: 0.180993 [3840/20457 (19%)] Loss: 0.027099 [5120/20457 (25%)] Loss: 0.031176 [6400/20457 (31%)] Loss: 0.020737 [7680/20457 (38%)] Loss: 0.031309 [8960/20457 (44%)] Loss: 0.023619 [10240/20457 (50%)] Loss: 0.022132 [11520/20457 (56%)] Loss: 0.014054 [12800/20457 (62%)] Loss: 0.009688 [14080/20457 (69%)] Loss: 0.010825 [15360/20457 (75%)] Loss: 0.033677 [16640/20457 (81%)] Loss: 0.003488 [17920/20457 (88%)] Loss: 0.077383 [19200/20457 (94%)] Loss: 0.006477 [13120/20457 (100%)] Loss: 0.021153 Train: Average loss: 0.0477, Accuracy: 0.9834 Validation: Average loss: 3.1563, Accuracy: 0.8219 Train Epoch: 25 [1280/20457 (6%)] Loss: 0.011585 [2560/20457 (12%)] Loss: 0.026237 [3840/20457 (19%)] Loss: 0.032970 [5120/20457 (25%)] Loss: 0.017178 [6400/20457 (31%)] Loss: 0.063584 [7680/20457 (38%)] Loss: 0.030424 [8960/20457 (44%)] Loss: 0.038144 [10240/20457 (50%)] Loss: 0.006553 [11520/20457 (56%)] Loss: 0.084033 [12800/20457 (62%)] Loss: 0.042627 [14080/20457 (69%)] Loss: 0.023425 [15360/20457 (75%)] Loss: 0.016240 [16640/20457 (81%)] Loss: 0.028662 [17920/20457 (88%)] Loss: 0.022785 [19200/20457 (94%)] Loss: 0.023668 [13120/20457 (100%)] Loss: 0.033970 Train: Average loss: 0.0456, Accuracy: 0.9837 Validation: Average loss: 2.5554, Accuracy: 0.8389 Train Epoch: 26 [1280/20457 (6%)] Loss: 0.238269 [2560/20457 (12%)] Loss: 0.046602 [3840/20457 (19%)] Loss: 0.016120 [5120/20457 (25%)] Loss: 0.010275 [6400/20457 (31%)] Loss: 0.063294 [7680/20457 (38%)] Loss: 0.028716 [8960/20457 (44%)] Loss: 0.055632 [10240/20457 (50%)] Loss: 0.003819 [11520/20457 (56%)] Loss: 0.068559 [12800/20457 (62%)] Loss: 0.017301 [14080/20457 (69%)] Loss: 0.091028 [15360/20457 (75%)] Loss: 0.201758 [16640/20457 (81%)] Loss: 0.009264 [17920/20457 (88%)] Loss: 0.058656 [19200/20457 (94%)] Loss: 0.082569 [13120/20457 (100%)] Loss: 0.189526 Train: Average loss: 0.0533, Accuracy: 0.9811 Validation: Average loss: 4.0893, Accuracy: 0.7737 Train Epoch: 27 [1280/20457 (6%)] Loss: 0.030012 [2560/20457 (12%)] Loss: 0.022481 [3840/20457 (19%)] Loss: 0.013295 [5120/20457 (25%)] Loss: 0.013206 [6400/20457 (31%)] Loss: 0.004704 [7680/20457 (38%)] Loss: 0.026337 [8960/20457 (44%)] Loss: 0.122487 [10240/20457 (50%)] Loss: 0.010481 [11520/20457 (56%)] Loss: 0.096591 [12800/20457 (62%)] Loss: 0.006623 [14080/20457 (69%)] Loss: 0.016750 [15360/20457 (75%)] Loss: 0.165899 [16640/20457 (81%)] Loss: 0.066964 [17920/20457 (88%)] Loss: 0.044991 [19200/20457 (94%)] Loss: 0.028740 [13120/20457 (100%)] Loss: 0.010718 Train: Average loss: 0.0417, Accuracy: 0.9856 Validation: Average loss: 3.4930, Accuracy: 0.8046 Train Epoch: 28 [1280/20457 (6%)] Loss: 0.013860 [2560/20457 (12%)] Loss: 0.003523 [3840/20457 (19%)] Loss: 0.017058 [5120/20457 (25%)] Loss: 0.036954 [6400/20457 (31%)] Loss: 0.021833 [7680/20457 (38%)] Loss: 0.108124 [8960/20457 (44%)] Loss: 0.022643 [10240/20457 (50%)] Loss: 0.034491 [11520/20457 (56%)] Loss: 0.046759 [12800/20457 (62%)] Loss: 0.019284 [14080/20457 (69%)] Loss: 0.039736 [15360/20457 (75%)] Loss: 0.030811 [16640/20457 (81%)] Loss: 0.035037 [17920/20457 (88%)] Loss: 0.030163 [19200/20457 (94%)] Loss: 0.015596 [13120/20457 (100%)] Loss: 0.023386 Train: Average loss: 0.0433, Accuracy: 0.9843 Validation: Average loss: 2.9932, Accuracy: 0.8039 Train Epoch: 29 [1280/20457 (6%)] Loss: 0.034986 [2560/20457 (12%)] Loss: 0.014053 [3840/20457 (19%)] Loss: 0.021670 [5120/20457 (25%)] Loss: 0.013239 [6400/20457 (31%)] Loss: 0.019322 [7680/20457 (38%)] Loss: 0.108043 [8960/20457 (44%)] Loss: 0.080542 [10240/20457 (50%)] Loss: 0.050894 [11520/20457 (56%)] Loss: 0.022049 [12800/20457 (62%)] Loss: 0.056777 [14080/20457 (69%)] Loss: 0.036953 [15360/20457 (75%)] Loss: 0.008373 [16640/20457 (81%)] Loss: 0.094993 [17920/20457 (88%)] Loss: 0.046380 [19200/20457 (94%)] Loss: 0.040345 [13120/20457 (100%)] Loss: 0.067332 Train: Average loss: 0.0421, Accuracy: 0.9840 Validation: Average loss: 3.1402, Accuracy: 0.8191 Train Epoch: 30 [1280/20457 (6%)] Loss: 0.038956 [2560/20457 (12%)] Loss: 0.016095 [3840/20457 (19%)] Loss: 0.015689 [5120/20457 (25%)] Loss: 0.022543 [6400/20457 (31%)] Loss: 0.032722 [7680/20457 (38%)] Loss: 0.063683 [8960/20457 (44%)] Loss: 0.009182 [10240/20457 (50%)] Loss: 0.036861 [11520/20457 (56%)] Loss: 0.067981 [12800/20457 (62%)] Loss: 0.020677 [14080/20457 (69%)] Loss: 0.007287 [15360/20457 (75%)] Loss: 0.016112 [16640/20457 (81%)] Loss: 0.008147 [17920/20457 (88%)] Loss: 0.008376 [19200/20457 (94%)] Loss: 0.031842 [13120/20457 (100%)] Loss: 0.003770 Train: Average loss: 0.0381, Accuracy: 0.9865 Validation: Average loss: 2.6988, Accuracy: 0.8299
epochs = range(1, len(train_losses) + 1)
plt.figure(figsize=(15,6))
plt.subplot(1,2,1)
plt.plot(epochs, train_losses, '-o', label='Training loss')
plt.plot(epochs, val_losses, '-o', label='Validation loss')
plt.legend()
plt.title('Learning curves')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.xticks(epochs)
plt.subplot(1,2,2)
plt.plot(epochs, train_accuracies, '-o', label='Training accuracy')
plt.plot(epochs, val_accuracies, '-o', label='Validation accuracy')
plt.legend()
plt.title('Learning curves')
plt.xlabel('Epoch')
plt.ylabel('accuracy')
plt.xticks(epochs)
plt.show()
# best_epoch = 32
model = torch.load(checkpoints_foler+f'/avp_{best_epoch:03d}.pkl')
/tmp/cache-bformanek/ipykernel_128826/529002640.py:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. model = torch.load(checkpoints_foler+f'/avp_{best_epoch:03d}.pkl')
def predict(model, data_loader):
model.eval()
# save the predictions in this list
y_pred = []
# no gradient needed
with torch.no_grad():
# go over each batch in the loader. We can ignore the targets here
for batch, _ in data_loader:
# Move batch to the GPU
batch = batch.to(device)
# predict probabilities of each class
predictions = model(batch)
# apply a softmax to the predictions
predictions = F.softmax(predictions, dim=1)
# move to the cpu and convert to numpy
predictions = predictions.cpu().numpy()
# save
y_pred.append(predictions)
# stack predictions into a (num_samples, 10) array
y_pred = np.vstack(y_pred)
return y_pred
# compute predictions on the test set
y_pred = predict(model, test_loader)
# find the argmax of each of the predictions
y_pred = y_pred.argmax(axis=1)
# get the true labels and convert to numpy
y_true = np.array(test_set.targets)
num_errors = np.sum(y_true != y_pred)
print(f'Test errors {num_errors} (out of {len(test_set)}) {num_errors/len(test_set)*100:0.2f}%')
print(f'Test accuracy {100-num_errors/len(test_set)*100:0.2f}%')
Test errors 484 (out of 5823) 8.31% Test accuracy 91.69%
from sklearn.metrics import confusion_matrix
import seaborn as sns
conf_matrix = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
xticklabels=train_categories,
yticklabels=train_categories)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.show()
/usr/lib/python3/dist-packages/statsmodels/__init__.py:6: UserWarning: This appears to be an armel system, on which statsmodels is buggy (crashes and possibly wrong answers) - https://bugs.debian.org/968210 warnings.warn("This appears to be an armel system, on which statsmodels is buggy (crashes and possibly wrong answers) - https://bugs.debian.org/968210")
TP = conf_matrix.diagonal()
P = conf_matrix.sum(axis=1)
# Calculate balanced accuracy
balanced_accuracy = sum(TP / P) / len(P)
print(f'Balanced accuracy {balanced_accuracy*100:0.2f}%')
Balanced accuracy 88.54%