Commit d97807a9 authored by Scheele, Stephan's avatar Scheele, Stephan
Browse files

Initial commoit of trepan notebook

parents
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Using Trepan: a demonstration\n",
"\n",
"By: Yuriy Sverchkov\n",
"\n",
"Assistant Scientist\n",
"\n",
"Department of Biostatistics and Medical Informatics\n",
"\n",
"University of Wisconsin - Madison, USA\n",
"\n",
"## Trepan\n",
"\n",
"[M. Craven and J. Shalvik. 1995. Extracting tree-structured representations of trained networks. In Proceedings of the 8th International Conference on Neural Information Processing 1995](https://dl.acm.org/doi/10.5555/2998828.2998832)\n",
"\n",
"Trepan is one of the first explanation-by-model-translation methods for Neural Networks.\n",
"Model translation is an approach to model explanation in which an uninterpretable black-box model is translated into an interpretable model (in this case a decision tree) that aims at having high fidelity to the black-box model, that is, make similar prediction to it.\n",
"\n",
"In this notebook we will use a modern implementation in the [`generalizedtrees`](https://github.com/Craven-Biostat-Lab/generalizedtrees) package to run Trepan on one of the original datasets used in the paper. `generalizedtrees` is under development with major changes from version to version. It is a project that attempts to bring together many different variants of tree learning together into a single framework.\n",
"\n",
"For this notebook we will use `generalizedtrees` version 1.1.0"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: generalizedtrees==1.1.0 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (1.1.0)\n",
"Requirement already satisfied: scipy>=1.5.2 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from generalizedtrees==1.1.0) (1.6.0)\n",
"Requirement already satisfied: pandas>=1.1.0 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from generalizedtrees==1.1.0) (1.2.2)\n",
"Requirement already satisfied: scikit-learn>=0.23.2 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from generalizedtrees==1.1.0) (0.24.1)\n",
"Requirement already satisfied: numpy>=1.19.1 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from generalizedtrees==1.1.0) (1.20.1)\n",
"Requirement already satisfied: python-dateutil>=2.7.3 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from pandas>=1.1.0->generalizedtrees==1.1.0) (2.8.1)\n",
"Requirement already satisfied: pytz>=2017.3 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from pandas>=1.1.0->generalizedtrees==1.1.0) (2021.1)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from scikit-learn>=0.23.2->generalizedtrees==1.1.0) (2.1.0)\n",
"Requirement already satisfied: joblib>=0.11 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from scikit-learn>=0.23.2->generalizedtrees==1.1.0) (1.0.1)\n",
"Requirement already satisfied: six>=1.5 in /Users/sees/.virtualenvs/trepan-demo/lib/python3.8/site-packages (from python-dateutil>=2.7.3->pandas>=1.1.0->generalizedtrees==1.1.0) (1.15.0)\n",
"\u001b[33mWARNING: You are using pip version 20.2.3; however, version 21.0.1 is available.\n",
"You should consider upgrading via the '/Users/sees/.virtualenvs/trepan-demo/bin/python -m pip install --upgrade pip' command.\u001b[0m\n"
]
}
],
"source": [
"import sys\n",
"!{sys.executable} -m pip install generalizedtrees==1.1.0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data\n",
"\n",
"We will use the Cleveland Heart Disease dataset, available from the UCI repository: http://archive.ics.uci.edu/ml/datasets/Heart+Disease.\n",
"The file we load here is specifically http://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/cleve.mod."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn.preprocessing import OneHotEncoder, LabelEncoder\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"np_rng = np.random.default_rng(8372234)\n",
"sk_rng = np.random.RandomState(3957458)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>sex</th>\n",
" <th>chest pain type</th>\n",
" <th>resting bp</th>\n",
" <th>cholesterol</th>\n",
" <th>fasting blood sugar &lt; 120</th>\n",
" <th>resting ecg</th>\n",
" <th>max heart rate</th>\n",
" <th>exercise induced angina</th>\n",
" <th>oldpeak</th>\n",
" <th>slope</th>\n",
" <th>number of vessels colored</th>\n",
" <th>thal</th>\n",
" <th>class</th>\n",
" <th>stage</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>63.0</td>\n",
" <td>male</td>\n",
" <td>angina</td>\n",
" <td>145.0</td>\n",
" <td>233.0</td>\n",
" <td>True</td>\n",
" <td>hyp</td>\n",
" <td>150.0</td>\n",
" <td>False</td>\n",
" <td>2.3</td>\n",
" <td>down</td>\n",
" <td>0.0</td>\n",
" <td>fix</td>\n",
" <td>buff</td>\n",
" <td>H</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>67.0</td>\n",
" <td>male</td>\n",
" <td>asympt</td>\n",
" <td>160.0</td>\n",
" <td>286.0</td>\n",
" <td>False</td>\n",
" <td>hyp</td>\n",
" <td>108.0</td>\n",
" <td>True</td>\n",
" <td>1.5</td>\n",
" <td>flat</td>\n",
" <td>3.0</td>\n",
" <td>norm</td>\n",
" <td>sick</td>\n",
" <td>S2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>67.0</td>\n",
" <td>male</td>\n",
" <td>asympt</td>\n",
" <td>120.0</td>\n",
" <td>229.0</td>\n",
" <td>False</td>\n",
" <td>hyp</td>\n",
" <td>129.0</td>\n",
" <td>True</td>\n",
" <td>2.6</td>\n",
" <td>flat</td>\n",
" <td>2.0</td>\n",
" <td>rev</td>\n",
" <td>sick</td>\n",
" <td>S1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>37.0</td>\n",
" <td>male</td>\n",
" <td>notang</td>\n",
" <td>130.0</td>\n",
" <td>250.0</td>\n",
" <td>False</td>\n",
" <td>norm</td>\n",
" <td>187.0</td>\n",
" <td>False</td>\n",
" <td>3.5</td>\n",
" <td>down</td>\n",
" <td>0.0</td>\n",
" <td>norm</td>\n",
" <td>buff</td>\n",
" <td>H</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>41.0</td>\n",
" <td>fem</td>\n",
" <td>abnang</td>\n",
" <td>130.0</td>\n",
" <td>204.0</td>\n",
" <td>False</td>\n",
" <td>hyp</td>\n",
" <td>172.0</td>\n",
" <td>False</td>\n",
" <td>1.4</td>\n",
" <td>up</td>\n",
" <td>0.0</td>\n",
" <td>norm</td>\n",
" <td>buff</td>\n",
" <td>H</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>298</th>\n",
" <td>48.0</td>\n",
" <td>male</td>\n",
" <td>notang</td>\n",
" <td>124.0</td>\n",
" <td>255.0</td>\n",
" <td>True</td>\n",
" <td>norm</td>\n",
" <td>175.0</td>\n",
" <td>False</td>\n",
" <td>0.0</td>\n",
" <td>up</td>\n",
" <td>2.0</td>\n",
" <td>norm</td>\n",
" <td>buff</td>\n",
" <td>H</td>\n",
" </tr>\n",
" <tr>\n",
" <th>299</th>\n",
" <td>57.0</td>\n",
" <td>male</td>\n",
" <td>asympt</td>\n",
" <td>132.0</td>\n",
" <td>207.0</td>\n",
" <td>False</td>\n",
" <td>norm</td>\n",
" <td>168.0</td>\n",
" <td>True</td>\n",
" <td>0.0</td>\n",
" <td>up</td>\n",
" <td>0.0</td>\n",
" <td>rev</td>\n",
" <td>buff</td>\n",
" <td>H</td>\n",
" </tr>\n",
" <tr>\n",
" <th>300</th>\n",
" <td>49.0</td>\n",
" <td>male</td>\n",
" <td>notang</td>\n",
" <td>118.0</td>\n",
" <td>149.0</td>\n",
" <td>False</td>\n",
" <td>hyp</td>\n",
" <td>126.0</td>\n",
" <td>False</td>\n",
" <td>0.8</td>\n",
" <td>up</td>\n",
" <td>3.0</td>\n",
" <td>norm</td>\n",
" <td>sick</td>\n",
" <td>S1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>301</th>\n",
" <td>74.0</td>\n",
" <td>fem</td>\n",
" <td>abnang</td>\n",
" <td>120.0</td>\n",
" <td>269.0</td>\n",
" <td>False</td>\n",
" <td>hyp</td>\n",
" <td>121.0</td>\n",
" <td>True</td>\n",
" <td>0.2</td>\n",
" <td>up</td>\n",
" <td>1.0</td>\n",
" <td>norm</td>\n",
" <td>buff</td>\n",
" <td>H</td>\n",
" </tr>\n",
" <tr>\n",
" <th>302</th>\n",
" <td>54.0</td>\n",
" <td>fem</td>\n",
" <td>notang</td>\n",
" <td>160.0</td>\n",
" <td>201.0</td>\n",
" <td>False</td>\n",
" <td>norm</td>\n",
" <td>163.0</td>\n",
" <td>False</td>\n",
" <td>0.0</td>\n",
" <td>up</td>\n",
" <td>1.0</td>\n",
" <td>norm</td>\n",
" <td>buff</td>\n",
" <td>H</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>296 rows × 15 columns</p>\n",
"</div>"
],
"text/plain": [
" age sex chest pain type resting bp cholesterol \\\n",
"0 63.0 male angina 145.0 233.0 \n",
"1 67.0 male asympt 160.0 286.0 \n",
"2 67.0 male asympt 120.0 229.0 \n",
"3 37.0 male notang 130.0 250.0 \n",
"4 41.0 fem abnang 130.0 204.0 \n",
".. ... ... ... ... ... \n",
"298 48.0 male notang 124.0 255.0 \n",
"299 57.0 male asympt 132.0 207.0 \n",
"300 49.0 male notang 118.0 149.0 \n",
"301 74.0 fem abnang 120.0 269.0 \n",
"302 54.0 fem notang 160.0 201.0 \n",
"\n",
" fasting blood sugar < 120 resting ecg max heart rate \\\n",
"0 True hyp 150.0 \n",
"1 False hyp 108.0 \n",
"2 False hyp 129.0 \n",
"3 False norm 187.0 \n",
"4 False hyp 172.0 \n",
".. ... ... ... \n",
"298 True norm 175.0 \n",
"299 False norm 168.0 \n",
"300 False hyp 126.0 \n",
"301 False hyp 121.0 \n",
"302 False norm 163.0 \n",
"\n",
" exercise induced angina oldpeak slope number of vessels colored thal \\\n",
"0 False 2.3 down 0.0 fix \n",
"1 True 1.5 flat 3.0 norm \n",
"2 True 2.6 flat 2.0 rev \n",
"3 False 3.5 down 0.0 norm \n",
"4 False 1.4 up 0.0 norm \n",
".. ... ... ... ... ... \n",
"298 False 0.0 up 2.0 norm \n",
"299 True 0.0 up 0.0 rev \n",
"300 False 0.8 up 3.0 norm \n",
"301 True 0.2 up 1.0 norm \n",
"302 False 0.0 up 1.0 norm \n",
"\n",
" class stage \n",
"0 buff H \n",
"1 sick S2 \n",
"2 sick S1 \n",
"3 buff H \n",
"4 buff H \n",
".. ... ... \n",
"298 buff H \n",
"299 buff H \n",
"300 sick S1 \n",
"301 buff H \n",
"302 buff H \n",
"\n",
"[296 rows x 15 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"full_data = pd.read_fwf(\n",
" 'cleveland.txt',\n",
" skiprows=20,\n",
" names = [\n",
" 'age', 'sex', 'chest pain type', 'resting bp', 'cholesterol', 'fasting blood sugar < 120', 'resting ecg',\n",
" 'max heart rate', 'exercise induced angina', 'oldpeak', 'slope', 'number of vessels colored', 'thal', 'class', 'stage'\n",
" ],\n",
" true_values = ['true'],\n",
" false_values = ['fal'],\n",
" na_values = '?').dropna(axis=0)\n",
"\n",
"full_data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Since we will be using scikit-learn to learn our black-box model, we need to convert categorical variables to numeric vectors.\n",
"We will also make a train-test split."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"data_df = full_data.drop(['class', 'stage'], axis=1)\n",
"\n",
"encoder = OneHotEncoder(drop = 'if_binary')\n",
"lencoder = LabelEncoder()\n",
"\n",
"numeric_features = data_df.select_dtypes(include = 'number')\n",
"categorical_features_df = data_df.select_dtypes(exclude = 'number')\n",
"categorical_features = encoder.fit_transform(categorical_features_df).toarray()\n",
"feature_names = np.append(numeric_features.columns, encoder.get_feature_names(categorical_features_df.columns))\n",
"\n",
"x = np.append(\n",
" numeric_features,\n",
" categorical_features,\n",
" axis = 1)\n",
"\n",
"y = lencoder.fit_transform(full_data['class'])\n",
"\n",
"x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state = sk_rng)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our new feature set is below. Subsequently features will be referenced by their 0-index"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Feature Name</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>age</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>resting bp</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>cholesterol</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>max heart rate</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>oldpeak</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>number of vessels colored</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>sex_male</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>chest pain type_abnang</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>chest pain type_angina</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>chest pain type_asympt</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>chest pain type_notang</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>fasting blood sugar &lt; 120_True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>resting ecg_abn</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>resting ecg_hyp</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>resting ecg_norm</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>exercise induced angina_True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>slope_down</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <td>slope_flat</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <td>slope_up</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>thal_fix</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <td>thal_norm</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21</th>\n",
" <td>thal_rev</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Feature Name\n",
"0 age\n",
"1 resting bp\n",
"2 cholesterol\n",
"3 max heart rate\n",
"4 oldpeak\n",
"5 number of vessels colored\n",
"6 sex_male\n",
"7 chest pain type_abnang\n",
"8 chest pain type_angina\n",
"9 chest pain type_asympt\n",
"10 chest pain type_notang\n",
"11 fasting blood sugar < 120_True\n",
"12 resting ecg_abn\n",
"13 resting ecg_hyp\n",
"14 resting ecg_norm\n",
"15 exercise induced angina_True\n",
"16 slope_down\n",
"17 slope_flat\n",
"18 slope_up\n",
"19 thal_fix\n",
"20 thal_norm\n",
"21 thal_rev"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame({'Feature Name': feature_names})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Black-box model\n",
"\n",
"Like in the paper, our black-box model will be a fully connected Neural Network with one hidden layer, and the size of the layer is determined by cross-validation within the training set."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('standardscaler', StandardScaler()),\n",
" ('mlpclassifier',\n",
" MLPClassifier(alpha=1e-05, hidden_layer_sizes=(5,),\n",
" random_state=RandomState(MT19937) at 0x13B59DA40,\n",
" solver='lbfgs'))])"
]
},
"execution_count": 7,
"metadata": {},