Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import json
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
# Load the data
file_path = "comparision_N1_N2_N3_graphs/slices_evaluation/resnet18_basic_N1_misclassified_test_set_N2_averagingslices_ev.json"
import json
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
# Load the data
def load_data(all_slices_path, mistake_slices_path):
with open(all_slices_path, 'r') as all_file:
all_data = json.load(all_file)
with open(mistake_slices_path, 'r') as mistake_file:
mistake_data = json.load(mistake_file)
# Remove misclassified slices to get good slices
good_positions = {}
mistake_positions = {}
for plane in all_data:
all_slices = all_data[plane]
misclassified_slices = mistake_data['mistake_positions'].get(plane, [])
good_positions[plane] = []
temp_misclassified = misclassified_slices.copy()
for slice_pos in all_slices:
if slice_pos in temp_misclassified:
temp_misclassified.remove(slice_pos)
else:
good_positions[plane].append(slice_pos)
mistake_positions[plane] = misclassified_slices
return good_positions, mistake_positions
def prepare_data(positions, label):
counts = pd.Series(positions).value_counts().reset_index()
counts.columns = ['x_value', 'count']
counts['label'] = label
return counts
# Load slices from the JSON files
all_slices_path = 'all_slices_N3.json'
mistake_slices_path = "comparision_N1_N2_N3_graphs/slices_evaluation/resnet18_basic_N1_misclassified_test_set_N3slices_ev.json"
good_data, mistake_data = load_data(all_slices_path, mistake_slices_path)
# Prepare data for each plane
dfs = []
for plane in ['sagittal', 'axial', 'coronial']:
good_df = prepare_data(good_data[plane], 'Good')
good_df['plane'] = plane
mistake_df = prepare_data(mistake_data[plane], 'Mistake')
mistake_df['plane'] = plane
#dfs.append(mistake_df)
dfs.append(pd.concat([good_df, mistake_df]))
# Combine all data
all_data = pd.concat(dfs)
# Plot settings
sns.set(style="whitegrid")
colors = {'Good': 'green', 'Mistake': 'red'}
# Create bar plots for each plane
for plane in ['sagittal', 'axial', 'coronial']:
plt.figure(figsize=(10, 6))
plane_data = all_data[all_data['plane'] == plane]
sns.barplot(data=plane_data, x='x_value', y='count', hue='label', palette=colors)
plt.title(f'{plane.capitalize()} Plane Bar Chart')
plt.xlabel('X Value')
plt.ylabel('Count')
plt.legend(title='Position Type')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()