From e091881626929af7390411322cb8ddf77ec52c94 Mon Sep 17 00:00:00 2001
From: V Moni <vajay.monika@hallgato.ppke.hu>
Date: Sat, 11 Jan 2025 17:20:12 +0100
Subject: [PATCH] evaluation scripts

---
 evaluation/barchart_slices.py              |  81 ++++++++++++
 evaluation/count_misclassified_instance.py | 128 +++++++++++++++++++
 evaluation/get_all_slice.py                |  35 ++++++
 evaluation/parameters.py                   |  29 +++++
 evaluation/visualization.py                | 140 +++++++++++++++++++++
 evaluation/visualization_functions.py      |  78 ++++++++++++
 6 files changed, 491 insertions(+)
 create mode 100644 evaluation/barchart_slices.py
 create mode 100644 evaluation/count_misclassified_instance.py
 create mode 100644 evaluation/get_all_slice.py
 create mode 100644 evaluation/parameters.py
 create mode 100644 evaluation/visualization.py
 create mode 100644 evaluation/visualization_functions.py

diff --git a/evaluation/barchart_slices.py b/evaluation/barchart_slices.py
new file mode 100644
index 0000000..3bdbac6
--- /dev/null
+++ b/evaluation/barchart_slices.py
@@ -0,0 +1,81 @@
+import json
+import seaborn as sns
+import matplotlib.pyplot as plt
+import pandas as pd
+
+# Load the data
+file_path = "comparision_N1_N2_N3_graphs/slices_evaluation/resnet18_basic_N1_misclassified_test_set_N2_averagingslices_ev.json"
+
+import json
+import seaborn as sns
+import matplotlib.pyplot as plt
+import pandas as pd
+
+# Load the data
+def load_data(all_slices_path, mistake_slices_path):
+    with open(all_slices_path, 'r') as all_file:
+        all_data = json.load(all_file)
+
+    with open(mistake_slices_path, 'r') as mistake_file:
+        mistake_data = json.load(mistake_file)
+
+    # Remove misclassified slices to get good slices
+    good_positions = {}
+    mistake_positions = {}
+
+    for plane in all_data:
+        all_slices = all_data[plane]
+        misclassified_slices = mistake_data['mistake_positions'].get(plane, [])
+        good_positions[plane] = []
+        temp_misclassified = misclassified_slices.copy()
+        for slice_pos in all_slices:
+            if slice_pos in temp_misclassified:
+                temp_misclassified.remove(slice_pos)
+            else:
+                good_positions[plane].append(slice_pos)
+        mistake_positions[plane] = misclassified_slices
+
+    return good_positions, mistake_positions
+
+
+def prepare_data(positions, label):
+    counts = pd.Series(positions).value_counts().reset_index()
+    counts.columns = ['x_value', 'count']
+    counts['label'] = label
+    return counts
+
+# Load slices from the JSON files
+all_slices_path = 'all_slices_N3.json'  
+mistake_slices_path = "comparision_N1_N2_N3_graphs/slices_evaluation/resnet18_basic_N1_misclassified_test_set_N3slices_ev.json"
+good_data, mistake_data = load_data(all_slices_path, mistake_slices_path)
+
+# Prepare data for each plane
+dfs = []
+for plane in ['sagittal', 'axial', 'coronial']:
+    good_df = prepare_data(good_data[plane], 'Good')
+    good_df['plane'] = plane
+    mistake_df = prepare_data(mistake_data[plane], 'Mistake')
+    mistake_df['plane'] = plane
+    #dfs.append(mistake_df)
+    dfs.append(pd.concat([good_df, mistake_df]))
+
+# Combine all data
+all_data = pd.concat(dfs)
+
+# Plot settings
+sns.set(style="whitegrid")
+colors = {'Good': 'green', 'Mistake': 'red'}
+
+# Create bar plots for each plane
+for plane in ['sagittal', 'axial', 'coronial']:
+    plt.figure(figsize=(10, 6))
+    plane_data = all_data[all_data['plane'] == plane]
+    sns.barplot(data=plane_data, x='x_value', y='count', hue='label', palette=colors)
+    plt.title(f'{plane.capitalize()} Plane Bar Chart')
+    plt.xlabel('X Value')
+    plt.ylabel('Count')
+    plt.legend(title='Position Type')
+    plt.xticks(rotation=45)
+    plt.tight_layout()
+
+plt.show()
\ No newline at end of file
diff --git a/evaluation/count_misclassified_instance.py b/evaluation/count_misclassified_instance.py
new file mode 100644
index 0000000..918591d
--- /dev/null
+++ b/evaluation/count_misclassified_instance.py
@@ -0,0 +1,128 @@
+import numpy 
+import json
+import os
+import pandas as pd
+
+# megnezni hogy hogy megy a kimentese...
+# megszamolni, hogy hany elem van a key-ek kozott
+# ha at kell irni a kimentest: csak a rosszul klasszifalt fajl nevet kiirni, minden más csak sub dolog legyen
+AXIS_CONVERTION = {
+    '0': "axial",
+    '1': "coronial",
+    '2': "sagittal"
+}
+#file_name = "comparision_N1_N2_N3_graphs/resnet18_basic_N1_misclassified_test_set_N3.json"
+
+folder_path = "comparision_N1_N2_N3_graphs/misclassified_data/"
+for item in os.listdir(folder_path):
+    file_name = os.path.join(folder_path, item)  # Get full path
+    print(file_name)
+    with open(file_name, 'r') as file:
+            data = json.load(file) 
+
+    all_mistake = 0
+    individual_mistakes = 0
+
+    mistake_positions = {"sagittal": [], "axial": [], "coronial": []}
+    good_positions = {"sagittal": [], "axial": [], "coronial": []}
+    mistakes = {"sagittal": 0, "axial": 0, "coronial": 0}
+    mistakes_all = {"sagittal": 0, "axial": 0, "coronial": 0}
+    mistake_types = {"FLAIR": 0, "FLAIRCE": 0, "OTHER": 0, "T1w": 0, "T1wCE": 0, "T2star": 0, "T2w": 0}
+
+    if not "averaging" in file_name:  
+        for ground_truth, base_names in data.items():
+            for base_name, axes in base_names.items():
+                individual_mistakes += 1
+                for axis, axis_data in axes.items():
+                    slice_positions = axis_data["slice_position"]
+
+                    if isinstance(slice_positions, str):
+                        slice_positions = int(slice_positions)
+                        mistake_count = 1
+                        mistake_positions[axis].append(slice_positions)
+                    else:
+                        mistake_count = len(slice_positions)
+                        for i in range(mistake_count):
+                            mistake_positions[axis].append(int(slice_positions[i]))
+                            print(f"{int(slice_positions[i])},len: {mistake_count} type: {type(int(slice_positions[i]))}")
+                    
+                    all_mistake += mistake_count
+                    mistakes_all[axis] += mistake_count
+                    mistake_types[ground_truth] += mistake_count
+
+                    if len(axes) != 3:
+                        mistakes[axis] += mistake_count
+        print("Mistakes by axis: ", mistakes)
+        print("Mistakes by axis: ", mistakes_all)
+        
+        #Mistakes by axis: {'sagittal': 378, 'axial': 118, 'coronial': 260}
+        #Mistakes by axis: {'sagittal': 460, 'axial': 209, 'coronial': 355}
+    else:
+        for base_name, values in data.items():
+            individual_mistakes += 1
+            gt = values["gt"]
+            each_pred = values["each_pred"]
+            indexes = [index for index, element in enumerate(each_pred) if element == gt]
+            for i in range(len(each_pred)):
+                axis = AXIS_CONVERTION[values["axis"][i]]
+                slice_pos = int(values["slice_position"][i])
+
+                if i in indexes:
+                    good_positions[axis].append(slice_pos)
+                else:
+                    all_mistake += 1
+                    mistake_positions[axis].append(slice_pos)
+                    mistakes_all[axis] += 1
+
+        for key in good_positions:
+            good_positions[key] = sorted(good_positions[key], key=int)
+
+        print("Good slice positions by axis:",good_positions)
+
+    for key in mistake_positions:
+        mistake_positions[key] = sorted(mistake_positions[key], key=int)
+
+    print("Mistake slice positions by axis:", mistake_positions)
+
+    print("All mistakes:", all_mistake)
+    print("Individual mistake: ", individual_mistakes)
+
+    attributes = ["file_name","all_mistake", "individual_mistakes", "mistakes_sagittal", "mistakes_axial" ,"mistakes_coronial",
+                    "mistakes_all_sagittal", "mistakes_all_axial" ,"mistakes_all_coronial", "FLAIR", "FLAIRCE", "OTHER", "T1w", 
+                    "T1wCE", "T2star", "T2w"]
+    new_row = {
+        "file_name": item,
+        "all_mistake": all_mistake,
+        "individual_mistakes": individual_mistakes,
+        "mistakes_sagittal":mistakes["sagittal"],
+        "mistakes_coronial":mistakes["coronial"],
+        "mistakes_axial":mistakes["axial"],
+        "mistakes_all_sagittal":mistakes_all["sagittal"],
+        "mistakes_all_coronial":mistakes_all["coronial"],
+        "mistakes_all_axial":mistakes_all["axial"],
+        "FLAIR": mistake_types["FLAIR"],
+        "FLAIRCE": mistake_types["FLAIRCE"], 
+        "OTHER": mistake_types["OTHER"], 
+        "T1w": mistake_types["T1w"], 
+        "T1wCE": mistake_types["T1wCE"], 
+        "T2star": mistake_types["T2star"], 
+        "T2w": mistake_types["T2w"]
+    }
+
+    table_name = "comparision_N1_N2_N3_graphs/misclassified_evaluation.xlsx"
+    if os.path.exists(table_name):
+        existing_df = pd.read_excel(table_name)
+    else:
+        existing_df = pd.DataFrame(columns=attributes)
+
+    existing_df = pd.concat([existing_df, pd.DataFrame([new_row])], ignore_index=True)
+    existing_df.to_excel(table_name, index=False)
+
+    print(f"Table updated and saved to {table_name}")
+
+    all_slices = {"good_positions": good_positions, 
+                "mistake_positions": mistake_positions}
+
+    with open("comparision_N1_N2_N3_graphs/slices_evaluation/"+item[:-5] + "slices_ev.json", "w") as file:
+        json.dump(all_slices, file, indent=4)  
+    print("Multiple dictionaries saved to multiple_dicts.json")
\ No newline at end of file
diff --git a/evaluation/get_all_slice.py b/evaluation/get_all_slice.py
new file mode 100644
index 0000000..077696f
--- /dev/null
+++ b/evaluation/get_all_slice.py
@@ -0,0 +1,35 @@
+import os
+import json
+
+AXIS_CONVERTION = {
+    '0': "axial",
+    '1': "coronial",
+    '2': "sagittal"
+}
+
+def organize_images_by_axis(folder_path):
+    axis_dict = {}
+
+    for root, _, files in os.walk(folder_path):
+        for file in files:
+            if file.endswith('.png'):
+                image_path = os.path.join(root, file)
+                axis = AXIS_CONVERTION[image_path.rsplit("_", 2)[1]]
+                slice_pos = int(image_path.rsplit("_", 2)[2][:-4])
+
+                if axis not in axis_dict:
+                    axis_dict[axis] = []
+                axis_dict[axis].append(slice_pos)
+
+    return axis_dict
+
+# Example usage
+folder_path = "C:/Users/Monika/Documents/IPCV/TRDP/new_data/N1/test"
+image_data = organize_images_by_axis(folder_path)
+
+# Save the dictionary to a JSON file
+output_path = 'all_slices_N1.json'  # Specify the output file path
+with open(output_path, 'w') as json_file:
+    json.dump(image_data, json_file, indent=4)  # Write the dictionary to the JSON file with indentation
+
+print(f"Image data has been saved to {output_path}")
\ No newline at end of file
diff --git a/evaluation/parameters.py b/evaluation/parameters.py
new file mode 100644
index 0000000..0b6e507
--- /dev/null
+++ b/evaluation/parameters.py
@@ -0,0 +1,29 @@
+import seaborn as sns
+
+INDIVIDUALS = 1941
+cmap = sns.color_palette("icefire", as_cmap=True)
+spectral_colors = [cmap(i / 6) for i in range(7)]
+
+# DATA SET
+data_colors = {'N1': spectral_colors[0],'N2': spectral_colors[2],'N3': spectral_colors[4],'N3_c1': spectral_colors[5],'N3_c2': spectral_colors[6]}
+data_shapes = {'N1': 'o', 'N2': 's', 'N3': '*', 'N3_c1': 'D', 'N3_c2': '^'}
+
+# METHODS
+methods_colors = {
+    'SM-GI': spectral_colors[0],'SM-GI-Avg': spectral_colors[1],'SM-DP': spectral_colors[2],
+    'Ens-GI': spectral_colors[3],'Ens-SP-O': spectral_colors[4],
+    'SM-SP-O': spectral_colors[5],'SM-SP-O-Avg': spectral_colors[6]
+}
+
+methods_shapes = {'SM-GI': 'o', 'SM-GI-Avg': 's', 'SM-DP': '^', 
+          'Ens-GI': 'X', 'Ens-SP-O': 'P',
+          'SM-SP-O': 'D', 'SM-SP-O-Avg': 'H'}
+
+# AUGMENTATION
+augmentation_colors = {'-': spectral_colors[0], 'flip': spectral_colors[2], 'flip+rotation': spectral_colors[4], 'flip + rotation + scale': spectral_colors[6]}
+augmentation_shapes = {'-': 'o', 'flip': 's', 'flip+rotation': '^', 'flip + rotation + scale': '*'}
+augmentation_shapes_same = {'-': 'o', 'flip': 'o', 'flip+rotation': 'o', 'flip + rotation + scale': 'o'}
+
+# MODELS
+models_colors = {'resnet18': 'orange', 'resnet18.a2_in1k': 'blue'}
+models_shapes = {'resnet18': 'o', 'resnet18.a2_in1k': 's'}
\ No newline at end of file
diff --git a/evaluation/visualization.py b/evaluation/visualization.py
new file mode 100644
index 0000000..afc55d0
--- /dev/null
+++ b/evaluation/visualization.py
@@ -0,0 +1,140 @@
+import numpy as np
+import pandas as pd
+import matplotlib.pyplot as plt
+
+from visualization_functions import plt_vis, sns_vis
+from parameters import data_colors, data_shapes, methods_colors, methods_shapes, augmentation_colors, augmentation_shapes, augmentation_shapes_same, models_colors, models_shapes
+
+MODEL_COMPARISON = False
+
+AUGMENTATION_COMPARISON = False
+DATASET_COMPARISON = False
+METHOD_COMPARISON = False
+
+SM_GI_COMPARISION = False
+INPUT_COMPARISION = True
+ENSEMBLE_COMPARISION = False
+BEST_MODELS_COMPARISION = False
+
+model_name = "_resnet18_out2"
+folder_models = "comparision_models_graphs/"
+folder_all_attribrutes = "comparision_all_attributes_graphs/"
+folder_N1_N2_N3 = "comparision_N1_N2_N3_graphs/"
+folder_inputs_comp = "comparision_inputs_graphs/"
+folder_ensembles = "comparision_ensembles_graphs/"
+
+if MODEL_COMPARISON:    # visualize resnet18 vs resnet18.a2_in1k
+    table1 = 'models.xlsx'  
+    data_real = pd.read_excel(table1)
+
+    resnet18_var_subtable = data_real[['Model name', 'Test accuracy', 'Balanced test accuracy', 'Dataset name', 'Method']].dropna()
+
+    sns_vis(resnet18_var_subtable, 'Method', 'Test accuracy', 'Model name', 'Dataset name', models_colors, data_shapes, "Model comparison")
+    plt.savefig(folder_models + "model_comparison_TA_method.png", dpi=300, bbox_inches='tight')
+    sns_vis(resnet18_var_subtable, 'Method', 'Balanced test accuracy', 'Model name', 'Dataset name', models_colors, data_shapes, "Model comparison")
+    plt.savefig(folder_models + "model_comparison_BTA_method.png", dpi=300, bbox_inches='tight')
+    sns_vis(resnet18_var_subtable, 'Dataset name', 'Test accuracy', 'Model name', 'Method', models_colors, methods_shapes, "Model comparison")
+    plt.savefig(folder_models + "model_comparison_TA_dataset.png", dpi=300, bbox_inches='tight')
+    sns_vis(resnet18_var_subtable, 'Dataset name', 'Balanced test accuracy', 'Model name', 'Method', models_colors, methods_shapes, "Model comparison")
+    plt.savefig(folder_models + "model_comparison_BTA_dataset.png", dpi=300, bbox_inches='tight')
+
+table2 = 'resnet18.xlsx'  
+resnet18_var_data = pd.read_excel(table2)
+resnet18_var_subtable = resnet18_var_data[['Augmentation', 'Test accuracy', 'Balanced test accuracy', 'Dataset name', 'Method']].dropna()
+
+if AUGMENTATION_COMPARISON:
+    #sns_vis(resnet18_var_subtable, 'Augmentation', 'Test accuracy', 'Method', 'Dataset name', methods_colors, data_shapes, "Augmentation comparison")
+    #sns_vis(resnet18_var_subtable, 'Augmentation', 'Balanced test accuracy', 'Method', 'Dataset name', methods_colors, data_shapes, "Augmentation comparison")
+
+    sns_vis(resnet18_var_subtable, 'Augmentation', 'Test accuracy', 'Dataset name', 'Method', data_colors, methods_shapes, "Augmentation comparison", "out")
+    plt.savefig(folder_all_attribrutes + "augmentation_comparison_TA"+ model_name +".png", dpi=300, bbox_inches='tight')
+    sns_vis(resnet18_var_subtable, 'Augmentation', 'Balanced test accuracy', 'Dataset name', 'Method', data_colors, methods_shapes, "Augmentation comparison", "out")
+    plt.savefig(folder_all_attribrutes + "augmentation_comparison_BTA"+ model_name +".png", dpi=300, bbox_inches='tight')
+
+if DATASET_COMPARISON:
+    #sns_vis(resnet18_var_subtable, 'Dataset name', 'Test accuracy', 'Method', 'Augmentation', methods_colors, augmentation_shapes, "Dataset comparison")
+    #sns_vis(resnet18_var_subtable, 'Dataset name', 'Balanced test accuracy', 'Method', 'Augmentation', methods_colors, augmentation_shapes, "Dataset comparison")
+
+    sns_vis(resnet18_var_subtable, 'Dataset name', 'Test accuracy', 'Augmentation', 'Method', augmentation_colors, methods_shapes, "Dataset comparison", "out")
+    plt.savefig(folder_all_attribrutes + "dataset_comparison_TA"+ model_name +".png", dpi=300, bbox_inches='tight')
+    sns_vis(resnet18_var_subtable, 'Dataset name', 'Balanced test accuracy', 'Augmentation', 'Method', augmentation_colors, methods_shapes, "Dataset comparison", "out")
+    plt.savefig(folder_all_attribrutes + "dataset_comparison_BTA"+ model_name +".png", dpi=300, bbox_inches='tight')
+
+if METHOD_COMPARISON:
+   #sns_vis(resnet18_var_subtable, 'Method', 'Test accuracy', 'Dataset name', 'Augmentation', data_colors, augmentation_shapes, "Method comparison")
+   #sns_vis(resnet18_var_subtable, 'Method', 'Balanced test accuracy', 'Dataset name', 'Augmentation', data_colors, augmentation_shapes, "Dataset comparison")
+
+   sns_vis(resnet18_var_subtable, 'Method', 'Test accuracy', 'Augmentation', 'Dataset name', augmentation_colors, data_shapes, "Method comparison", "out")
+   plt.savefig(folder_all_attribrutes + "method_comparison_TA"+ model_name +".png", dpi=300, bbox_inches='tight')
+   sns_vis(resnet18_var_subtable, 'Method', 'Balanced test accuracy', 'Augmentation', 'Dataset name', augmentation_colors, data_shapes, "Method comparison", "out")
+   plt.savefig(folder_all_attribrutes + "method_comparison_BTA"+ model_name +".png", dpi=300, bbox_inches='tight')
+ 
+if SM_GI_COMPARISION:
+    table3 = "comparision_N1_N2_N3.xlsx"
+    comparision_n1_n2_n3 = pd.read_excel(table3)
+    comparision_n1_n2_n3_subtable = comparision_n1_n2_n3[['Augmentation', 'Test accuracy', 'Balanced test accuracy', 'Train set', 'Test set', 'Method']].dropna()
+
+    augmentation_types = comparision_n1_n2_n3_subtable['Augmentation'].unique()
+    subtables = {}
+
+    for augm in augmentation_types:
+        subtables[augm] = comparision_n1_n2_n3_subtable[(comparision_n1_n2_n3_subtable['Augmentation'] == augm) & (comparision_n1_n2_n3_subtable['Method'] == 'SM-GI')]
+        sns_vis(subtables[augm], 'Train set', 'Balanced test accuracy', 'Test set', 'Method', data_colors, methods_shapes, "Dataset comparison", "out")
+        plt.savefig(folder_N1_N2_N3 + augm + "_ SM-GI_" + "_data_comparison_BTA.png", dpi=300, bbox_inches='tight')
+        
+        subtables[augm] = comparision_n1_n2_n3_subtable[(comparision_n1_n2_n3_subtable['Augmentation'] == augm) & (comparision_n1_n2_n3_subtable['Method'] == 'SM-GI-Avg')]
+        sns_vis(subtables[augm], 'Train set', 'Balanced test accuracy', 'Test set', 'Method', data_colors, methods_shapes, "Dataset comparison", "out")
+        plt.savefig(folder_N1_N2_N3 + augm + "_ SM-GI-AVG_" + "_data_comparison_BTA.png", dpi=300, bbox_inches='tight')
+
+        subtables[augm] = comparision_n1_n2_n3_subtable[comparision_n1_n2_n3_subtable['Augmentation'] == augm]
+        sns_vis(subtables[augm], 'Train set', 'Balanced test accuracy', 'Test set', 'Method', data_colors, methods_shapes, "Dataset comparison", "out")
+        plt.savefig(folder_N1_N2_N3 + augm + "_data_comparison_BTA.png", dpi=300, bbox_inches='tight')
+
+ 
+if INPUT_COMPARISION:
+    table4 = "comparision_inputs.xlsx"
+    comparision_inputs = pd.read_excel(table4)
+    comparision_inputs_subtable = comparision_inputs[['Augmentation', 'Balanced test accuracy', 'Method', 'Dataset name']].dropna()
+
+    augmentation_types = comparision_inputs_subtable['Augmentation'].unique()
+    subtables = {}
+
+    for augm in augmentation_types:
+        subtables[augm] = comparision_inputs_subtable[comparision_inputs_subtable['Augmentation'] == augm]
+        sns_vis(subtables[augm], 'Method', 'Balanced test accuracy', 'Dataset name', None, data_colors, None, legend = "out")
+        plt.savefig(folder_inputs_comp + augm + "_method_comparison_BTA_same_axis.png", dpi=300, bbox_inches='tight')
+    
+    sns_vis(comparision_inputs_subtable, 'Method', 'Balanced test accuracy', 'Dataset name', 'Augmentation', data_colors, augmentation_shapes, "Method comparison", "out")
+    plt.savefig(folder_inputs_comp + "all_aug_method_comparison_BTA_same_axis.png", dpi=300, bbox_inches='tight')
+
+if ENSEMBLE_COMPARISION:
+    table5 = "comparision_ensembles.xlsx"
+    comparision_ensembles = pd.read_excel(table5)
+    comparision_ensembles_subtable = comparision_ensembles[['Augmentation', 'Balanced test accuracy', 'Method', 'Dataset name']].dropna()
+
+    augmentation_types = comparision_ensembles_subtable['Augmentation'].unique()
+    subtables = {}
+
+    """for augm in augmentation_types:
+        subtables[augm] = comparision_ensembles_subtable[comparision_ensembles_subtable['Augmentation'] == augm]
+        sns_vis(subtables[augm], 'Method', 'Balanced test accuracy', 'Dataset name', None, data_colors, None, legend = "out")
+        plt.savefig(folder_ensembles + augm + "_ensemble_comparison_BTA.png", dpi=300, bbox_inches='tight')"""
+    
+    sns_vis(comparision_ensembles_subtable, 'Augmentation', 'Balanced test accuracy', 'Dataset name', 'Method', data_colors, methods_shapes, legend = "out")
+    plt.savefig(folder_ensembles + "all_aug_ensemble_comparison_BTA_4.png", dpi=300, bbox_inches='tight')
+
+if BEST_MODELS_COMPARISION:
+    table6 = "comparision_best_models.xlsx"
+    comparision_best_models = pd.read_excel(table6)
+    comparision_best_models_subtable = comparision_best_models[['Augmentation', 'Balanced test accuracy', 'Method', 'Dataset name']].dropna()
+    
+    sns_vis(comparision_best_models_subtable, 'Augmentation', 'Balanced test accuracy', 'Dataset name', 'Method', data_colors, methods_shapes, legend = "out")
+    plt.savefig("best_models_comparison_BTA.png", dpi=300, bbox_inches='tight')
+
+    table7 = "comparision_best_models_aug_test_set.xlsx"
+    comparision_best_models_aug_test_set = pd.read_excel(table7)
+    comparision_best_models_aug_test_set_subtable = comparision_best_models_aug_test_set[['Augmentation', 'Balanced test accuracy', 'Method', 'Dataset name']].dropna()
+    
+    sns_vis(comparision_best_models_aug_test_set_subtable, 'Augmentation', 'Balanced test accuracy', 'Dataset name', 'Method', data_colors, methods_shapes, legend = "out")
+    plt.savefig("best_models_comparison_aug_test_set_BTA.png", dpi=300, bbox_inches='tight')
+#plt.show()
\ No newline at end of file
diff --git a/evaluation/visualization_functions.py b/evaluation/visualization_functions.py
new file mode 100644
index 0000000..8d6d598
--- /dev/null
+++ b/evaluation/visualization_functions.py
@@ -0,0 +1,78 @@
+import seaborn as sns
+import matplotlib.pyplot as plt
+from parameters import INDIVIDUALS
+
+def plt_vis(data, x_axis, y_axis1, color_attr, shape_attr, colors, shapes):
+    # Plot
+    plt.figure(figsize=(8, 6))
+    for _, row in data.iterrows():
+        plt.scatter(row[x_axis], row[y_axis1], 
+                    color=colors[row[color_attr]], 
+                    marker=shapes[row[shape_attr]], 
+                    label=f"{row[color_attr]} - {row[shape_attr]}")
+
+    # Custom legend
+    handles, labels = plt.gca().get_legend_handles_labels()
+    by_label = dict(zip(labels, handles))
+    plt.legend(by_label.values(), by_label.keys(), bbox_to_anchor=(1.05, 1), loc='upper left')
+    plt.xticks(rotation=20)
+
+    plt.xlabel(x_axis)
+    plt.ylabel(y_axis1)
+    plt.title('Matplotlib Visualization')
+    plt.tight_layout()
+
+def sns_vis(data, x_axis, y_axis1, color_attr, shape_attr, colors, shapes, title = "", legend = "out"):
+    plt.figure(figsize=(8.5, 4)) #8.5, 7
+    plt.rcParams.update({
+        'font.size': 14,         # Global font size
+        'axes.titlesize': 18,    # Title font size
+        'axes.labelsize': 16,    # Axis labels font size
+        'xtick.labelsize': 14,   # X-axis tick labels
+        'ytick.labelsize': 14,   # Y-axis tick labels
+        'legend.fontsize': 14    # Legend font size
+    })
+    sns.scatterplot(data=data, x=x_axis, y=y_axis1, 
+                    hue=color_attr, style=shape_attr, 
+                    palette=colors, markers=shapes, s=100)
+
+    plt.xlabel(x_axis) # plt.xlabel("Augmentation on the training set")
+    plt.ylabel(y_axis1)
+    plt.xticks(rotation=20)
+    plt.title(title)
+    if legend == "out":
+        plt.legend(bbox_to_anchor=(1.15, 1), loc='upper left') #bbox_to_anchor=(1.35, 1)
+    elif legend == "left":
+        plt.legend(loc='lower left', frameon=True)
+    elif legend == "right":
+        plt.legend(loc='lower right', frameon=True)
+    plt.tight_layout()
+    primary_ax = plt.gca()
+    primary_ax.set_ylim(0.89, 1.0)
+
+
+    # Secondary y-axis
+    def secondary_y_axis_transform(value):
+        return INDIVIDUALS * (1 - value)
+
+    def secondary_y_axis_inverse(value):
+        return 1 - (value / INDIVIDUALS)
+
+    primary_ax = plt.gca()
+    secondary_ax = primary_ax.twinx()
+    
+    # Set the secondary y-axis limits based on the primary y-axis
+    secondary_ax.set_ylim(
+        secondary_y_axis_transform(primary_ax.get_ylim()[0]),
+        secondary_y_axis_transform(primary_ax.get_ylim()[1])
+    )
+    secondary_ax.set_ylabel('Misclassified instances')
+
+    # Synchronize primary and secondary y-axes
+    def sync_axes(ax):
+        primary_ax.set_ylim(
+            secondary_y_axis_inverse(ax.get_ylim()[0]),
+            secondary_y_axis_inverse(ax.get_ylim()[1])
+        )
+
+    secondary_ax.callbacks.connect("ylim_changed", lambda ax: sync_axes(ax))
\ No newline at end of file
-- 
GitLab