Skip to content
Snippets Groups Projects
visualization.py 9.51 KiB
Newer Older
  • Learn to ignore specific revisions
  • Vajay Mónika's avatar
    Vajay Mónika committed
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    
    from visualization_functions import plt_vis, sns_vis
    from parameters import data_colors, data_shapes, methods_colors, methods_shapes, augmentation_colors, augmentation_shapes, augmentation_shapes_same, models_colors, models_shapes
    
    MODEL_COMPARISON = False
    
    AUGMENTATION_COMPARISON = False
    DATASET_COMPARISON = False
    METHOD_COMPARISON = False
    
    SM_GI_COMPARISION = False
    INPUT_COMPARISION = True
    ENSEMBLE_COMPARISION = False
    BEST_MODELS_COMPARISION = False
    
    model_name = "_resnet18_out2"
    folder_models = "comparision_models_graphs/"
    folder_all_attribrutes = "comparision_all_attributes_graphs/"
    folder_N1_N2_N3 = "comparision_N1_N2_N3_graphs/"
    folder_inputs_comp = "comparision_inputs_graphs/"
    folder_ensembles = "comparision_ensembles_graphs/"
    
    if MODEL_COMPARISON:    # visualize resnet18 vs resnet18.a2_in1k
        table1 = 'models.xlsx'  
        data_real = pd.read_excel(table1)
    
        resnet18_var_subtable = data_real[['Model name', 'Test accuracy', 'Balanced test accuracy', 'Dataset name', 'Method']].dropna()
    
        sns_vis(resnet18_var_subtable, 'Method', 'Test accuracy', 'Model name', 'Dataset name', models_colors, data_shapes, "Model comparison")
        plt.savefig(folder_models + "model_comparison_TA_method.png", dpi=300, bbox_inches='tight')
        sns_vis(resnet18_var_subtable, 'Method', 'Balanced test accuracy', 'Model name', 'Dataset name', models_colors, data_shapes, "Model comparison")
        plt.savefig(folder_models + "model_comparison_BTA_method.png", dpi=300, bbox_inches='tight')
        sns_vis(resnet18_var_subtable, 'Dataset name', 'Test accuracy', 'Model name', 'Method', models_colors, methods_shapes, "Model comparison")
        plt.savefig(folder_models + "model_comparison_TA_dataset.png", dpi=300, bbox_inches='tight')
        sns_vis(resnet18_var_subtable, 'Dataset name', 'Balanced test accuracy', 'Model name', 'Method', models_colors, methods_shapes, "Model comparison")
        plt.savefig(folder_models + "model_comparison_BTA_dataset.png", dpi=300, bbox_inches='tight')
    
    table2 = 'resnet18.xlsx'  
    resnet18_var_data = pd.read_excel(table2)
    resnet18_var_subtable = resnet18_var_data[['Augmentation', 'Test accuracy', 'Balanced test accuracy', 'Dataset name', 'Method']].dropna()
    
    if AUGMENTATION_COMPARISON:
        #sns_vis(resnet18_var_subtable, 'Augmentation', 'Test accuracy', 'Method', 'Dataset name', methods_colors, data_shapes, "Augmentation comparison")
        #sns_vis(resnet18_var_subtable, 'Augmentation', 'Balanced test accuracy', 'Method', 'Dataset name', methods_colors, data_shapes, "Augmentation comparison")
    
        sns_vis(resnet18_var_subtable, 'Augmentation', 'Test accuracy', 'Dataset name', 'Method', data_colors, methods_shapes, "Augmentation comparison", "out")
        plt.savefig(folder_all_attribrutes + "augmentation_comparison_TA"+ model_name +".png", dpi=300, bbox_inches='tight')
        sns_vis(resnet18_var_subtable, 'Augmentation', 'Balanced test accuracy', 'Dataset name', 'Method', data_colors, methods_shapes, "Augmentation comparison", "out")
        plt.savefig(folder_all_attribrutes + "augmentation_comparison_BTA"+ model_name +".png", dpi=300, bbox_inches='tight')
    
    if DATASET_COMPARISON:
        #sns_vis(resnet18_var_subtable, 'Dataset name', 'Test accuracy', 'Method', 'Augmentation', methods_colors, augmentation_shapes, "Dataset comparison")
        #sns_vis(resnet18_var_subtable, 'Dataset name', 'Balanced test accuracy', 'Method', 'Augmentation', methods_colors, augmentation_shapes, "Dataset comparison")
    
        sns_vis(resnet18_var_subtable, 'Dataset name', 'Test accuracy', 'Augmentation', 'Method', augmentation_colors, methods_shapes, "Dataset comparison", "out")
        plt.savefig(folder_all_attribrutes + "dataset_comparison_TA"+ model_name +".png", dpi=300, bbox_inches='tight')
        sns_vis(resnet18_var_subtable, 'Dataset name', 'Balanced test accuracy', 'Augmentation', 'Method', augmentation_colors, methods_shapes, "Dataset comparison", "out")
        plt.savefig(folder_all_attribrutes + "dataset_comparison_BTA"+ model_name +".png", dpi=300, bbox_inches='tight')
    
    if METHOD_COMPARISON:
       #sns_vis(resnet18_var_subtable, 'Method', 'Test accuracy', 'Dataset name', 'Augmentation', data_colors, augmentation_shapes, "Method comparison")
       #sns_vis(resnet18_var_subtable, 'Method', 'Balanced test accuracy', 'Dataset name', 'Augmentation', data_colors, augmentation_shapes, "Dataset comparison")
    
       sns_vis(resnet18_var_subtable, 'Method', 'Test accuracy', 'Augmentation', 'Dataset name', augmentation_colors, data_shapes, "Method comparison", "out")
       plt.savefig(folder_all_attribrutes + "method_comparison_TA"+ model_name +".png", dpi=300, bbox_inches='tight')
       sns_vis(resnet18_var_subtable, 'Method', 'Balanced test accuracy', 'Augmentation', 'Dataset name', augmentation_colors, data_shapes, "Method comparison", "out")
       plt.savefig(folder_all_attribrutes + "method_comparison_BTA"+ model_name +".png", dpi=300, bbox_inches='tight')
     
    if SM_GI_COMPARISION:
        table3 = "comparision_N1_N2_N3.xlsx"
        comparision_n1_n2_n3 = pd.read_excel(table3)
        comparision_n1_n2_n3_subtable = comparision_n1_n2_n3[['Augmentation', 'Test accuracy', 'Balanced test accuracy', 'Train set', 'Test set', 'Method']].dropna()
    
        augmentation_types = comparision_n1_n2_n3_subtable['Augmentation'].unique()
        subtables = {}
    
        for augm in augmentation_types:
            subtables[augm] = comparision_n1_n2_n3_subtable[(comparision_n1_n2_n3_subtable['Augmentation'] == augm) & (comparision_n1_n2_n3_subtable['Method'] == 'SM-GI')]
            sns_vis(subtables[augm], 'Train set', 'Balanced test accuracy', 'Test set', 'Method', data_colors, methods_shapes, "Dataset comparison", "out")
            plt.savefig(folder_N1_N2_N3 + augm + "_ SM-GI_" + "_data_comparison_BTA.png", dpi=300, bbox_inches='tight')
            
            subtables[augm] = comparision_n1_n2_n3_subtable[(comparision_n1_n2_n3_subtable['Augmentation'] == augm) & (comparision_n1_n2_n3_subtable['Method'] == 'SM-GI-Avg')]
            sns_vis(subtables[augm], 'Train set', 'Balanced test accuracy', 'Test set', 'Method', data_colors, methods_shapes, "Dataset comparison", "out")
            plt.savefig(folder_N1_N2_N3 + augm + "_ SM-GI-AVG_" + "_data_comparison_BTA.png", dpi=300, bbox_inches='tight')
    
            subtables[augm] = comparision_n1_n2_n3_subtable[comparision_n1_n2_n3_subtable['Augmentation'] == augm]
            sns_vis(subtables[augm], 'Train set', 'Balanced test accuracy', 'Test set', 'Method', data_colors, methods_shapes, "Dataset comparison", "out")
            plt.savefig(folder_N1_N2_N3 + augm + "_data_comparison_BTA.png", dpi=300, bbox_inches='tight')
    
     
    if INPUT_COMPARISION:
        table4 = "comparision_inputs.xlsx"
        comparision_inputs = pd.read_excel(table4)
        comparision_inputs_subtable = comparision_inputs[['Augmentation', 'Balanced test accuracy', 'Method', 'Dataset name']].dropna()
    
        augmentation_types = comparision_inputs_subtable['Augmentation'].unique()
        subtables = {}
    
        for augm in augmentation_types:
            subtables[augm] = comparision_inputs_subtable[comparision_inputs_subtable['Augmentation'] == augm]
            sns_vis(subtables[augm], 'Method', 'Balanced test accuracy', 'Dataset name', None, data_colors, None, legend = "out")
            plt.savefig(folder_inputs_comp + augm + "_method_comparison_BTA_same_axis.png", dpi=300, bbox_inches='tight')
        
        sns_vis(comparision_inputs_subtable, 'Method', 'Balanced test accuracy', 'Dataset name', 'Augmentation', data_colors, augmentation_shapes, "Method comparison", "out")
        plt.savefig(folder_inputs_comp + "all_aug_method_comparison_BTA_same_axis.png", dpi=300, bbox_inches='tight')
    
    if ENSEMBLE_COMPARISION:
        table5 = "comparision_ensembles.xlsx"
        comparision_ensembles = pd.read_excel(table5)
        comparision_ensembles_subtable = comparision_ensembles[['Augmentation', 'Balanced test accuracy', 'Method', 'Dataset name']].dropna()
    
        augmentation_types = comparision_ensembles_subtable['Augmentation'].unique()
        subtables = {}
    
        """for augm in augmentation_types:
            subtables[augm] = comparision_ensembles_subtable[comparision_ensembles_subtable['Augmentation'] == augm]
            sns_vis(subtables[augm], 'Method', 'Balanced test accuracy', 'Dataset name', None, data_colors, None, legend = "out")
            plt.savefig(folder_ensembles + augm + "_ensemble_comparison_BTA.png", dpi=300, bbox_inches='tight')"""
        
        sns_vis(comparision_ensembles_subtable, 'Augmentation', 'Balanced test accuracy', 'Dataset name', 'Method', data_colors, methods_shapes, legend = "out")
        plt.savefig(folder_ensembles + "all_aug_ensemble_comparison_BTA_4.png", dpi=300, bbox_inches='tight')
    
    if BEST_MODELS_COMPARISION:
        table6 = "comparision_best_models.xlsx"
        comparision_best_models = pd.read_excel(table6)
        comparision_best_models_subtable = comparision_best_models[['Augmentation', 'Balanced test accuracy', 'Method', 'Dataset name']].dropna()
        
        sns_vis(comparision_best_models_subtable, 'Augmentation', 'Balanced test accuracy', 'Dataset name', 'Method', data_colors, methods_shapes, legend = "out")
        plt.savefig("best_models_comparison_BTA.png", dpi=300, bbox_inches='tight')
    
        table7 = "comparision_best_models_aug_test_set.xlsx"
        comparision_best_models_aug_test_set = pd.read_excel(table7)
        comparision_best_models_aug_test_set_subtable = comparision_best_models_aug_test_set[['Augmentation', 'Balanced test accuracy', 'Method', 'Dataset name']].dropna()
        
        sns_vis(comparision_best_models_aug_test_set_subtable, 'Augmentation', 'Balanced test accuracy', 'Dataset name', 'Method', data_colors, methods_shapes, legend = "out")
        plt.savefig("best_models_comparison_aug_test_set_BTA.png", dpi=300, bbox_inches='tight')
    #plt.show()