Source code for pyreal.explainers.gfi.global_feature_importance

import logging

from pyreal.explainers import (
    GlobalFeatureImportanceBase,
    PermutationFeatureImportance,
    ShapFeatureImportance,
)

log = logging.getLogger(__name__)


def choose_algorithm():
    """
    Choose an algorithm based on the model type.
    Currently, shap is the only supported algorithm

    Return:
        string (one of ["shap"])
            Explanation algorithm to use
    """
    return "shap"


[docs]class GlobalFeatureImportance(GlobalFeatureImportanceBase): """ Generic GlobalFeatureImportance wrapper A GlobalFeatureImportance object wraps multiple global feature-based explanations. If no specific algorithm is requested, one will be chosen based on the information given. Currently, only SHAP is supported. Args: model (string filepath or model object): Filepath to the pickled model to explain, or model object with .predict() function x_train_orig (dataframe of shape (n_instances, x_orig_feature_count)): The training set for the explainer e_algorithm (string, one of ["shap", "permutation"]): Explanation algorithm to use. If none, one will be chosen automatically based on model type shap_type (string, one of ["kernel", "linear"]): Type of shap algorithm to use, if e_algorithm="shap". **kwargs: see LocalFeatureContributionsBase args """
[docs] def __init__(self, model, x_train_orig=None, e_algorithm=None, shap_type=None, **kwargs): if e_algorithm is None: e_algorithm = choose_algorithm() self.base_global_feature_importance = None if e_algorithm == "shap": self.base_global_feature_importance = ShapFeatureImportance( model, x_train_orig, shap_type=shap_type, **kwargs ) if e_algorithm == "permutation": self.base_global_feature_importance = PermutationFeatureImportance( model, x_train_orig, **kwargs ) if self.base_global_feature_importance is None: raise ValueError("Invalid algorithm type %s" % e_algorithm) super(GlobalFeatureImportance, self).__init__(model, x_train_orig, **kwargs)
[docs] def fit(self, x_train_orig=None, y_train=None): """ Fit this explainer object Args: x_train_orig (DataFrame of shape (n_instances, n_features): Training set to fit on, required if not provided on initialization y_train: Targets of training set, required if not provided on initialization """ self.base_global_feature_importance.fit(x_train_orig, y_train) return self
def produce_explanation(self, **kwargs): """ Gets the raw explanation. Returns: FeatureImportanceExplainer Importance of each feature """ return self.base_global_feature_importance.produce_explanation()