Skip to content
Snippets Groups Projects
Commit 531d5384 authored by Formanek Balázs István's avatar Formanek Balázs István
Browse files

test transform impact - codes

parent e54dabf9
Branches
No related tags found
No related merge requests found
Balanced Accuracy vs. Rotation.png

45.1 KiB

import os
import torch
import torchvision.transforms as transforms
#import torch.nn as nn
#import torch.optim as optim
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset
# from torchvision.transforms import v2
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
import numpy as np
import pandas as pd
#import matplotlib.pyplot as plt
#import sklearn.metrics as metrics
from sklearn.metrics import confusion_matrix
#import seaborn as sns
from itertools import product
import random
random.seed(42)
torch.manual_seed(42)
np.random.seed(42)
# Define a function to get the transformations based on the configuration
def get_transform(use_flip, rotation):
transform_list = []
# resize
transform_list.append(transforms.Resize(224))
# flip
if use_flip:
transform_list.append(transforms.RandomHorizontalFlip())
# rotation
transform_list.append(transforms.RandomRotation(degrees=rotation))
transform_list.append(transforms.ToTensor())
return transforms.Compose(transform_list)
def predict(model, data_loader, device):
model.eval()
# save the predictions in this list
y_pred = []
# no gradient needed
with torch.no_grad():
# go over each batch in the loader. We can ignore the targets here
for batch, _ in data_loader:
# Move batch to the GPU
batch = batch.to(device)
# predict probabilities of each class
predictions = model(batch)
# apply a softmax to the predictions
predictions = F.softmax(predictions, dim=1)
# move to the cpu and convert to numpy
predictions = predictions.cpu().numpy()
# save
y_pred.append(predictions)
# stack predictions into a (num_samples, 10) array
y_pred = np.vstack(y_pred)
return y_pred
# MAIN
if torch.cuda.is_available():
device = 'cuda'
print('cuda is available')
else:
device = 'cpu'
print('cuda is not available, change to cpu')
model_path = '/net/cremi/bformanek/TRDP_II/local_models/'
model_name = 'transfer_checkpoints_resnet18_adam_amp_criterion_balanced_avp_025.pkl'
model = torch.load(model_path + model_name, map_location=torch.device('cpu'), weights_only=False)
model.to(device)
DATA_PATH = '/net/travail/bformanek/MRI_dataset'
TEST_FOLDER = DATA_PATH + '/test'
# Set up your transformations with a list of options
use_random_flip = [False, True]
random_rotation = list(range(0, 180, 10))
BATCH_SIZE = 64
WORKERS = 8
# Initialize list to store results
results = []
for flip, rotation in product(use_random_flip, random_rotation):
transform = get_transform(flip, rotation)
test_set = ImageFolder(TEST_FOLDER, transform = transform)
test_loader = DataLoader(test_set, batch_size = BATCH_SIZE, shuffle = False, num_workers=WORKERS)
# compute predictions on the test set
y_pred = predict(model, test_loader, device)
# find the argmax of each of the predictions
y_pred = y_pred.argmax(axis=1)
# get the true labels and convert to numpy
y_true = np.array(test_set.targets)
# balanced accuracy
conf_matrix = confusion_matrix(y_true, y_pred)
TP = conf_matrix.diagonal()
P = conf_matrix.sum(axis=1)
balanced_accuracy = sum(TP / P) / len(P)
results.append({
"use_random_flip": flip,
"random_rotation": rotation,
"balanced_accuracy": balanced_accuracy
})
print(f"use_random_flip - {flip}, random_rotation - {rotation}, balanced_accuracy - {balanced_accuracy}")
# Save results to a DataFrame and export to CSV for easy analysis
df_results = pd.DataFrame(results)
df_results.to_csv("transform_evaluation_results.csv", index=False)
\ No newline at end of file
This diff is collapsed.
use_random_flip,random_rotation,balanced_accuracy
False,0,0.9910891981996583
False,10,0.9791003824858996
False,20,0.9555263606009453
False,30,0.9350323928076885
False,40,0.920475736344423
False,50,0.9035698129166397
False,60,0.8884172084451903
False,70,0.8242767620826293
False,80,0.8627538664121938
False,90,0.8084297581588451
False,100,0.7491355267513998
False,110,0.798434987707292
False,120,0.7858226803376203
False,130,0.780957605269648
False,140,0.7319620036389785
False,150,0.7699308726795506
False,160,0.7279718128212013
False,170,0.7794329344544063
True,0,0.9458076701140617
True,10,0.9380200646297198
True,20,0.9179824545578235
True,30,0.8951788501005208
True,40,0.8899794965416481
True,50,0.8776486711418438
True,60,0.8615161023146182
True,70,0.8088444130000744
True,80,0.8430272303480528
True,90,0.7870537398293254
True,100,0.7825543670178262
True,110,0.7304412568002858
True,120,0.7700303001804264
True,130,0.7172527098110664
True,140,0.7275591371037369
True,150,0.7743519262432214
True,160,0.728260065937451
True,170,0.773318541938019
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment