Skip to content
Snippets Groups Projects
barchart_slices.py 2.58 KiB
Newer Older
  • Learn to ignore specific revisions
  • Vajay Mónika's avatar
    Vajay Mónika committed
    import json
    import seaborn as sns
    import matplotlib.pyplot as plt
    import pandas as pd
    
    # Load the data
    file_path = "comparision_N1_N2_N3_graphs/slices_evaluation/resnet18_basic_N1_misclassified_test_set_N2_averagingslices_ev.json"
    
    import json
    import seaborn as sns
    import matplotlib.pyplot as plt
    import pandas as pd
    
    # Load the data
    def load_data(all_slices_path, mistake_slices_path):
        with open(all_slices_path, 'r') as all_file:
            all_data = json.load(all_file)
    
        with open(mistake_slices_path, 'r') as mistake_file:
            mistake_data = json.load(mistake_file)
    
        # Remove misclassified slices to get good slices
        good_positions = {}
        mistake_positions = {}
    
        for plane in all_data:
            all_slices = all_data[plane]
            misclassified_slices = mistake_data['mistake_positions'].get(plane, [])
            good_positions[plane] = []
            temp_misclassified = misclassified_slices.copy()
            for slice_pos in all_slices:
                if slice_pos in temp_misclassified:
                    temp_misclassified.remove(slice_pos)
                else:
                    good_positions[plane].append(slice_pos)
            mistake_positions[plane] = misclassified_slices
    
        return good_positions, mistake_positions
    
    
    def prepare_data(positions, label):
        counts = pd.Series(positions).value_counts().reset_index()
        counts.columns = ['x_value', 'count']
        counts['label'] = label
        return counts
    
    # Load slices from the JSON files
    all_slices_path = 'all_slices_N3.json'  
    mistake_slices_path = "comparision_N1_N2_N3_graphs/slices_evaluation/resnet18_basic_N1_misclassified_test_set_N3slices_ev.json"
    good_data, mistake_data = load_data(all_slices_path, mistake_slices_path)
    
    # Prepare data for each plane
    dfs = []
    for plane in ['sagittal', 'axial', 'coronial']:
        good_df = prepare_data(good_data[plane], 'Good')
        good_df['plane'] = plane
        mistake_df = prepare_data(mistake_data[plane], 'Mistake')
        mistake_df['plane'] = plane
        #dfs.append(mistake_df)
        dfs.append(pd.concat([good_df, mistake_df]))
    
    # Combine all data
    all_data = pd.concat(dfs)
    
    # Plot settings
    sns.set(style="whitegrid")
    colors = {'Good': 'green', 'Mistake': 'red'}
    
    # Create bar plots for each plane
    for plane in ['sagittal', 'axial', 'coronial']:
        plt.figure(figsize=(10, 6))
        plane_data = all_data[all_data['plane'] == plane]
        sns.barplot(data=plane_data, x='x_value', y='count', hue='label', palette=colors)
        plt.title(f'{plane.capitalize()} Plane Bar Chart')
        plt.xlabel('X Value')
        plt.ylabel('Count')
        plt.legend(title='Position Type')
        plt.xticks(rotation=45)
        plt.tight_layout()
    
    plt.show()