Commit c1b97097 authored by Scheele, Stephan's avatar Scheele, Stephan
Browse files

updated notebook

parent 9618abeb
......@@ -7,6 +7,7 @@
"# Using Trepan: a demonstration\n",
"\n",
"By: Yuriy Sverchkov\n",
"Date: Feb 2021\n",
"\n",
"Assistant Scientist\n",
"\n",
......@@ -14,6 +15,9 @@
"\n",
"University of Wisconsin - Madison, USA\n",
"\n",
"Adapted for AI-Campus course [XAI4Ing](https://learn.ki-campus.org/courses/erklaerbareki2020) by Stephan Scheele, https://www.uni-bamberg.de/en/cogsys/scheele-stephan/\n",
"\n",
"\n",
"## Trepan\n",
"\n",
"[M. Craven and J. Shalvik. 1995. Extracting tree-structured representations of trained networks. In Proceedings of the 8th International Conference on Neural Information Processing 1995](https://dl.acm.org/doi/10.5555/2998828.2998832)\n",
......@@ -23,7 +27,16 @@
"\n",
"In this notebook we will use a modern implementation in the [`generalizedtrees`](https://github.com/Craven-Biostat-Lab/generalizedtrees) package to run Trepan on one of the original datasets used in the paper. `generalizedtrees` is under development with major changes from version to version. It is a project that attempts to bring together many different variants of tree learning together into a single framework.\n",
"\n",
"For this notebook we will use `generalizedtrees` version 1.1.0"
"For this notebook we will use `generalizedtrees` version 1.1.0\n",
"\n",
"\n",
"## Contact\n",
"Via the gitlab website: Stephan Scheele, https://www.uni-bamberg.de/en/cogsys/scheele-stephan/\n",
"\n",
"## License\n",
"\n",
"This work is licensed under Creative Commons Attribution 4.0 International License.\n",
"https://creativecommons.org/licenses/by/4.0/\n"
]
},
{
......
%% Cell type:markdown id: tags:
# Using Trepan: a demonstration
By: Yuriy Sverchkov
Date: Feb 2021
Assistant Scientist
Department of Biostatistics and Medical Informatics
University of Wisconsin - Madison, USA
Adapted for AI-Campus course [XAI4Ing](https://learn.ki-campus.org/courses/erklaerbareki2020) by Stephan Scheele, https://www.uni-bamberg.de/en/cogsys/scheele-stephan/
## Trepan
[M. Craven and J. Shalvik. 1995. Extracting tree-structured representations of trained networks. In Proceedings of the 8th International Conference on Neural Information Processing 1995](https://dl.acm.org/doi/10.5555/2998828.2998832)
Trepan is one of the first explanation-by-model-translation methods for Neural Networks.
Model translation is an approach to model explanation in which an uninterpretable black-box model is translated into an interpretable model (in this case a decision tree) that aims at having high fidelity to the black-box model, that is, make similar prediction to it.
In this notebook we will use a modern implementation in the [`generalizedtrees`](https://github.com/Craven-Biostat-Lab/generalizedtrees) package to run Trepan on one of the original datasets used in the paper. `generalizedtrees` is under development with major changes from version to version. It is a project that attempts to bring together many different variants of tree learning together into a single framework.
For this notebook we will use `generalizedtrees` version 1.1.0
## Contact
Via the gitlab website: Stephan Scheele, https://www.uni-bamberg.de/en/cogsys/scheele-stephan/
## License
This work is licensed under Creative Commons Attribution 4.0 International License.
https://creativecommons.org/licenses/by/4.0/
%% Cell type:code id: tags:
``` python
import sys
!{sys.executable} -m pip install generalizedtrees==1.1.0
```
%%%% Output: stream
Requirement already satisfied: generalizedtrees==1.1.0 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (1.1.0)
Requirement already satisfied: scipy>=1.5.2 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from generalizedtrees==1.1.0) (1.6.0)
Requirement already satisfied: pandas>=1.1.0 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from generalizedtrees==1.1.0) (1.2.2)
Requirement already satisfied: scikit-learn>=0.23.2 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from generalizedtrees==1.1.0) (0.24.1)
Requirement already satisfied: numpy>=1.19.1 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from generalizedtrees==1.1.0) (1.20.1)
Requirement already satisfied: python-dateutil>=2.7.3 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from pandas>=1.1.0->generalizedtrees==1.1.0) (2.8.1)
Requirement already satisfied: pytz>=2017.3 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from pandas>=1.1.0->generalizedtrees==1.1.0) (2021.1)
Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from scikit-learn>=0.23.2->generalizedtrees==1.1.0) (2.1.0)
Requirement already satisfied: joblib>=0.11 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from scikit-learn>=0.23.2->generalizedtrees==1.1.0) (1.0.1)
Requirement already satisfied: six>=1.5 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from python-dateutil>=2.7.3->pandas>=1.1.0->generalizedtrees==1.1.0) (1.15.0)
WARNING: You are using pip version 20.2.3; however, version 21.0.1 is available.
You should consider upgrading via the '/Users/sees/.virtualenvs/trepan-demo/bin/python -m pip install --upgrade pip' command.
%% Cell type:markdown id: tags:
## Data
We will use the Cleveland Heart Disease dataset, available from the UCI repository: http://archive.ics.uci.edu/ml/datasets/Heart+Disease.
The file we load here is specifically http://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/cleve.mod.
%% Cell type:code id: tags:
``` python
import pandas as pd
import numpy as np
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.model_selection import train_test_split
np_rng = np.random.default_rng(8372234)
sk_rng = np.random.RandomState(3957458)
```
%% Cell type:code id: tags:
``` python
full_data = pd.read_fwf(
'cleveland.txt',
skiprows=20,
names = [
'age', 'sex', 'chest pain type', 'resting bp', 'cholesterol', 'fasting blood sugar < 120', 'resting ecg',
'max heart rate', 'exercise induced angina', 'oldpeak', 'slope', 'number of vessels colored', 'thal', 'class', 'stage'
],
true_values = ['true'],
false_values = ['fal'],
na_values = '?').dropna(axis=0)
full_data
```
%%%% Output: execute_result
age sex chest pain type resting bp cholesterol \
0 63.0 male angina 145.0 233.0
1 67.0 male asympt 160.0 286.0
2 67.0 male asympt 120.0 229.0
3 37.0 male notang 130.0 250.0
4 41.0 fem abnang 130.0 204.0
.. ... ... ... ... ...
298 48.0 male notang 124.0 255.0
299 57.0 male asympt 132.0 207.0
300 49.0 male notang 118.0 149.0
301 74.0 fem abnang 120.0 269.0
302 54.0 fem notang 160.0 201.0
fasting blood sugar < 120 resting ecg max heart rate \
0 True hyp 150.0
1 False hyp 108.0
2 False hyp 129.0
3 False norm 187.0
4 False hyp 172.0
.. ... ... ...
298 True norm 175.0
299 False norm 168.0
300 False hyp 126.0
301 False hyp 121.0
302 False norm 163.0
exercise induced angina oldpeak slope number of vessels colored thal \
0 False 2.3 down 0.0 fix
1 True 1.5 flat 3.0 norm
2 True 2.6 flat 2.0 rev
3 False 3.5 down 0.0 norm
4 False 1.4 up 0.0 norm
.. ... ... ... ... ...
298 False 0.0 up 2.0 norm
299 True 0.0 up 0.0 rev
300 False 0.8 up 3.0 norm
301 True 0.2 up 1.0 norm
302 False 0.0 up 1.0 norm
class stage
0 buff H
1 sick S2
2 sick S1
3 buff H
4 buff H
.. ... ...
298 buff H
299 buff H
300 sick S1
301 buff H
302 buff H
[296 rows x 15 columns]
%% Cell type:markdown id: tags:
Since we will be using scikit-learn to learn our black-box model, we need to convert categorical variables to numeric vectors.
We will also make a train-test split.
%% Cell type:code id: tags:
``` python
data_df = full_data.drop(['class', 'stage'], axis=1)
encoder = OneHotEncoder(drop = 'if_binary')
lencoder = LabelEncoder()
numeric_features = data_df.select_dtypes(include = 'number')
categorical_features_df = data_df.select_dtypes(exclude = 'number')
categorical_features = encoder.fit_transform(categorical_features_df).toarray()
feature_names = np.append(numeric_features.columns, encoder.get_feature_names(categorical_features_df.columns))
x = np.append(
numeric_features,
categorical_features,
axis = 1)
y = lencoder.fit_transform(full_data['class'])
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state = sk_rng)
```
%% Cell type:markdown id: tags:
Our new feature set is below. Subsequently features will be referenced by their 0-index
%% Cell type:code id: tags:
``` python
pd.DataFrame({'Feature Name': feature_names})
```
%%%% Output: execute_result
Feature Name
0 age
1 resting bp
2 cholesterol
3 max heart rate
4 oldpeak
5 number of vessels colored
6 sex_male
7 chest pain type_abnang
8 chest pain type_angina
9 chest pain type_asympt
10 chest pain type_notang
11 fasting blood sugar < 120_True
12 resting ecg_abn
13 resting ecg_hyp
14 resting ecg_norm
15 exercise induced angina_True
16 slope_down
17 slope_flat
18 slope_up
19 thal_fix
20 thal_norm
21 thal_rev
%% Cell type:markdown id: tags:
## Black-box model
Like in the paper, our black-box model will be a fully connected Neural Network with one hidden layer, and the size of the layer is determined by cross-validation within the training set.
%% Cell type:code id: tags:
``` python
from sklearn.model_selection import GridSearchCV
from sklearn.neural_network import MLPClassifier
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
model = GridSearchCV(
make_pipeline(StandardScaler(), MLPClassifier(solver='lbfgs', alpha=1e-5, random_state=sk_rng)),
param_grid = {'mlpclassifier__hidden_layer_sizes': [(5,), (10,), (20,), (40,)]},
refit = True
)
model.fit(x_train, y_train)
model.best_estimator_
```
%%%% Output: execute_result
Pipeline(steps=[('standardscaler', StandardScaler()),
('mlpclassifier',
MLPClassifier(alpha=1e-05, hidden_layer_sizes=(5,),
random_state=RandomState(MT19937) at 0x13B59DA40,
solver='lbfgs'))])
%% Cell type:markdown id: tags:
## Decision tree explanation
%% Cell type:code id: tags:
``` python
import time
from generalizedtrees.recipes import trepan
from generalizedtrees.vis.vis import explanation_to_html
from generalizedtrees.features import FeatureSpec
```
%% Cell type:markdown id: tags:
We have a function that serves as a 'recipe' for Trepan in our `generalizedtrees` python package, this returns an object that can be fit to data and a model.
%% Cell type:code id: tags:
``` python
explanation = trepan(
m_of_n=False,
max_tree_size=10,
impurity='entropy',
rng = np_rng)
```
%% Cell type:markdown id: tags:
We learn the explanation from the black-box model using the `fit` method, which takes unlabeled data and the black-box model as input.
%% Cell type:code id: tags:
``` python
t0 = time.time()
explanation.fit(x_train, model)
t1 = time.time()
print(f'Time taken to learn explanation: {t1-t0} seconds')
```
%%%% Output: stream
Assuming continuous features in the absence of feature specifications
%%%% Output: stream
Time taken to learn explanation: 3.302284002304077 seconds
%% Cell type:markdown id: tags:
A console-friendly representation of the tree is available:
%% Cell type:code id: tags:
``` python
print(explanation.show_tree())
```
%%%% Output: stream
Test x[5] > 0.5
+--Test x[5] > 1.5
| +--[0.49 0.51]
| +--[0.577 0.423]
+--Test x[4] > 1.55
+--[0.539 0.461]
+--Test x[20] > 0.5
+--Test x[4] > 1.45
| +--[0.561 0.439]
| +--[0.699 0.301]
+--[0.601 0.399]
%% Cell type:markdown id: tags:
A graphical representation (as an HTML file) can be generated as well:
%% Cell type:code id: tags:
``` python
explanation_to_html(explanation, 'explanation.html')
```
%% Cell type:markdown id: tags:
Open the generated html:
%% Cell type:markdown id: tags:
Click the follwing link: [explanation.html](explanation.html)
%% Cell type:markdown id: tags:
## Performance comparison
The resulting tree can itself be treated as a classifier, and we can check how this learned tree performs on test data, alongside the original black-box model for comparison.
%% Cell type:code id: tags:
``` python
y_test_trepan = explanation.predict(x_test)
y_test_model = model.predict(x_test)
```
%% Cell type:code id: tags:
``` python
from sklearn.metrics import classification_report
print('Trepan:')
print(classification_report(y_test, y_test_trepan, target_names=lencoder.classes_))
print('Black Box:')
print(classification_report(y_test, y_test_model, target_names=lencoder.classes_))
```
%%%% Output: stream
Trepan:
precision recall f1-score support
buff 0.71 0.94 0.81 16
sick 0.89 0.57 0.70 14
accuracy 0.77 30
macro avg 0.80 0.75 0.75 30
weighted avg 0.80 0.77 0.76 30
Black Box:
precision recall f1-score support
buff 0.79 0.94 0.86 16
sick 0.91 0.71 0.80 14
accuracy 0.83 30
macro avg 0.85 0.83 0.83 30
weighted avg 0.85 0.83 0.83 30
%% Cell type:markdown id: tags:
---
Another metric of interest for explanation methods specifically is *fidelity*, which measures how well the explanation's predictions match the black-box predictions:
%% Cell type:code id: tags:
``` python
print('Training set fidelity')
print(classification_report(model.predict(x_train), explanation.predict(x_train), target_names=lencoder.classes_))
print('Test set fidelity')
print(classification_report(y_test_model, y_test_trepan, target_names=lencoder.classes_))
```
%%%% Output: stream
Training set fidelity
precision recall f1-score support
buff 0.65 0.93 0.77 151
sick 0.80 0.34 0.48 115
accuracy 0.68 266
macro avg 0.72 0.64 0.62 266
weighted avg 0.71 0.68 0.64 266
Test set fidelity
precision recall f1-score support
buff 0.81 0.89 0.85 19
sick 0.78 0.64 0.70 11
accuracy 0.80 30
macro avg 0.79 0.77 0.78 30
weighted avg 0.80 0.80 0.80 30
%% Cell type:markdown id: tags:
## Remarks
There are some differences between the experiment presented here and the experiment in the paper. The version of Trepan we ran here did not use m-of-n splits at nodes (the search for splits becomes very slow if we enable it), we used a different number for the minimum samples required at each node (1000), and we used a simpler rejection-based sampling scheme for generating synthetic samples. The details of the neural network learning algorithm are also different.
This notebook is available [in a gist](https://gist.github.com/sverchkov/c87b301db1b88e0f4cc8bb7d77b889b9).
For issues/questions about the `generalizedtrees` package contact us through https://github.com/Craven-Biostat-Lab/generalizedtrees
%% Cell type:code id: tags:
``` python
```
......
......@@ -7,6 +7,7 @@
"# Using Trepan: a demonstration\n",
"\n",
"By: Yuriy Sverchkov\n",
"Date: Feb 2021\n",
"\n",
"Assistant Scientist\n",
"\n",
......@@ -14,6 +15,9 @@
"\n",
"University of Wisconsin - Madison, USA\n",
"\n",
"Adapted for AI-Campus course [XAI4Ing](https://learn.ki-campus.org/courses/erklaerbareki2020) by Stephan Scheele, https://www.uni-bamberg.de/en/cogsys/scheele-stephan/\n",
"\n",
"\n",
"## Trepan\n",
"\n",
"[M. Craven and J. Shalvik. 1995. Extracting tree-structured representations of trained networks. In Proceedings of the 8th International Conference on Neural Information Processing 1995](https://dl.acm.org/doi/10.5555/2998828.2998832)\n",
......@@ -23,7 +27,16 @@
"\n",
"In this notebook we will use a modern implementation in the [`generalizedtrees`](https://github.com/Craven-Biostat-Lab/generalizedtrees) package to run Trepan on one of the original datasets used in the paper. `generalizedtrees` is under development with major changes from version to version. It is a project that attempts to bring together many different variants of tree learning together into a single framework.\n",
"\n",
"For this notebook we will use `generalizedtrees` version 1.1.0"
"For this notebook we will use `generalizedtrees` version 1.1.0\n",
"\n",
"\n",
"## Contact\n",
"Via the gitlab website: Stephan Scheele, https://www.uni-bamberg.de/en/cogsys/scheele-stephan/\n",
"\n",
"## License\n",
"\n",
"This work is licensed under Creative Commons Attribution 4.0 International License.\n",
"https://creativecommons.org/licenses/by/4.0/\n"
]
},
{
......
%% Cell type:markdown id: tags:
# Using Trepan: a demonstration
By: Yuriy Sverchkov
Date: Feb 2021
Assistant Scientist
Department of Biostatistics and Medical Informatics
University of Wisconsin - Madison, USA
Adapted for AI-Campus course [XAI4Ing](https://learn.ki-campus.org/courses/erklaerbareki2020) by Stephan Scheele, https://www.uni-bamberg.de/en/cogsys/scheele-stephan/
## Trepan
[M. Craven and J. Shalvik. 1995. Extracting tree-structured representations of trained networks. In Proceedings of the 8th International Conference on Neural Information Processing 1995](https://dl.acm.org/doi/10.5555/2998828.2998832)
Trepan is one of the first explanation-by-model-translation methods for Neural Networks.
Model translation is an approach to model explanation in which an uninterpretable black-box model is translated into an interpretable model (in this case a decision tree) that aims at having high fidelity to the black-box model, that is, make similar prediction to it.
In this notebook we will use a modern implementation in the [`generalizedtrees`](https://github.com/Craven-Biostat-Lab/generalizedtrees) package to run Trepan on one of the original datasets used in the paper. `generalizedtrees` is under development with major changes from version to version. It is a project that attempts to bring together many different variants of tree learning together into a single framework.
For this notebook we will use `generalizedtrees` version 1.1.0
## Contact
Via the gitlab website: Stephan Scheele, https://www.uni-bamberg.de/en/cogsys/scheele-stephan/
## License
This work is licensed under Creative Commons Attribution 4.0 International License.
https://creativecommons.org/licenses/by/4.0/
%% Cell type:code id: tags:
``` python
import sys
!{sys.executable} -m pip install generalizedtrees==1.1.0
```
%%%% Output: stream
Requirement already satisfied: generalizedtrees==1.1.0 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (1.1.0)
Requirement already satisfied: scipy>=1.5.2 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from generalizedtrees==1.1.0) (1.6.0)
Requirement already satisfied: pandas>=1.1.0 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from generalizedtrees==1.1.0) (1.2.2)
Requirement already satisfied: scikit-learn>=0.23.2 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from generalizedtrees==1.1.0) (0.24.1)
Requirement already satisfied: numpy>=1.19.1 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from generalizedtrees==1.1.0) (1.20.1)
Requirement already satisfied: python-dateutil>=2.7.3 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from pandas>=1.1.0->generalizedtrees==1.1.0) (2.8.1)
Requirement already satisfied: pytz>=2017.3 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from pandas>=1.1.0->generalizedtrees==1.1.0) (2021.1)
Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from scikit-learn>=0.23.2->generalizedtrees==1.1.0) (2.1.0)
Requirement already satisfied: joblib>=0.11 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from scikit-learn>=0.23.2->generalizedtrees==1.1.0) (1.0.1)
Requirement already satisfied: six>=1.5 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from python-dateutil>=2.7.3->pandas>=1.1.0->generalizedtrees==1.1.0) (1.15.0)
WARNING: You are using pip version 20.2.3; however, version 21.0.1 is available.
You should consider upgrading via the '/Users/sees/.virtualenvs/trepan-demo/bin/python -m pip install --upgrade pip' command.
%% Cell type:markdown id: tags:
## Data
We will use the Cleveland Heart Disease dataset, available from the UCI repository: http://archive.ics.uci.edu/ml/datasets/Heart+Disease.
The file we load here is specifically http://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/cleve.mod.
%% Cell type:code id: tags:
``` python
import pandas as pd
import numpy as np
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.model_selection import train_test_split
np_rng = np.random.default_rng(8372234)
sk_rng = np.random.RandomState(3957458)
```
%% Cell type:code id: tags:
``` python
full_data = pd.read_fwf(
'cleveland.txt',
skiprows=20,
names = [
'age', 'sex', 'chest pain type', 'resting bp', 'cholesterol', 'fasting blood sugar < 120', 'resting ecg',
'max heart rate', 'exercise induced angina', 'oldpeak', 'slope', 'number of vessels colored', 'thal', 'class', 'stage'
],
true_values = ['true'],
false_values = ['fal'],
na_values = '?').dropna(axis=0)
full_data
```
%%%% Output: execute_result
age sex chest pain type resting bp cholesterol \
0 63.0 male angina 145.0 233.0
1 67.0 male asympt 160.0 286.0
2 67.0 male asympt 120.0 229.0
3 37.0 male notang 130.0 250.0
4 41.0 fem abnang 130.0 204.0
.. ... ... ... ... ...
298 48.0 male notang 124.0 255.0
299 57.0 male asympt 132.0 207.0
300 49.0 male notang 118.0 149.0
301 74.0 fem abnang 120.0 269.0
302 54.0 fem notang 160.0 201.0
fasting blood sugar < 120 resting ecg max heart rate \
0 True hyp 150.0
1 False hyp 108.0
2 False hyp 129.0
3 False norm 187.0
4 False hyp 172.0
.. ... ... ...
298 True norm 175.0
299 False norm 168.0
300 False hyp 126.0
301 False hyp 121.0
302 False norm 163.0
exercise induced angina oldpeak slope number of vessels colored thal \
0 False 2.3 down 0.0 fix
1 True 1.5 flat 3.0 norm
2 True 2.6 flat 2.0 rev
3 False 3.5 down 0.0 norm
4 False 1.4 up 0.0 norm
.. ... ... ... ... ...
298 False 0.0 up 2.0 norm
299 True 0.0 up 0.0 rev
300 False 0.8 up 3.0 norm
301 True 0.2 up 1.0 norm
302 False 0.0 up 1.0 norm
class stage
0 buff H
1 sick S2
2 sick S1
3 buff H
4 buff H
.. ... ...
298 buff H
299 buff H
300 sick S1
301 buff H
302 buff H
[296 rows x 15 columns]
%% Cell type:markdown id: tags:
Since we will be using scikit-learn to learn our black-box model, we need to convert categorical variables to numeric vectors.
We will also make a train-test split.
%% Cell type:code id: tags:
``` python
data_df = full_data.drop(['class', 'stage'], axis=1)
encoder = OneHotEncoder(drop = 'if_binary')
lencoder = LabelEncoder()
numeric_features = data_df.select_dtypes(include = 'number')
categorical_features_df = data_df.select_dtypes(exclude = 'number')
categorical_features = encoder.fit_transform(categorical_features_df).toarray()
feature_names = np.append(numeric_features.columns, encoder.get_feature_names(categorical_features_df.columns))
x = np.append(
numeric_features,
categorical_features,
axis = 1)
y = lencoder.fit_transform(full_data['class'])
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state = sk_rng)
```
%% Cell type:markdown id: tags:
Our new feature set is below. Subsequently features will be referenced by their 0-index
%% Cell type:code id: tags:
``` python
pd.DataFrame({'Feature Name': feature_names})
```
%%%% Output: execute_result
Feature Name
0 age
1 resting bp
2 cholesterol
3 max heart rate
4 oldpeak
5 number of vessels colored
6 sex_male
7 chest pain type_abnang
8 chest pain type_angina
9 chest pain type_asympt
10 chest pain type_notang
11 fasting blood sugar < 120_True
12 resting ecg_abn
13 resting ecg_hyp
14 resting ecg_norm
15 exercise induced angina_True
16 slope_down
17 slope_flat
18 slope_up
19 thal_fix
20 thal_norm
21 thal_rev
%% Cell type:markdown id: tags:
## Black-box model
Like in the paper, our black-box model will be a fully connected Neural Network with one hidden layer, and the size of the layer is determined by cross-validation within the training set.
%% Cell type:code id: tags:
``` python
from sklearn.model_selection import GridSearchCV
from sklearn.neural_network import MLPClassifier
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
model = GridSearchCV(
make_pipeline(StandardScaler(), MLPClassifier(solver='lbfgs', alpha=1e-5, random_state=sk_rng)),
param_grid = {'mlpclassifier__hidden_layer_sizes': [(5,), (10,), (20,), (40,)]},
refit = True
)
model.fit(x_train, y_train)
model.best_estimator_
```
%%%% Output: execute_result
Pipeline(steps=[('standardscaler', StandardScaler()),
('mlpclassifier',
MLPClassifier(alpha=1e-05, hidden_layer_sizes=(5,),
random_state=RandomState(MT19937) at 0x13B59DA40,
solver='lbfgs'))])
%% Cell type:markdown id: tags:
## Decision tree explanation
%% Cell type:code id: tags:
``` python
import time
from generalizedtrees.recipes import trepan
from generalizedtrees.vis.vis import explanation_to_html
from generalizedtrees.features import FeatureSpec
```
%% Cell type:markdown id: tags:
We have a function that serves as a 'recipe' for Trepan in our `generalizedtrees` python package, this returns an object that can be fit to data and a model.
%% Cell type:code id: tags:
``` python
explanation = trepan(
m_of_n=False,
max_tree_size=10,
impurity='entropy',
rng = np_rng)
```
%% Cell type:markdown id: tags:
We learn the explanation from the black-box model using the `fit` method, which takes unlabeled data and the black-box model as input.
%% Cell type:code id: tags:
``` python
t0 = time.time()
explanation.fit(x_train, model)
t1 = time.time()
print(f'Time taken to learn explanation: {t1-t0} seconds')
```
%%%% Output: stream
Assuming continuous features in the absence of feature specifications
%%%% Output: stream
Time taken to learn explanation: 3.302284002304077 seconds
%% Cell type:markdown id: tags:
A console-friendly representation of the tree is available:
%% Cell type:code id: tags:
``` python
print(explanation.show_tree())
```
%%%% Output: stream
Test x[5] > 0.5
+--Test x[5] > 1.5
| +--[0.49 0.51]
| +--[0.577 0.423]
+--Test x[4] > 1.55
+--[0.539 0.461]
+--Test x[20] > 0.5
+--Test x[4] > 1.45
| +--[0.561 0.439]
| +--[0.699 0.301]
+--[0.601 0.399]
%% Cell type:markdown id: tags:
A graphical representation (as an HTML file) can be generated as well:
%% Cell type:code id: tags:
``` python
explanation_to_html(explanation, 'explanation.html')
```
%% Cell type:markdown id: tags:
Open the generated html:
%% Cell type:markdown id: tags:
Click the follwing link: [explanation.html](explanation.html)
%% Cell type:markdown id: tags:
## Performance comparison
The resulting tree can itself be treated as a classifier, and we can check how this learned tree performs on test data, alongside the original black-box model for comparison.
%% Cell type:code id: tags:
``` python
y_test_trepan = explanation.predict(x_test)
y_test_model = model.predict(x_test)
```
%% Cell type:code id: tags:
``` python
from sklearn.metrics import classification_report
print('Trepan:')
print(classification_report(y_test, y_test_trepan, target_names=lencoder.classes_))
print('Black Box:')
print(classification_report(y_test, y_test_model, target_names=lencoder.classes_))
```
%%%% Output: stream
Trepan:
precision recall f1-score support
buff 0.71 0.94 0.81 16
sick 0.89 0.57 0.70 14
accuracy 0.77 30
macro avg 0.80 0.75 0.75 30
weighted avg 0.80 0.77 0.76 30
Black Box:
precision recall f1-score support
buff 0.79 0.94 0.86 16
sick 0.91 0.71 0.80 14
accuracy 0.83 30
macro avg 0.85 0.83 0.83 30
weighted avg 0.85 0.83 0.83 30
%% Cell type:markdown id: tags:
---
Another metric of interest for explanation methods specifically is *fidelity*, which measures how well the explanation's predictions match the black-box predictions:
%% Cell type:code id: tags:
``` python
print('Training set fidelity')
print(classification_report(model.predict(x_train), explanation.predict(x_train), target_names=lencoder.classes_))
print('Test set fidelity')
print(classification_report(y_test_model, y_test_trepan, target_names=lencoder.classes_))
```
%%%% Output: stream
Training set fidelity
precision recall f1-score support
buff 0.65 0.93 0.77 151
sick 0.80 0.34 0.48 115
accuracy 0.68 266
macro avg 0.72 0.64 0.62 266
weighted avg 0.71 0.68 0.64 266
Test set fidelity
precision recall f1-score support
buff 0.81 0.89 0.85 19
sick 0.78 0.64 0.70 11
accuracy 0.80 30
macro avg 0.79 0.77 0.78 30
weighted avg 0.80 0.80 0.80 30
%% Cell type:markdown id: tags:
## Remarks
There are some differences between the experiment presented here and the experiment in the paper. The version of Trepan we ran here did not use m-of-n splits at nodes (the search for splits becomes very slow if we enable it), we used a different number for the minimum samples required at each node (1000), and we used a simpler rejection-based sampling scheme for generating synthetic samples. The details of the neural network learning algorithm are also different.
This notebook is available [in a gist](https://gist.github.com/sverchkov/c87b301db1b88e0f4cc8bb7d77b889b9).
For issues/questions about the `generalizedtrees` package contact us through https://github.com/Craven-Biostat-Lab/generalizedtrees
%% Cell type:code id: tags:
``` python
```
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment