Skip to content
Snippets Groups Projects
visualization_functions.py 2.76 KiB
Newer Older
  • Learn to ignore specific revisions
  • Vajay Mónika's avatar
    Vajay Mónika committed
    import seaborn as sns
    import matplotlib.pyplot as plt
    from parameters import INDIVIDUALS
    
    def plt_vis(data, x_axis, y_axis1, color_attr, shape_attr, colors, shapes):
        # Plot
        plt.figure(figsize=(8, 6))
        for _, row in data.iterrows():
            plt.scatter(row[x_axis], row[y_axis1], 
                        color=colors[row[color_attr]], 
                        marker=shapes[row[shape_attr]], 
                        label=f"{row[color_attr]} - {row[shape_attr]}")
    
        # Custom legend
        handles, labels = plt.gca().get_legend_handles_labels()
        by_label = dict(zip(labels, handles))
        plt.legend(by_label.values(), by_label.keys(), bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.xticks(rotation=20)
    
        plt.xlabel(x_axis)
        plt.ylabel(y_axis1)
        plt.title('Matplotlib Visualization')
        plt.tight_layout()
    
    def sns_vis(data, x_axis, y_axis1, color_attr, shape_attr, colors, shapes, title = "", legend = "out"):
        plt.figure(figsize=(8.5, 4)) #8.5, 7
        plt.rcParams.update({
            'font.size': 14,         # Global font size
            'axes.titlesize': 18,    # Title font size
            'axes.labelsize': 16,    # Axis labels font size
            'xtick.labelsize': 14,   # X-axis tick labels
            'ytick.labelsize': 14,   # Y-axis tick labels
            'legend.fontsize': 14    # Legend font size
        })
        sns.scatterplot(data=data, x=x_axis, y=y_axis1, 
                        hue=color_attr, style=shape_attr, 
                        palette=colors, markers=shapes, s=100)
    
        plt.xlabel(x_axis) # plt.xlabel("Augmentation on the training set")
        plt.ylabel(y_axis1)
        plt.xticks(rotation=20)
        plt.title(title)
        if legend == "out":
            plt.legend(bbox_to_anchor=(1.15, 1), loc='upper left') #bbox_to_anchor=(1.35, 1)
        elif legend == "left":
            plt.legend(loc='lower left', frameon=True)
        elif legend == "right":
            plt.legend(loc='lower right', frameon=True)
        plt.tight_layout()
        primary_ax = plt.gca()
        primary_ax.set_ylim(0.89, 1.0)
    
    
        # Secondary y-axis
        def secondary_y_axis_transform(value):
            return INDIVIDUALS * (1 - value)
    
        def secondary_y_axis_inverse(value):
            return 1 - (value / INDIVIDUALS)
    
        primary_ax = plt.gca()
        secondary_ax = primary_ax.twinx()
        
        # Set the secondary y-axis limits based on the primary y-axis
        secondary_ax.set_ylim(
            secondary_y_axis_transform(primary_ax.get_ylim()[0]),
            secondary_y_axis_transform(primary_ax.get_ylim()[1])
        )
        secondary_ax.set_ylabel('Misclassified instances')
    
        # Synchronize primary and secondary y-axes
        def sync_axes(ax):
            primary_ax.set_ylim(
                secondary_y_axis_inverse(ax.get_ylim()[0]),
                secondary_y_axis_inverse(ax.get_ylim()[1])
            )
    
        secondary_ax.callbacks.connect("ylim_changed", lambda ax: sync_axes(ax))