Source code for pyreal.explainers.dte.decision_tree_explainer
from pyreal.explainers import DecisionTreeExplainerBase, SurrogateDecisionTree
def choose_algorithm():
"""
Choose an algorithm based on the model type.
Currently, shap is the only supported algorithm
Return:
string (one of ["surrogate_tree"])
Explanation algorithm to use
"""
return "surrogate_tree"
[docs]class DecisionTreeExplainer(DecisionTreeExplainerBase):
"""
Generic DecisionTreeExplainer wrapper
An DecisionTreeExplainer object wraps multiple decision tree-based explanations. If no
specific algorithm is requested, one will be chosen based on the information given.
Currently, only surrogate tree is supported.
Args:
model (string filepath or model object):
Filepath to the pickled model to explain, or model object with .predict() function
x_train_orig (dataframe of shape (n_instances, x_orig_feature_count)):
The training set for the explainer
e_algorithm (string, one of ["surrogate_tree"]):
Explanation algorithm to use. If none, one will be chosen automatically based on model
type
is_classifier (bool):
Set this True for a classification model, False for a regression model.
max_depth (int):
The max_depth of the tree
**kwargs: see DecisionTreeExplainerBase args
"""
[docs] def __init__(
self,
model,
x_train_orig=None,
e_algorithm=None,
is_classifier=True,
max_depth=None,
**kwargs
):
self.is_classifier = is_classifier
self.max_depth = max_depth
if e_algorithm is None:
e_algorithm = choose_algorithm()
if e_algorithm == "surrogate_tree":
self.base_decision_tree = SurrogateDecisionTree(
model, x_train_orig, is_classifier, max_depth, **kwargs
)
if self.base_decision_tree is None:
raise ValueError("Invalid algorithm type %s" % e_algorithm)
super(DecisionTreeExplainer, self).__init__(model, x_train_orig, **kwargs)
[docs] def fit(self, x_train_orig=None, y_train=None):
"""
Fit this explainer object
Args:
x_train_orig (DataFrame of shape (n_instances, n_features):
Training set to fit on, required if not provided on initialization
y_train:
Targets of training set, required if not provided on initialization
"""
self.base_decision_tree.fit(x_train_orig, y_train)
return self
def produce_explanation(self, **kwargs):
"""
Returns the decision tree object, either DecisionTreeClassifier or DecisionTreeRegressor
x_orig is a dummy param to match signature
"""
return self.base_decision_tree.produce()
def produce_importances(self):
"""
Returns the feature importance created by the decision tree explainer
"""
return self.base_decision_tree.produce_importances()