From e091881626929af7390411322cb8ddf77ec52c94 Mon Sep 17 00:00:00 2001 From: V Moni <vajay.monika@hallgato.ppke.hu> Date: Sat, 11 Jan 2025 17:20:12 +0100 Subject: [PATCH] evaluation scripts --- evaluation/barchart_slices.py | 81 ++++++++++++ evaluation/count_misclassified_instance.py | 128 +++++++++++++++++++ evaluation/get_all_slice.py | 35 ++++++ evaluation/parameters.py | 29 +++++ evaluation/visualization.py | 140 +++++++++++++++++++++ evaluation/visualization_functions.py | 78 ++++++++++++ 6 files changed, 491 insertions(+) create mode 100644 evaluation/barchart_slices.py create mode 100644 evaluation/count_misclassified_instance.py create mode 100644 evaluation/get_all_slice.py create mode 100644 evaluation/parameters.py create mode 100644 evaluation/visualization.py create mode 100644 evaluation/visualization_functions.py diff --git a/evaluation/barchart_slices.py b/evaluation/barchart_slices.py new file mode 100644 index 0000000..3bdbac6 --- /dev/null +++ b/evaluation/barchart_slices.py @@ -0,0 +1,81 @@ +import json +import seaborn as sns +import matplotlib.pyplot as plt +import pandas as pd + +# Load the data +file_path = "comparision_N1_N2_N3_graphs/slices_evaluation/resnet18_basic_N1_misclassified_test_set_N2_averagingslices_ev.json" + +import json +import seaborn as sns +import matplotlib.pyplot as plt +import pandas as pd + +# Load the data +def load_data(all_slices_path, mistake_slices_path): + with open(all_slices_path, 'r') as all_file: + all_data = json.load(all_file) + + with open(mistake_slices_path, 'r') as mistake_file: + mistake_data = json.load(mistake_file) + + # Remove misclassified slices to get good slices + good_positions = {} + mistake_positions = {} + + for plane in all_data: + all_slices = all_data[plane] + misclassified_slices = mistake_data['mistake_positions'].get(plane, []) + good_positions[plane] = [] + temp_misclassified = misclassified_slices.copy() + for slice_pos in all_slices: + if slice_pos in temp_misclassified: + temp_misclassified.remove(slice_pos) + else: + good_positions[plane].append(slice_pos) + mistake_positions[plane] = misclassified_slices + + return good_positions, mistake_positions + + +def prepare_data(positions, label): + counts = pd.Series(positions).value_counts().reset_index() + counts.columns = ['x_value', 'count'] + counts['label'] = label + return counts + +# Load slices from the JSON files +all_slices_path = 'all_slices_N3.json' +mistake_slices_path = "comparision_N1_N2_N3_graphs/slices_evaluation/resnet18_basic_N1_misclassified_test_set_N3slices_ev.json" +good_data, mistake_data = load_data(all_slices_path, mistake_slices_path) + +# Prepare data for each plane +dfs = [] +for plane in ['sagittal', 'axial', 'coronial']: + good_df = prepare_data(good_data[plane], 'Good') + good_df['plane'] = plane + mistake_df = prepare_data(mistake_data[plane], 'Mistake') + mistake_df['plane'] = plane + #dfs.append(mistake_df) + dfs.append(pd.concat([good_df, mistake_df])) + +# Combine all data +all_data = pd.concat(dfs) + +# Plot settings +sns.set(style="whitegrid") +colors = {'Good': 'green', 'Mistake': 'red'} + +# Create bar plots for each plane +for plane in ['sagittal', 'axial', 'coronial']: + plt.figure(figsize=(10, 6)) + plane_data = all_data[all_data['plane'] == plane] + sns.barplot(data=plane_data, x='x_value', y='count', hue='label', palette=colors) + plt.title(f'{plane.capitalize()} Plane Bar Chart') + plt.xlabel('X Value') + plt.ylabel('Count') + plt.legend(title='Position Type') + plt.xticks(rotation=45) + plt.tight_layout() + +plt.show() \ No newline at end of file diff --git a/evaluation/count_misclassified_instance.py b/evaluation/count_misclassified_instance.py new file mode 100644 index 0000000..918591d --- /dev/null +++ b/evaluation/count_misclassified_instance.py @@ -0,0 +1,128 @@ +import numpy +import json +import os +import pandas as pd + +# megnezni hogy hogy megy a kimentese... +# megszamolni, hogy hany elem van a key-ek kozott +# ha at kell irni a kimentest: csak a rosszul klasszifalt fajl nevet kiirni, minden más csak sub dolog legyen +AXIS_CONVERTION = { + '0': "axial", + '1': "coronial", + '2': "sagittal" +} +#file_name = "comparision_N1_N2_N3_graphs/resnet18_basic_N1_misclassified_test_set_N3.json" + +folder_path = "comparision_N1_N2_N3_graphs/misclassified_data/" +for item in os.listdir(folder_path): + file_name = os.path.join(folder_path, item) # Get full path + print(file_name) + with open(file_name, 'r') as file: + data = json.load(file) + + all_mistake = 0 + individual_mistakes = 0 + + mistake_positions = {"sagittal": [], "axial": [], "coronial": []} + good_positions = {"sagittal": [], "axial": [], "coronial": []} + mistakes = {"sagittal": 0, "axial": 0, "coronial": 0} + mistakes_all = {"sagittal": 0, "axial": 0, "coronial": 0} + mistake_types = {"FLAIR": 0, "FLAIRCE": 0, "OTHER": 0, "T1w": 0, "T1wCE": 0, "T2star": 0, "T2w": 0} + + if not "averaging" in file_name: + for ground_truth, base_names in data.items(): + for base_name, axes in base_names.items(): + individual_mistakes += 1 + for axis, axis_data in axes.items(): + slice_positions = axis_data["slice_position"] + + if isinstance(slice_positions, str): + slice_positions = int(slice_positions) + mistake_count = 1 + mistake_positions[axis].append(slice_positions) + else: + mistake_count = len(slice_positions) + for i in range(mistake_count): + mistake_positions[axis].append(int(slice_positions[i])) + print(f"{int(slice_positions[i])},len: {mistake_count} type: {type(int(slice_positions[i]))}") + + all_mistake += mistake_count + mistakes_all[axis] += mistake_count + mistake_types[ground_truth] += mistake_count + + if len(axes) != 3: + mistakes[axis] += mistake_count + print("Mistakes by axis: ", mistakes) + print("Mistakes by axis: ", mistakes_all) + + #Mistakes by axis: {'sagittal': 378, 'axial': 118, 'coronial': 260} + #Mistakes by axis: {'sagittal': 460, 'axial': 209, 'coronial': 355} + else: + for base_name, values in data.items(): + individual_mistakes += 1 + gt = values["gt"] + each_pred = values["each_pred"] + indexes = [index for index, element in enumerate(each_pred) if element == gt] + for i in range(len(each_pred)): + axis = AXIS_CONVERTION[values["axis"][i]] + slice_pos = int(values["slice_position"][i]) + + if i in indexes: + good_positions[axis].append(slice_pos) + else: + all_mistake += 1 + mistake_positions[axis].append(slice_pos) + mistakes_all[axis] += 1 + + for key in good_positions: + good_positions[key] = sorted(good_positions[key], key=int) + + print("Good slice positions by axis:",good_positions) + + for key in mistake_positions: + mistake_positions[key] = sorted(mistake_positions[key], key=int) + + print("Mistake slice positions by axis:", mistake_positions) + + print("All mistakes:", all_mistake) + print("Individual mistake: ", individual_mistakes) + + attributes = ["file_name","all_mistake", "individual_mistakes", "mistakes_sagittal", "mistakes_axial" ,"mistakes_coronial", + "mistakes_all_sagittal", "mistakes_all_axial" ,"mistakes_all_coronial", "FLAIR", "FLAIRCE", "OTHER", "T1w", + "T1wCE", "T2star", "T2w"] + new_row = { + "file_name": item, + "all_mistake": all_mistake, + "individual_mistakes": individual_mistakes, + "mistakes_sagittal":mistakes["sagittal"], + "mistakes_coronial":mistakes["coronial"], + "mistakes_axial":mistakes["axial"], + "mistakes_all_sagittal":mistakes_all["sagittal"], + "mistakes_all_coronial":mistakes_all["coronial"], + "mistakes_all_axial":mistakes_all["axial"], + "FLAIR": mistake_types["FLAIR"], + "FLAIRCE": mistake_types["FLAIRCE"], + "OTHER": mistake_types["OTHER"], + "T1w": mistake_types["T1w"], + "T1wCE": mistake_types["T1wCE"], + "T2star": mistake_types["T2star"], + "T2w": mistake_types["T2w"] + } + + table_name = "comparision_N1_N2_N3_graphs/misclassified_evaluation.xlsx" + if os.path.exists(table_name): + existing_df = pd.read_excel(table_name) + else: + existing_df = pd.DataFrame(columns=attributes) + + existing_df = pd.concat([existing_df, pd.DataFrame([new_row])], ignore_index=True) + existing_df.to_excel(table_name, index=False) + + print(f"Table updated and saved to {table_name}") + + all_slices = {"good_positions": good_positions, + "mistake_positions": mistake_positions} + + with open("comparision_N1_N2_N3_graphs/slices_evaluation/"+item[:-5] + "slices_ev.json", "w") as file: + json.dump(all_slices, file, indent=4) + print("Multiple dictionaries saved to multiple_dicts.json") \ No newline at end of file diff --git a/evaluation/get_all_slice.py b/evaluation/get_all_slice.py new file mode 100644 index 0000000..077696f --- /dev/null +++ b/evaluation/get_all_slice.py @@ -0,0 +1,35 @@ +import os +import json + +AXIS_CONVERTION = { + '0': "axial", + '1': "coronial", + '2': "sagittal" +} + +def organize_images_by_axis(folder_path): + axis_dict = {} + + for root, _, files in os.walk(folder_path): + for file in files: + if file.endswith('.png'): + image_path = os.path.join(root, file) + axis = AXIS_CONVERTION[image_path.rsplit("_", 2)[1]] + slice_pos = int(image_path.rsplit("_", 2)[2][:-4]) + + if axis not in axis_dict: + axis_dict[axis] = [] + axis_dict[axis].append(slice_pos) + + return axis_dict + +# Example usage +folder_path = "C:/Users/Monika/Documents/IPCV/TRDP/new_data/N1/test" +image_data = organize_images_by_axis(folder_path) + +# Save the dictionary to a JSON file +output_path = 'all_slices_N1.json' # Specify the output file path +with open(output_path, 'w') as json_file: + json.dump(image_data, json_file, indent=4) # Write the dictionary to the JSON file with indentation + +print(f"Image data has been saved to {output_path}") \ No newline at end of file diff --git a/evaluation/parameters.py b/evaluation/parameters.py new file mode 100644 index 0000000..0b6e507 --- /dev/null +++ b/evaluation/parameters.py @@ -0,0 +1,29 @@ +import seaborn as sns + +INDIVIDUALS = 1941 +cmap = sns.color_palette("icefire", as_cmap=True) +spectral_colors = [cmap(i / 6) for i in range(7)] + +# DATA SET +data_colors = {'N1': spectral_colors[0],'N2': spectral_colors[2],'N3': spectral_colors[4],'N3_c1': spectral_colors[5],'N3_c2': spectral_colors[6]} +data_shapes = {'N1': 'o', 'N2': 's', 'N3': '*', 'N3_c1': 'D', 'N3_c2': '^'} + +# METHODS +methods_colors = { + 'SM-GI': spectral_colors[0],'SM-GI-Avg': spectral_colors[1],'SM-DP': spectral_colors[2], + 'Ens-GI': spectral_colors[3],'Ens-SP-O': spectral_colors[4], + 'SM-SP-O': spectral_colors[5],'SM-SP-O-Avg': spectral_colors[6] +} + +methods_shapes = {'SM-GI': 'o', 'SM-GI-Avg': 's', 'SM-DP': '^', + 'Ens-GI': 'X', 'Ens-SP-O': 'P', + 'SM-SP-O': 'D', 'SM-SP-O-Avg': 'H'} + +# AUGMENTATION +augmentation_colors = {'-': spectral_colors[0], 'flip': spectral_colors[2], 'flip+rotation': spectral_colors[4], 'flip + rotation + scale': spectral_colors[6]} +augmentation_shapes = {'-': 'o', 'flip': 's', 'flip+rotation': '^', 'flip + rotation + scale': '*'} +augmentation_shapes_same = {'-': 'o', 'flip': 'o', 'flip+rotation': 'o', 'flip + rotation + scale': 'o'} + +# MODELS +models_colors = {'resnet18': 'orange', 'resnet18.a2_in1k': 'blue'} +models_shapes = {'resnet18': 'o', 'resnet18.a2_in1k': 's'} \ No newline at end of file diff --git a/evaluation/visualization.py b/evaluation/visualization.py new file mode 100644 index 0000000..afc55d0 --- /dev/null +++ b/evaluation/visualization.py @@ -0,0 +1,140 @@ +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +from visualization_functions import plt_vis, sns_vis +from parameters import data_colors, data_shapes, methods_colors, methods_shapes, augmentation_colors, augmentation_shapes, augmentation_shapes_same, models_colors, models_shapes + +MODEL_COMPARISON = False + +AUGMENTATION_COMPARISON = False +DATASET_COMPARISON = False +METHOD_COMPARISON = False + +SM_GI_COMPARISION = False +INPUT_COMPARISION = True +ENSEMBLE_COMPARISION = False +BEST_MODELS_COMPARISION = False + +model_name = "_resnet18_out2" +folder_models = "comparision_models_graphs/" +folder_all_attribrutes = "comparision_all_attributes_graphs/" +folder_N1_N2_N3 = "comparision_N1_N2_N3_graphs/" +folder_inputs_comp = "comparision_inputs_graphs/" +folder_ensembles = "comparision_ensembles_graphs/" + +if MODEL_COMPARISON: # visualize resnet18 vs resnet18.a2_in1k + table1 = 'models.xlsx' + data_real = pd.read_excel(table1) + + resnet18_var_subtable = data_real[['Model name', 'Test accuracy', 'Balanced test accuracy', 'Dataset name', 'Method']].dropna() + + sns_vis(resnet18_var_subtable, 'Method', 'Test accuracy', 'Model name', 'Dataset name', models_colors, data_shapes, "Model comparison") + plt.savefig(folder_models + "model_comparison_TA_method.png", dpi=300, bbox_inches='tight') + sns_vis(resnet18_var_subtable, 'Method', 'Balanced test accuracy', 'Model name', 'Dataset name', models_colors, data_shapes, "Model comparison") + plt.savefig(folder_models + "model_comparison_BTA_method.png", dpi=300, bbox_inches='tight') + sns_vis(resnet18_var_subtable, 'Dataset name', 'Test accuracy', 'Model name', 'Method', models_colors, methods_shapes, "Model comparison") + plt.savefig(folder_models + "model_comparison_TA_dataset.png", dpi=300, bbox_inches='tight') + sns_vis(resnet18_var_subtable, 'Dataset name', 'Balanced test accuracy', 'Model name', 'Method', models_colors, methods_shapes, "Model comparison") + plt.savefig(folder_models + "model_comparison_BTA_dataset.png", dpi=300, bbox_inches='tight') + +table2 = 'resnet18.xlsx' +resnet18_var_data = pd.read_excel(table2) +resnet18_var_subtable = resnet18_var_data[['Augmentation', 'Test accuracy', 'Balanced test accuracy', 'Dataset name', 'Method']].dropna() + +if AUGMENTATION_COMPARISON: + #sns_vis(resnet18_var_subtable, 'Augmentation', 'Test accuracy', 'Method', 'Dataset name', methods_colors, data_shapes, "Augmentation comparison") + #sns_vis(resnet18_var_subtable, 'Augmentation', 'Balanced test accuracy', 'Method', 'Dataset name', methods_colors, data_shapes, "Augmentation comparison") + + sns_vis(resnet18_var_subtable, 'Augmentation', 'Test accuracy', 'Dataset name', 'Method', data_colors, methods_shapes, "Augmentation comparison", "out") + plt.savefig(folder_all_attribrutes + "augmentation_comparison_TA"+ model_name +".png", dpi=300, bbox_inches='tight') + sns_vis(resnet18_var_subtable, 'Augmentation', 'Balanced test accuracy', 'Dataset name', 'Method', data_colors, methods_shapes, "Augmentation comparison", "out") + plt.savefig(folder_all_attribrutes + "augmentation_comparison_BTA"+ model_name +".png", dpi=300, bbox_inches='tight') + +if DATASET_COMPARISON: + #sns_vis(resnet18_var_subtable, 'Dataset name', 'Test accuracy', 'Method', 'Augmentation', methods_colors, augmentation_shapes, "Dataset comparison") + #sns_vis(resnet18_var_subtable, 'Dataset name', 'Balanced test accuracy', 'Method', 'Augmentation', methods_colors, augmentation_shapes, "Dataset comparison") + + sns_vis(resnet18_var_subtable, 'Dataset name', 'Test accuracy', 'Augmentation', 'Method', augmentation_colors, methods_shapes, "Dataset comparison", "out") + plt.savefig(folder_all_attribrutes + "dataset_comparison_TA"+ model_name +".png", dpi=300, bbox_inches='tight') + sns_vis(resnet18_var_subtable, 'Dataset name', 'Balanced test accuracy', 'Augmentation', 'Method', augmentation_colors, methods_shapes, "Dataset comparison", "out") + plt.savefig(folder_all_attribrutes + "dataset_comparison_BTA"+ model_name +".png", dpi=300, bbox_inches='tight') + +if METHOD_COMPARISON: + #sns_vis(resnet18_var_subtable, 'Method', 'Test accuracy', 'Dataset name', 'Augmentation', data_colors, augmentation_shapes, "Method comparison") + #sns_vis(resnet18_var_subtable, 'Method', 'Balanced test accuracy', 'Dataset name', 'Augmentation', data_colors, augmentation_shapes, "Dataset comparison") + + sns_vis(resnet18_var_subtable, 'Method', 'Test accuracy', 'Augmentation', 'Dataset name', augmentation_colors, data_shapes, "Method comparison", "out") + plt.savefig(folder_all_attribrutes + "method_comparison_TA"+ model_name +".png", dpi=300, bbox_inches='tight') + sns_vis(resnet18_var_subtable, 'Method', 'Balanced test accuracy', 'Augmentation', 'Dataset name', augmentation_colors, data_shapes, "Method comparison", "out") + plt.savefig(folder_all_attribrutes + "method_comparison_BTA"+ model_name +".png", dpi=300, bbox_inches='tight') + +if SM_GI_COMPARISION: + table3 = "comparision_N1_N2_N3.xlsx" + comparision_n1_n2_n3 = pd.read_excel(table3) + comparision_n1_n2_n3_subtable = comparision_n1_n2_n3[['Augmentation', 'Test accuracy', 'Balanced test accuracy', 'Train set', 'Test set', 'Method']].dropna() + + augmentation_types = comparision_n1_n2_n3_subtable['Augmentation'].unique() + subtables = {} + + for augm in augmentation_types: + subtables[augm] = comparision_n1_n2_n3_subtable[(comparision_n1_n2_n3_subtable['Augmentation'] == augm) & (comparision_n1_n2_n3_subtable['Method'] == 'SM-GI')] + sns_vis(subtables[augm], 'Train set', 'Balanced test accuracy', 'Test set', 'Method', data_colors, methods_shapes, "Dataset comparison", "out") + plt.savefig(folder_N1_N2_N3 + augm + "_ SM-GI_" + "_data_comparison_BTA.png", dpi=300, bbox_inches='tight') + + subtables[augm] = comparision_n1_n2_n3_subtable[(comparision_n1_n2_n3_subtable['Augmentation'] == augm) & (comparision_n1_n2_n3_subtable['Method'] == 'SM-GI-Avg')] + sns_vis(subtables[augm], 'Train set', 'Balanced test accuracy', 'Test set', 'Method', data_colors, methods_shapes, "Dataset comparison", "out") + plt.savefig(folder_N1_N2_N3 + augm + "_ SM-GI-AVG_" + "_data_comparison_BTA.png", dpi=300, bbox_inches='tight') + + subtables[augm] = comparision_n1_n2_n3_subtable[comparision_n1_n2_n3_subtable['Augmentation'] == augm] + sns_vis(subtables[augm], 'Train set', 'Balanced test accuracy', 'Test set', 'Method', data_colors, methods_shapes, "Dataset comparison", "out") + plt.savefig(folder_N1_N2_N3 + augm + "_data_comparison_BTA.png", dpi=300, bbox_inches='tight') + + +if INPUT_COMPARISION: + table4 = "comparision_inputs.xlsx" + comparision_inputs = pd.read_excel(table4) + comparision_inputs_subtable = comparision_inputs[['Augmentation', 'Balanced test accuracy', 'Method', 'Dataset name']].dropna() + + augmentation_types = comparision_inputs_subtable['Augmentation'].unique() + subtables = {} + + for augm in augmentation_types: + subtables[augm] = comparision_inputs_subtable[comparision_inputs_subtable['Augmentation'] == augm] + sns_vis(subtables[augm], 'Method', 'Balanced test accuracy', 'Dataset name', None, data_colors, None, legend = "out") + plt.savefig(folder_inputs_comp + augm + "_method_comparison_BTA_same_axis.png", dpi=300, bbox_inches='tight') + + sns_vis(comparision_inputs_subtable, 'Method', 'Balanced test accuracy', 'Dataset name', 'Augmentation', data_colors, augmentation_shapes, "Method comparison", "out") + plt.savefig(folder_inputs_comp + "all_aug_method_comparison_BTA_same_axis.png", dpi=300, bbox_inches='tight') + +if ENSEMBLE_COMPARISION: + table5 = "comparision_ensembles.xlsx" + comparision_ensembles = pd.read_excel(table5) + comparision_ensembles_subtable = comparision_ensembles[['Augmentation', 'Balanced test accuracy', 'Method', 'Dataset name']].dropna() + + augmentation_types = comparision_ensembles_subtable['Augmentation'].unique() + subtables = {} + + """for augm in augmentation_types: + subtables[augm] = comparision_ensembles_subtable[comparision_ensembles_subtable['Augmentation'] == augm] + sns_vis(subtables[augm], 'Method', 'Balanced test accuracy', 'Dataset name', None, data_colors, None, legend = "out") + plt.savefig(folder_ensembles + augm + "_ensemble_comparison_BTA.png", dpi=300, bbox_inches='tight')""" + + sns_vis(comparision_ensembles_subtable, 'Augmentation', 'Balanced test accuracy', 'Dataset name', 'Method', data_colors, methods_shapes, legend = "out") + plt.savefig(folder_ensembles + "all_aug_ensemble_comparison_BTA_4.png", dpi=300, bbox_inches='tight') + +if BEST_MODELS_COMPARISION: + table6 = "comparision_best_models.xlsx" + comparision_best_models = pd.read_excel(table6) + comparision_best_models_subtable = comparision_best_models[['Augmentation', 'Balanced test accuracy', 'Method', 'Dataset name']].dropna() + + sns_vis(comparision_best_models_subtable, 'Augmentation', 'Balanced test accuracy', 'Dataset name', 'Method', data_colors, methods_shapes, legend = "out") + plt.savefig("best_models_comparison_BTA.png", dpi=300, bbox_inches='tight') + + table7 = "comparision_best_models_aug_test_set.xlsx" + comparision_best_models_aug_test_set = pd.read_excel(table7) + comparision_best_models_aug_test_set_subtable = comparision_best_models_aug_test_set[['Augmentation', 'Balanced test accuracy', 'Method', 'Dataset name']].dropna() + + sns_vis(comparision_best_models_aug_test_set_subtable, 'Augmentation', 'Balanced test accuracy', 'Dataset name', 'Method', data_colors, methods_shapes, legend = "out") + plt.savefig("best_models_comparison_aug_test_set_BTA.png", dpi=300, bbox_inches='tight') +#plt.show() \ No newline at end of file diff --git a/evaluation/visualization_functions.py b/evaluation/visualization_functions.py new file mode 100644 index 0000000..8d6d598 --- /dev/null +++ b/evaluation/visualization_functions.py @@ -0,0 +1,78 @@ +import seaborn as sns +import matplotlib.pyplot as plt +from parameters import INDIVIDUALS + +def plt_vis(data, x_axis, y_axis1, color_attr, shape_attr, colors, shapes): + # Plot + plt.figure(figsize=(8, 6)) + for _, row in data.iterrows(): + plt.scatter(row[x_axis], row[y_axis1], + color=colors[row[color_attr]], + marker=shapes[row[shape_attr]], + label=f"{row[color_attr]} - {row[shape_attr]}") + + # Custom legend + handles, labels = plt.gca().get_legend_handles_labels() + by_label = dict(zip(labels, handles)) + plt.legend(by_label.values(), by_label.keys(), bbox_to_anchor=(1.05, 1), loc='upper left') + plt.xticks(rotation=20) + + plt.xlabel(x_axis) + plt.ylabel(y_axis1) + plt.title('Matplotlib Visualization') + plt.tight_layout() + +def sns_vis(data, x_axis, y_axis1, color_attr, shape_attr, colors, shapes, title = "", legend = "out"): + plt.figure(figsize=(8.5, 4)) #8.5, 7 + plt.rcParams.update({ + 'font.size': 14, # Global font size + 'axes.titlesize': 18, # Title font size + 'axes.labelsize': 16, # Axis labels font size + 'xtick.labelsize': 14, # X-axis tick labels + 'ytick.labelsize': 14, # Y-axis tick labels + 'legend.fontsize': 14 # Legend font size + }) + sns.scatterplot(data=data, x=x_axis, y=y_axis1, + hue=color_attr, style=shape_attr, + palette=colors, markers=shapes, s=100) + + plt.xlabel(x_axis) # plt.xlabel("Augmentation on the training set") + plt.ylabel(y_axis1) + plt.xticks(rotation=20) + plt.title(title) + if legend == "out": + plt.legend(bbox_to_anchor=(1.15, 1), loc='upper left') #bbox_to_anchor=(1.35, 1) + elif legend == "left": + plt.legend(loc='lower left', frameon=True) + elif legend == "right": + plt.legend(loc='lower right', frameon=True) + plt.tight_layout() + primary_ax = plt.gca() + primary_ax.set_ylim(0.89, 1.0) + + + # Secondary y-axis + def secondary_y_axis_transform(value): + return INDIVIDUALS * (1 - value) + + def secondary_y_axis_inverse(value): + return 1 - (value / INDIVIDUALS) + + primary_ax = plt.gca() + secondary_ax = primary_ax.twinx() + + # Set the secondary y-axis limits based on the primary y-axis + secondary_ax.set_ylim( + secondary_y_axis_transform(primary_ax.get_ylim()[0]), + secondary_y_axis_transform(primary_ax.get_ylim()[1]) + ) + secondary_ax.set_ylabel('Misclassified instances') + + # Synchronize primary and secondary y-axes + def sync_axes(ax): + primary_ax.set_ylim( + secondary_y_axis_inverse(ax.get_ylim()[0]), + secondary_y_axis_inverse(ax.get_ylim()[1]) + ) + + secondary_ax.callbacks.connect("ylim_changed", lambda ax: sync_axes(ax)) \ No newline at end of file -- GitLab