Source code for pyreal.visualize.feature_based_vis

import math

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from pyreal.explanation_types import FeatureContributionExplanation, FeatureImportanceExplanation
from pyreal.realapp import realapp
from pyreal.utils import get_top_contributors
from pyreal.visualize.visualize_config import (
    NEGATIVE_COLOR,
    NEGATIVE_COLOR_LIGHT,
    NEUTRAL_COLOR,
    PALETTE_CMAP,
    POSITIVE_COLOR,
    POSITIVE_COLOR_LIGHT,
)


def _parse_multi_contribution(explanation):
    if isinstance(explanation, FeatureContributionExplanation):
        contributions = explanation.get()
        values = explanation.get_values()
    else:
        contribution_list = [explanation[i]["Contribution"] for i in explanation]
        value_list = [explanation[i]["Feature Value"] for i in explanation]
        feature_list = explanation[next(iter(explanation))]["Feature Name"].values
        contributions = pd.DataFrame(contribution_list)
        contributions.columns = feature_list
        values = pd.DataFrame(value_list)
        values.columns = feature_list
    return contributions, values


[docs]def feature_bar_plot( explanation, select_by="absolute", num_features=5, transparent=False, flip_colors=False, precision=2, prediction=None, include_averages=False, include_axis=True, show=False, filename=None, **kwargs ): """ Plot the most contributing features Args: explanation (DataFrame or FeatureBased): One output DataFrame from RealApp.produce_feature_contributions or RealApp.prepare_feature_importance OR FeatureBased explanation object select_by (one of "absolute", "max", "min"): Method to use when selecting features. num_features (int): Number of features to plot transparent (Boolean): If True, the background of the figure is set to transparent. flip_colors (Boolean): If True, make the positive explanation red and negative explanation blue. Useful if the target variable has a negative connotation precision (int): Number of decimal places to print for numeric float values prediction (numeric or string): Prediction to display in the title include_averages (Boolean): If True, include the mean values in the visualization (if provided in explanation) include_axis (Boolean): If True, include the contribution axis show (Boolean): Show the figure filename (string or None): If not None, save the figure as filename **kwargs: Additional parameters to pass into plt.barh Returns: pyplot figure Bar plot of top contributors """ if isinstance(explanation, FeatureContributionExplanation): explanation = realapp.format_feature_contribution_output(explanation) explanation = explanation[next(iter(explanation))] elif isinstance(explanation, FeatureImportanceExplanation): explanation = realapp.format_feature_importance_output(explanation) if isinstance(explanation, dict): raise ValueError( "Invalid explanation. Expected feature contribution explanation on a single instance" " or feature importance explanation. If you are passing in an explanation from" " RealApp.produce_feature_contributions(), please index to get a single instance, ie" " explanation[0]." ) if not isinstance(explanation, pd.DataFrame): raise ValueError( "Invalid explanation type, expected DataFrame or" " FeatureContributionExplanation/FeatureImportanceExplanation object" ) explanation = get_top_contributors(explanation, num_features=num_features, select_by=select_by) features = explanation["Feature Name"].to_numpy() if "Feature Value" in explanation: values = explanation["Feature Value"].to_numpy() if include_averages and "Average/Mode" in explanation: averages = explanation["Average/Mode"].to_numpy() features = np.array( [ ( "%s - %.*g (mean: %.*g)" % (features[i], precision, values[i], precision, averages[i]) if isinstance(values[i], (float, np.float, int, np.integer)) else "%s - %s (mode: %s)" % (features[i], values[i], averages[i]) ) for i in range(len(features)) ] ) else: features = np.array( [ ( "%s (%.*f)" % (features[i], precision, values[i]) if isinstance(values[i], float) else "%s (%s)" % (features[i], values[i]) ) for i in range(len(features)) ] ) are_importances = False if "Contribution" in explanation: contributions = explanation["Contribution"] elif "Importance" in explanation: contributions = explanation["Importance"] are_importances = True else: raise ValueError("Provided DataFrame has neither Contribution nor Importance column") if contributions.ndim == 2: contributions = contributions.iloc[0] contributions = contributions.to_numpy() if not flip_colors: colors = [NEGATIVE_COLOR if (c < 0) else POSITIVE_COLOR for c in contributions[::-1]] else: colors = [POSITIVE_COLOR if (c < 0) else NEGATIVE_COLOR for c in contributions[::-1]] if transparent: _, ax = plt.subplots() else: _, ax = plt.subplots(facecolor="w") plt.barh(features[::-1], contributions[::-1], color=colors, **kwargs) if are_importances: title = "Feature Importance Scores" else: title = "Feature Contributions" plt.title(title, fontsize=18) if prediction is not None: plt.title("Overall prediction: %s" % prediction, fontsize=12) plt.suptitle(title, fontsize=18, y=1) if include_axis: plt.tick_params(axis="x", which="both", bottom=True, top=False, labelbottom=True) if are_importances: plt.xlabel("Importance") else: plt.xlabel("Contribution") else: plt.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False) ax.spines["top"].set_visible(False) ax.spines["bottom"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["left"].set_visible(False) ax.axvline(x=0, color="black") if filename is not None: plt.savefig(filename, bbox_inches="tight") if show: plt.show()
[docs]def strip_plot( explanation, type="strip", num_features=5, discrete=False, show=False, filename=None, marker_size=3, palette=None, show_legend=True, **kwargs ): """ Generates a strip plot (type="strip") or a swarm plot (type="swarm") from a set of feature contributions. Args: explanation (DataFrame or FeatureBased): One output DataFrame from RealApp.produce_feature_contributions OR FeatureContributions explanation object type (String, one of ["strip", "swarm"]: The type of plot to generate num_features (int): Number of features to show discrete (Boolean): If true, give discrete legends for each row. Otherwise, give a colorbar legend show (Boolean): If True, show the figure filename (string or None): If not None, save the figure as filename marker_size (int): Size of markers to use in plot palette (seaborn palette name, list, or dict): Colors to use in the plot. See seaborn.color_palette for more info show_legend (Boolean): If False, hide the legend **kwargs: Additional arguments to pass to seaborn.swarmplot or seaborn.stripplot """ contributions, values = _parse_multi_contribution(explanation) average_importance = np.mean(abs(contributions), axis=0) order = np.argsort(average_importance)[::-1] num_cats = [] if discrete and show_legend: legend = "brief" else: legend = False generate_palette = palette is None for i in range(num_features): hues = values.iloc[:, order[i : i + 1]] hues = hues.melt()["value"] num_colors = len(np.unique(hues.astype("str"))) if hues.isna().any(): num_colors -= 1 if generate_palette: palette = sns.blend_palette( [NEGATIVE_COLOR_LIGHT, NEUTRAL_COLOR, POSITIVE_COLOR_LIGHT], n_colors=num_colors ) if "size" in kwargs: marker_size = kwargs["size"] kwargs.pop("size", None) if type == "strip": ax = sns.stripplot( x="value", y="variable", hue=hues, data=contributions.iloc[:, order[i : i + 1]].melt(), palette=palette, legend=legend, size=marker_size, **kwargs ) elif type == "swarm": ax = sns.swarmplot( x="value", y="variable", hue=hues, data=contributions.iloc[:, order[i : i + 1]].melt(), palette=palette, legend=legend, size=marker_size, **kwargs ) else: raise ValueError("Invalid type %s. Type must be one of [strip, swarm]." % type) handles, labels = ax.get_legend_handles_labels() num_cats.append(len(labels) - sum(num_cats)) plt.axvline(x=0, color="black", linewidth=1) ax.grid(axis="y") ax.set_ylabel("") ax.set_xlabel("Contributions") ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["left"].set_visible(False) if show_legend: legends = [] if discrete: handles, labels = ax.get_legend_handles_labels() shift = 1 / len(num_cats) r = 0 for i in range(0, len(num_cats)): if num_cats[i] <= 5: l1 = ax.legend( handles[r : r + num_cats[i]], labels[r : r + num_cats[i]], bbox_to_anchor=(1, 1 - (i * shift)), loc="upper left", ncol=num_cats[i], labelspacing=0.2, columnspacing=0.2, handletextpad=0.1, frameon=False, ) legends.append(l1) else: step = math.ceil(num_cats[i] / 5) l1 = ax.legend( handles[r : r + num_cats[i] : step], labels[r : r + num_cats[i] : step], bbox_to_anchor=(1, 1 - (i * shift)), loc="upper left", ncol=num_cats[i], labelspacing=0.2, columnspacing=0.2, handletextpad=0.1, frameon=False, ) legends.append(l1) r += num_cats[i] for labels in legends[:-1]: ax.add_artist(labels) else: ax = plt.gca() norm = plt.Normalize(0, 1) sm = plt.cm.ScalarMappable(cmap=PALETTE_CMAP, norm=norm) sm.set_array([]) cbar = ax.figure.colorbar(sm) cbar.ax.get_yaxis().set_ticks([]) cbar.ax.text(1.5, 0.05, "low", ha="left", va="center") cbar.ax.text(1.5, 0.95, "high", ha="left", va="center") cbar.ax.set_ylabel("Feature Value", rotation=270) cbar.ax.get_yaxis().labelpad = 15 if filename is not None: plt.gcf().savefig(filename, bbox_extra_artists=legends, bbox_inches="tight") if show: plt.show()
[docs]def feature_scatter_plot( explanation, feature, predictions=None, discrete=None, show=False, filename=None, palette=None, marker_alpha=0.5, marker_size=3, **kwargs ): """ Plot a contribution scatter plot for one feature Args: explanation (DataFrame or FeatureBased): One output DataFrame from RealApp.produce_feature_contributions OR FeatureContributions explanation object feature (column label): Label of column to visualize predictions (array-like of length n_instances): Predictions corresponding to explained instances discrete (Boolean): If true, plot x as discrete data. Defaults to True if x is not numeric. show (Boolean): If True, show the figure filename (string or None): If not None, save the figure as filename palette (seaborn palette name, list, or dict): Colors to use in the plot. See seaborn.color_palette for more info marker_alpha (float between (0,1]): Alpha value to use for markers marker_size (int): Size to use for markers **kwargs: Additional arguments to pass into seaborn.stripplot or seaborn.scatterplot """ contributions, values = _parse_multi_contribution(explanation) contributions = contributions[feature] values = values[feature] if isinstance(predictions, dict): predictions = np.array([predictions[i] for i in predictions]).reshape(-1) legend_type = "discrete" if predictions is None: legend_type = "none" predictions = np.zeros_like(contributions) data = pd.DataFrame( {"Contribution": contributions.values, "Value": values.values, "Prediction": predictions} ) num_colors = len(np.unique(predictions.astype("str"))) if palette is None: palette = sns.blend_palette( [NEGATIVE_COLOR_LIGHT, NEUTRAL_COLOR, POSITIVE_COLOR_LIGHT], n_colors=num_colors ) if ( legend_type != "none" and isinstance(predictions[0], float) or (isinstance(predictions[0], int) and num_colors > 6) ): legend_type = "continuous" plot_legend = False if legend_type == "discrete": plot_legend = True if discrete is None: discrete = not pd.api.types.is_numeric_dtype(values) if discrete: ax = sns.stripplot( x="Value", y="Contribution", data=data, hue="Prediction", palette=palette, legend=plot_legend, alpha=marker_alpha, size=marker_size, zorder=0, **kwargs ) else: ax = sns.scatterplot( x="Value", y="Contribution", data=data, hue="Prediction", palette=palette, legend=plot_legend, alpha=marker_alpha, sizes=marker_size, **kwargs ) plt.axhline(0, color="black", zorder=0) plt.xlabel("Values for %s" % feature) if legend_type == "continuous": norm = plt.Normalize(0, 1) sm = plt.cm.ScalarMappable(cmap=PALETTE_CMAP, norm=norm) if discrete: plt.xticks(rotation=45, ha="right") min_val = predictions.min() max_val = predictions.max() sm.set_array([]) cbar = ax.figure.colorbar(sm) cbar.ax.get_yaxis().set_ticks([]) cbar.ax.text(1.5, 0.05, ("%.2f" % min_val).rstrip("0").rstrip("."), ha="left", va="center") cbar.ax.text(1.5, 0.95, ("%.2f" % max_val).rstrip("0").rstrip("."), ha="left", va="center") cbar.ax.set_ylabel("Prediction", rotation=270) cbar.ax.get_yaxis().labelpad = 15 if filename is not None: plt.gcf().savefig(filename, bbox_inches="tight") if show: plt.tight_layout() plt.show()