import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import xgboost as xgb
%matplotlib inline
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier, export_text, plot_tree
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_extraction import DictVectorizer
from sklearn.impute import KNNImputer
from imblearn.over_sampling import SMOTE
from sklearn.metrics import roc_auc_score, aucStudy Case
Classification Problem: Stroke Prediction
Load library
Data
Stroke Dataset
According to the World Health Organization (WHO) stroke is the 2nd leading cause of death globally, responsible for approximately 11% of total deaths.
This dataset is used to predict whether a patient is likely to get stroke based on the input parameters like gender, age, various diseases, and smoking status. Each row in the data provides relavant information about the patient.
Attribute Information:
id: unique identifiergender: “Male”, “Female” or “Other”age: age of the patienthypertension: 0 if the patient doesn’t have hypertension, 1 if the patient has hypertensionheart_disease: 0 if the patient doesn’t have any heart diseases, 1 if the patient has a heart diseaseever_married: “No” or “Yes”work_type: “children”, “Govt_jov”, “Never_worked”, “Private” or “Self-employed”Residence_type: “Rural” or “Urban”avg_glucose_level: average glucose level in bloodbmi: body mass indexsmoking_status: “formerly smoked”, “never smoked”, “smokes” or “Unknown”*stroke: 1 if the patient had a stroke or 0 if not
df = pd.read_csv("https://storage.googleapis.com/kagglesdsdata/datasets/1120859/1882037/healthcare-dataset-stroke-data.csv?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com%40kaggle-161607.iam.gserviceaccount.com%2F20251120%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20251120T010128Z&X-Goog-Expires=259200&X-Goog-SignedHeaders=host&X-Goog-Signature=8c4301a1478060dbf299493e46d7f3129a844cb7d9a60acdb701482e347fb480e07720edbb494bb426d84da9391bfb6438b7feb9eb9c3b8e3dc43f020fd44486c10096a540e42bf3f2e4a11b02b2fea91998884fdb2c525b651d96a5b090e8fd472db49c13fcc6e27586a22e80015322b23d56274ced92020504e184e04bfb457ea63ec1ef2ed14bd04e8809f27048ae5ee0e4a3a690ea8c04ea27b1e58cf73e29b2f3ab642bd23c2507a6b13c176c33df6b257e2b5790ed81cfff8da4fb3213eb8e36266bbd6c64a278df640b349c8c053d9da103b691ff0d75384584b4ac0ecc754701e48a1b5e7e1077f2e4d408f44b45ed038553ea7c08df03b09db1537a")
df.head()| id | gender | age | hypertension | heart_disease | ever_married | work_type | Residence_type | avg_glucose_level | bmi | smoking_status | stroke | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 9046 | Male | 67.0 | 0 | 1 | Yes | Private | Urban | 228.69 | 36.6 | formerly smoked | 1 |
| 1 | 51676 | Female | 61.0 | 0 | 0 | Yes | Self-employed | Rural | 202.21 | NaN | never smoked | 1 |
| 2 | 31112 | Male | 80.0 | 0 | 1 | Yes | Private | Rural | 105.92 | 32.5 | never smoked | 1 |
| 3 | 60182 | Female | 49.0 | 0 | 0 | Yes | Private | Urban | 171.23 | 34.4 | smokes | 1 |
| 4 | 1665 | Female | 79.0 | 1 | 0 | Yes | Self-employed | Rural | 174.12 | 24.0 | never smoked | 1 |
We exlude id and group the categorical and numerical features to start exploration.
df.columns = df.columns.str.lower()
colnames = list(df.columns[1:,].values)Missing values checking
df.isnull().sum()id 0
gender 0
age 0
hypertension 0
heart_disease 0
ever_married 0
work_type 0
residence_type 0
avg_glucose_level 0
bmi 201
smoking_status 0
stroke 0
dtype: int64
Column bmi has some missing values. Is it possible to impute missing values? Or try to use some scenarios to fill the missing with mean, median or null to find the best result.
Let’s explore the data further.
Distribution of stroke
stroke_rate = df['stroke'].value_counts().reset_index()
stroke_rate.columns = ['Stroke', 'Count']
sns.barplot(
data = stroke_rate,
x = 'Stroke',
y = 'Count',
hue='Stroke'
)
plt.title('Distribution of stroke among patients')
plt.xticks(ticks = [0, 1], labels = ['No', 'Yes'])
plt.xlabel('Stroke status')
plt.ylabel('Number of Patients')Text(0, 0.5, 'Number of Patients')
len(df[df.stroke == 1]) / len(df) * 1004.87279843444227
There is an imbalance in the response class of the stroke status. Only about 4.87% have stroke
Independent features exploration
Dataset has both categorical and numerical features
categorical = ['gender', 'hypertension', 'heart_disease', 'ever_married', 'work_type', 'residence_type', 'smoking_status']
numerical = ['age', 'avg_glucose_level', 'bmi']Let’s see the distribution of numerical features first
df[numerical].describe().T| count | mean | std | min | 25% | 50% | 75% | max | |
|---|---|---|---|---|---|---|---|---|
| age | 5110.0 | 43.226614 | 22.612647 | 0.08 | 25.000 | 45.000 | 61.00 | 82.00 |
| avg_glucose_level | 5110.0 | 106.147677 | 45.283560 | 55.12 | 77.245 | 91.885 | 114.09 | 271.74 |
| bmi | 4909.0 | 28.893237 | 7.854067 | 10.30 | 23.500 | 28.100 | 33.10 | 97.60 |
Correlation matrix for the numerical features
corr_mtx = df[numerical].corr()
corr_mtx.style.background_gradient(cmap='coolwarm')| age | avg_glucose_level | bmi | |
|---|---|---|---|
| age | 1.000000 | 0.238171 | 0.333398 |
| avg_glucose_level | 0.238171 | 1.000000 | 0.175502 |
| bmi | 0.333398 | 0.175502 | 1.000000 |
fig, ax = plt.subplots(nrows = 1, ncols = 3, figsize = (15, 4))
ax = ax.flatten()
for i, col in enumerate(numerical):
sns.histplot(data = df[numerical].dropna(), x = col, ax = ax[i]).set_title(f'Distribution of {col}')Based on the histograms above, the average patient is: - Age: around 45 years old; middle-aged; wide-range from infancy to elderly - Avg glucose level: median ~92; slighly elevated; may indicate pre-diabetic condition - BMI: median around 28.1; overweight; suggests many patients are overweight or obese
The next one, let’s explore categorical features!
fig, ax = plt.subplots(3, 3, figsize = (15, 10))
ax = ax.flatten()
for i, col in enumerate(categorical):
sns.countplot(df[categorical], x = col, ax = ax[i])
ax[i].set_title(f'Distribution of {col}')
ax[i].tick_params(axis = 'x', rotation = 45)
for j in range(len(categorical), 9):
ax[j].axis('off')From the categorial features distribution, we can see: - the dataset is female-dominant (~58%) - most patients are healthy (no history of hypertension or heart disease). Maybe because most have never smoked - most patients have been or are married and are employed, working in either private or government sectors or run a business
Association between independent features and stroke status
stroke_group_median = df.groupby('stroke')[numerical].median().reset_index()
stroke_group_median.stroke = stroke_group_median['stroke'].map({0: 'No', 1: 'Yes'})
stroke_group_median| stroke | age | avg_glucose_level | bmi | |
|---|---|---|---|---|
| 0 | No | 43.0 | 91.47 | 28.0 |
| 1 | Yes | 71.0 | 105.22 | 29.7 |
Patients who had stroke tend to be significantly older. They also show higher median glucose levels and slightly higher BMI. This suggests that age and glucose levels may be important factors associated with stroke risk.
# stroke rate by categorical features
for col in categorical:
rate = df.groupby(col)['stroke'].mean().reset_index()
rate.columns = [col, 'stroke_rate']
rate['stroke_rate'] = rate['stroke_rate'].round(3) * 100
print(f'\nStroke Rate by {col}:\n', rate)
Stroke Rate by gender:
gender stroke_rate
0 Female 4.7
1 Male 5.1
2 Other 0.0
Stroke Rate by hypertension:
hypertension stroke_rate
0 0 4.0
1 1 13.3
Stroke Rate by heart_disease:
heart_disease stroke_rate
0 0 4.2
1 1 17.0
Stroke Rate by ever_married:
ever_married stroke_rate
0 No 1.7
1 Yes 6.6
Stroke Rate by work_type:
work_type stroke_rate
0 Govt_job 5.0
1 Never_worked 0.0
2 Private 5.1
3 Self-employed 7.9
4 children 0.3
Stroke Rate by residence_type:
residence_type stroke_rate
0 Rural 4.5
1 Urban 5.2
Stroke Rate by smoking_status:
smoking_status stroke_rate
0 Unknown 3.0
1 formerly smoked 7.9
2 never smoked 4.8
3 smokes 5.3
The summary above shows that patients at risk of a stroke are: - gender: male (higher risk) and female - unhealthy: have smoke, have hypertension, have heart disease - are/have married and are working professionals
Other gender, children, and unemployed groups have low stroke rates.
BMI and rare categories (gender, work_type)
Other patients and patients who never_worked are very rare categories (which shows from their count in the data). I will explore these so as to know the appropriate way to handle them, else they add noise to the model.
df[df['gender'] == 'Other']| id | gender | age | hypertension | heart_disease | ever_married | work_type | residence_type | avg_glucose_level | bmi | smoking_status | stroke | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 3116 | 56156 | Other | 26.0 | 0 | 0 | No | Private | Rural | 143.33 | 22.4 | formerly smoked | 0 |
Only one patient for the Other patient and is a non-stroke case. We can probably drop this row as it won’t contribute to the model.
df['bmi_missing'] = df['bmi'].isnull()
df.groupby('bmi_missing')[['age', 'avg_glucose_level']].mean()| age | avg_glucose_level | |
|---|---|---|
| bmi_missing | ||
| False | 42.865374 | 105.305150 |
| True | 52.049154 | 126.724627 |
for col in categorical:
missing_rate = df.groupby(col)['bmi_missing'].mean().reset_index()
missing_rate.columns = [col, 'Missing BMI Rate']
missing_rate['Missing BMI Rate'] = (missing_rate['Missing BMI Rate'] * 100).round(3)
print(f'\nMissing BMI Rate by {col}:\n', missing_rate)
Missing BMI Rate by gender:
gender Missing BMI Rate
0 Female 3.240
1 Male 4.917
2 Other 0.000
Missing BMI Rate by hypertension:
hypertension Missing BMI Rate
0 0 3.339
1 1 9.438
Missing BMI Rate by heart_disease:
heart_disease Missing BMI Rate
0 0 3.475
1 1 11.957
Missing BMI Rate by ever_married:
ever_married Missing BMI Rate
0 No 2.960
1 Yes 4.444
Missing BMI Rate by work_type:
work_type Missing BMI Rate
0 Govt_job 4.110
1 Never_worked 0.000
2 Private 3.897
3 Self-employed 5.372
4 children 2.329
Missing BMI Rate by residence_type:
residence_type Missing BMI Rate
0 Rural 3.779
1 Urban 4.083
Missing BMI Rate by smoking_status:
smoking_status Missing BMI Rate
0 Unknown 3.951
1 formerly smoked 5.424
2 never smoked 2.114
3 smokes 6.591
The missing value are not random. They cluster around: - Older individuals - Those with hypertension, heart disease, and higher glucose levels - Smokers and employed individuals
Preparation and Imputation
Next steps is to drop id, Other gender, then impute bmi using KNNImputer and one-hot encode categorical features
df_full = df.drop(columns=['id'])
df_full = df_full[df_full['gender'] != 'Other'].reset_index(drop=True)df_full.groupby('work_type')['age'].mean()work_type
Govt_job 50.879756
Never_worked 16.181818
Private 45.510602
Self-employed 60.201465
children 6.841339
Name: age, dtype: float64
Children and never_worked groups are very small. We can classify them as binary variables with employed vs not-employed.
employed_types = ['Private', 'Self-employed', 'Govt_job']
df_full['employed'] = df_full['work_type'].apply(lambda x: 1 if x in employed_types else 0)
df_full = df_full.drop(columns=['work_type'])
df_full.head()| gender | age | hypertension | heart_disease | ever_married | residence_type | avg_glucose_level | bmi | smoking_status | stroke | bmi_missing | employed | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | Male | 67.0 | 0 | 1 | Yes | Urban | 228.69 | 36.6 | formerly smoked | 1 | False | 1 |
| 1 | Female | 61.0 | 0 | 0 | Yes | Rural | 202.21 | NaN | never smoked | 1 | True | 1 |
| 2 | Male | 80.0 | 0 | 1 | Yes | Rural | 105.92 | 32.5 | never smoked | 1 | False | 1 |
| 3 | Female | 49.0 | 0 | 0 | Yes | Urban | 171.23 | 34.4 | smokes | 1 | False | 1 |
| 4 | Female | 79.0 | 1 | 0 | Yes | Rural | 174.12 | 24.0 | never smoked | 1 | False | 1 |
categorical = ['gender', 'hypertension', 'heart_disease', 'ever_married', 'employed', 'residence_type', 'smoking_status']
df_cat_ohe = pd.get_dummies(df_full[categorical], drop_first=True)
df_ready_impute = pd.concat([df_full[numerical], df_cat_ohe], axis=1)
df_ready_impute.head()| age | avg_glucose_level | bmi | hypertension | heart_disease | employed | gender_Male | ever_married_Yes | residence_type_Urban | smoking_status_formerly smoked | smoking_status_never smoked | smoking_status_smokes | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 67.0 | 228.69 | 36.6 | 0 | 1 | 1 | True | True | True | True | False | False |
| 1 | 61.0 | 202.21 | NaN | 0 | 0 | 1 | False | True | False | False | True | False |
| 2 | 80.0 | 105.92 | 32.5 | 0 | 1 | 1 | True | True | False | False | True | False |
| 3 | 49.0 | 171.23 | 34.4 | 0 | 0 | 1 | False | True | True | False | False | True |
| 4 | 79.0 | 174.12 | 24.0 | 1 | 0 | 1 | False | True | False | False | True | False |
imputer = KNNImputer(n_neighbors=5)
df_imputed = imputer.fit_transform(df_ready_impute)
df_imputedarray([[ 67. , 228.69, 36.6 , ..., 1. , 0. , 0. ],
[ 61. , 202.21, 32.4 , ..., 0. , 1. , 0. ],
[ 80. , 105.92, 32.5 , ..., 0. , 1. , 0. ],
...,
[ 35. , 82.99, 30.6 , ..., 0. , 1. , 0. ],
[ 51. , 166.29, 25.6 , ..., 1. , 0. , 0. ],
[ 44. , 85.28, 26.2 , ..., 0. , 0. , 0. ]])
# if we want to get the imputed bmi values back to the original dataframe
df_full['bmi'] = df_imputed[:, numerical.index('bmi')]
df_ready_impute['bmi'] = df_imputed[:, numerical.index('bmi')]
df_ready_impute.head()| age | avg_glucose_level | bmi | hypertension | heart_disease | employed | gender_Male | ever_married_Yes | residence_type_Urban | smoking_status_formerly smoked | smoking_status_never smoked | smoking_status_smokes | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 67.0 | 228.69 | 36.6 | 0 | 1 | 1 | True | True | True | True | False | False |
| 1 | 61.0 | 202.21 | 32.4 | 0 | 0 | 1 | False | True | False | False | True | False |
| 2 | 80.0 | 105.92 | 32.5 | 0 | 1 | 1 | True | True | False | False | True | False |
| 3 | 49.0 | 171.23 | 34.4 | 0 | 0 | 1 | False | True | True | False | False | True |
| 4 | 79.0 | 174.12 | 24.0 | 1 | 0 | 1 | False | True | False | False | True | False |
Setup validation framework
Split the data in train/val/test sets, with 60%/20%/20% distribution.
df_full = df_full.drop(columns=['bmi_missing'])
df_full_train, df_test = train_test_split(df_full, test_size=0.2, random_state=42)
df_train, df_val = train_test_split(df_full_train, test_size=0.25, random_state=42)
df_train = df_train.reset_index(drop=True)
df_val = df_val.reset_index(drop=True)
df_test = df_test.reset_index(drop=True)
y_train = df_train.stroke.values
y_val = df_val.stroke.values
y_test = df_test.stroke.values
del df_train["stroke"]
del df_val["stroke"]
del df_test["stroke"]One-hot encoding
df_train.to_dict(orient='records')[0]{'gender': 'Female',
'age': 72.0,
'hypertension': 0,
'heart_disease': 1,
'ever_married': 'No',
'residence_type': 'Rural',
'avg_glucose_level': 124.38,
'bmi': 23.4,
'smoking_status': 'formerly smoked',
'employed': 1}
dv = DictVectorizer(sparse=False)
train_dicts = df_train.to_dict(orient='records')
val_dicts = df_val.to_dict(orient='records')
X_train = dv.fit_transform(train_dicts)
X_val = dv.transform(val_dicts)
dv.get_feature_names_out()array(['age', 'avg_glucose_level', 'bmi', 'employed', 'ever_married=No',
'ever_married=Yes', 'gender=Female', 'gender=Male',
'heart_disease', 'hypertension', 'residence_type=Rural',
'residence_type=Urban', 'smoking_status=Unknown',
'smoking_status=formerly smoked', 'smoking_status=never smoked',
'smoking_status=smokes'], dtype=object)
Handling imbalanced classes
I apply SMOTE to oversample the minority class in the training dataset
smote = SMOTE(random_state=42)
X_train_res, y_train_res = smote.fit_resample(X_train, y_train)
X_train.shape, X_train_res.shape((3065, 16), (5838, 16))
The training set size increased from 3065 to 5838 after applying SMOTE, balancing the classes for better model training.
Modelling
Logistic Regression
model = LogisticRegression(solver='liblinear', C=1.0, max_iter=1000, random_state=42)
model.fit(X_train_res, y_train_res)
y_pred = model.predict_proba(X_train_res)[:, 1]
auc = roc_auc_score(y_train_res, y_pred)
print(f"AUC on train dataset: {round(auc, 4)}")
y_pred = model.predict_proba(X_val)[:, 1]
auc = roc_auc_score(y_val, y_pred)
print(f"AUC on validation dataset: {round(auc, 4)}")AUC on train dataset: 0.8596
AUC on validation dataset: 0.8132
Tuning the parameter C
params = [0.01, 0.1, 1, 10, 100]
for C in params:
model = LogisticRegression(solver='liblinear', C=C, max_iter=1000, random_state=42)
model.fit(X_train_res, y_train_res)
y_pred = model.predict_proba(X_val)[:, 1]
auc = roc_auc_score(y_val, y_pred)
print(f"{auc:.4f} for parameter C={C}")0.7946 for parameter C=0.01
0.8110 for parameter C=0.1
0.8132 for parameter C=1
0.8134 for parameter C=10
0.8133 for parameter C=100
The best auc so far using Logistic Regression in validation dataset is 0.8134 for parameter C=10
Full training set
y_full_train = df_full_train.stroke.values
del df_full_train["stroke"]
full_train_dicts = df_full_train.to_dict(orient='records')
X_full_train = dv.fit_transform(full_train_dicts)
smote = SMOTE(random_state=42)
X_full_train_res, y_full_train_res = smote.fit_resample(X_full_train, y_full_train)
dicts_test = df_test.to_dict(orient='records')
X_test = dv.transform(dicts_test)
model = LogisticRegression(solver='liblinear', C=10, max_iter=1000, random_state=42)
model.fit(X_full_train_res, y_full_train_res)
y_pred = model.predict_proba(X_test)[:, 1]
auc = roc_auc_score(y_test, y_pred)
print(f"AUC on testing dataset: {round(auc, 4)}")AUC on testing dataset: 0.8426
Decision tree
dt = DecisionTreeClassifier()
dt.fit(X_train, y_train)
print(export_text(dt, feature_names=dv.get_feature_names_out()))|--- age <= 67.50
| |--- age <= 47.50
| | |--- bmi <= 56.30
| | | |--- age <= 37.50
| | | | |--- age <= 1.36
| | | | | |--- age <= 1.28
| | | | | | |--- class: 0
| | | | | |--- age > 1.28
| | | | | | |--- avg_glucose_level <= 72.80
| | | | | | | |--- class: 1
| | | | | | |--- avg_glucose_level > 72.80
| | | | | | | |--- class: 0
| | | | |--- age > 1.36
| | | | | |--- class: 0
| | | |--- age > 37.50
| | | | |--- avg_glucose_level <= 58.14
| | | | | |--- avg_glucose_level <= 57.94
| | | | | | |--- class: 0
| | | | | |--- avg_glucose_level > 57.94
| | | | | | |--- class: 1
| | | | |--- avg_glucose_level > 58.14
| | | | | |--- bmi <= 30.85
| | | | | | |--- bmi <= 30.73
| | | | | | | |--- smoking_status=formerly smoked <= 0.50
| | | | | | | | |--- smoking_status=Unknown <= 0.50
| | | | | | | | | |--- class: 0
| | | | | | | | |--- smoking_status=Unknown > 0.50
| | | | | | | | | |--- avg_glucose_level <= 83.43
| | | | | | | | | | |--- avg_glucose_level <= 83.18
| | | | | | | | | | | |--- class: 0
| | | | | | | | | | |--- avg_glucose_level > 83.18
| | | | | | | | | | | |--- class: 1
| | | | | | | | | |--- avg_glucose_level > 83.43
| | | | | | | | | | |--- class: 0
| | | | | | | |--- smoking_status=formerly smoked > 0.50
| | | | | | | | |--- bmi <= 29.90
| | | | | | | | | |--- age <= 39.00
| | | | | | | | | | |--- avg_glucose_level <= 83.53
| | | | | | | | | | | |--- class: 1
| | | | | | | | | | |--- avg_glucose_level > 83.53
| | | | | | | | | | | |--- class: 0
| | | | | | | | | |--- age > 39.00
| | | | | | | | | | |--- class: 0
| | | | | | | | |--- bmi > 29.90
| | | | | | | | | |--- bmi <= 30.35
| | | | | | | | | | |--- class: 1
| | | | | | | | | |--- bmi > 30.35
| | | | | | | | | | |--- class: 0
| | | | | | |--- bmi > 30.73
| | | | | | | |--- smoking_status=formerly smoked <= 0.50
| | | | | | | | |--- class: 1
| | | | | | | |--- smoking_status=formerly smoked > 0.50
| | | | | | | | |--- class: 0
| | | | | |--- bmi > 30.85
| | | | | | |--- class: 0
| | |--- bmi > 56.30
| | | |--- age <= 44.00
| | | | |--- class: 0
| | | |--- age > 44.00
| | | | |--- class: 1
| |--- age > 47.50
| | |--- avg_glucose_level <= 150.81
| | | |--- gender=Female <= 0.50
| | | | |--- bmi <= 27.25
| | | | | |--- class: 0
| | | | |--- bmi > 27.25
| | | | | |--- bmi <= 32.05
| | | | | | |--- bmi <= 31.60
| | | | | | | |--- smoking_status=smokes <= 0.50
| | | | | | | | |--- age <= 48.50
| | | | | | | | | |--- ever_married=No <= 0.50
| | | | | | | | | | |--- class: 0
| | | | | | | | | |--- ever_married=No > 0.50
| | | | | | | | | | |--- class: 1
| | | | | | | | |--- age > 48.50
| | | | | | | | | |--- bmi <= 28.59
| | | | | | | | | | |--- bmi <= 28.54
| | | | | | | | | | | |--- truncated branch of depth 4
| | | | | | | | | | |--- bmi > 28.54
| | | | | | | | | | | |--- class: 1
| | | | | | | | | |--- bmi > 28.59
| | | | | | | | | | |--- age <= 64.50
| | | | | | | | | | | |--- class: 0
| | | | | | | | | | |--- age > 64.50
| | | | | | | | | | | |--- truncated branch of depth 3
| | | | | | | |--- smoking_status=smokes > 0.50
| | | | | | | | |--- bmi <= 27.40
| | | | | | | | | |--- age <= 60.00
| | | | | | | | | | |--- class: 0
| | | | | | | | | |--- age > 60.00
| | | | | | | | | | |--- class: 1
| | | | | | | | |--- bmi > 27.40
| | | | | | | | | |--- bmi <= 31.00
| | | | | | | | | | |--- residence_type=Rural <= 0.50
| | | | | | | | | | | |--- class: 0
| | | | | | | | | | |--- residence_type=Rural > 0.50
| | | | | | | | | | | |--- truncated branch of depth 2
| | | | | | | | | |--- bmi > 31.00
| | | | | | | | | | |--- bmi <= 31.32
| | | | | | | | | | | |--- class: 1
| | | | | | | | | | |--- bmi > 31.32
| | | | | | | | | | | |--- class: 0
| | | | | | |--- bmi > 31.60
| | | | | | | |--- smoking_status=never smoked <= 0.50
| | | | | | | | |--- age <= 60.00
| | | | | | | | | |--- class: 1
| | | | | | | | |--- age > 60.00
| | | | | | | | | |--- bmi <= 31.80
| | | | | | | | | | |--- class: 1
| | | | | | | | | |--- bmi > 31.80
| | | | | | | | | | |--- class: 0
| | | | | | | |--- smoking_status=never smoked > 0.50
| | | | | | | | |--- class: 0
| | | | | |--- bmi > 32.05
| | | | | | |--- heart_disease <= 0.50
| | | | | | | |--- smoking_status=Unknown <= 0.50
| | | | | | | | |--- class: 0
| | | | | | | |--- smoking_status=Unknown > 0.50
| | | | | | | | |--- bmi <= 36.45
| | | | | | | | | |--- class: 0
| | | | | | | | |--- bmi > 36.45
| | | | | | | | | |--- bmi <= 37.35
| | | | | | | | | | |--- class: 1
| | | | | | | | | |--- bmi > 37.35
| | | | | | | | | | |--- class: 0
| | | | | | |--- heart_disease > 0.50
| | | | | | | |--- hypertension <= 0.50
| | | | | | | | |--- residence_type=Rural <= 0.50
| | | | | | | | | |--- class: 1
| | | | | | | | |--- residence_type=Rural > 0.50
| | | | | | | | | |--- class: 0
| | | | | | | |--- hypertension > 0.50
| | | | | | | | |--- class: 0
| | | |--- gender=Female > 0.50
| | | | |--- avg_glucose_level <= 96.42
| | | | | |--- age <= 49.50
| | | | | | |--- smoking_status=never smoked <= 0.50
| | | | | | | |--- class: 0
| | | | | | |--- smoking_status=never smoked > 0.50
| | | | | | | |--- avg_glucose_level <= 63.12
| | | | | | | | |--- class: 1
| | | | | | | |--- avg_glucose_level > 63.12
| | | | | | | | |--- bmi <= 23.00
| | | | | | | | | |--- class: 1
| | | | | | | | |--- bmi > 23.00
| | | | | | | | | |--- class: 0
| | | | | |--- age > 49.50
| | | | | | |--- hypertension <= 0.50
| | | | | | | |--- class: 0
| | | | | | |--- hypertension > 0.50
| | | | | | | |--- avg_glucose_level <= 68.11
| | | | | | | | |--- avg_glucose_level <= 64.77
| | | | | | | | | |--- class: 0
| | | | | | | | |--- avg_glucose_level > 64.77
| | | | | | | | | |--- class: 1
| | | | | | | |--- avg_glucose_level > 68.11
| | | | | | | | |--- class: 0
| | | | |--- avg_glucose_level > 96.42
| | | | | |--- avg_glucose_level <= 96.68
| | | | | | |--- class: 1
| | | | | |--- avg_glucose_level > 96.68
| | | | | | |--- age <= 56.50
| | | | | | | |--- smoking_status=formerly smoked <= 0.50
| | | | | | | | |--- class: 0
| | | | | | | |--- smoking_status=formerly smoked > 0.50
| | | | | | | | |--- residence_type=Urban <= 0.50
| | | | | | | | | |--- avg_glucose_level <= 106.35
| | | | | | | | | | |--- avg_glucose_level <= 101.53
| | | | | | | | | | | |--- class: 0
| | | | | | | | | | |--- avg_glucose_level > 101.53
| | | | | | | | | | | |--- class: 1
| | | | | | | | | |--- avg_glucose_level > 106.35
| | | | | | | | | | |--- class: 0
| | | | | | | | |--- residence_type=Urban > 0.50
| | | | | | | | | |--- class: 0
| | | | | | |--- age > 56.50
| | | | | | | |--- smoking_status=never smoked <= 0.50
| | | | | | | | |--- heart_disease <= 0.50
| | | | | | | | | |--- avg_glucose_level <= 97.51
| | | | | | | | | | |--- class: 1
| | | | | | | | | |--- avg_glucose_level > 97.51
| | | | | | | | | | |--- age <= 62.50
| | | | | | | | | | | |--- truncated branch of depth 4
| | | | | | | | | | |--- age > 62.50
| | | | | | | | | | | |--- truncated branch of depth 4
| | | | | | | | |--- heart_disease > 0.50
| | | | | | | | | |--- class: 1
| | | | | | | |--- smoking_status=never smoked > 0.50
| | | | | | | | |--- class: 0
| | |--- avg_glucose_level > 150.81
| | | |--- avg_glucose_level <= 151.36
| | | | |--- class: 1
| | | |--- avg_glucose_level > 151.36
| | | | |--- heart_disease <= 0.50
| | | | | |--- avg_glucose_level <= 197.61
| | | | | | |--- avg_glucose_level <= 197.19
| | | | | | | |--- gender=Male <= 0.50
| | | | | | | | |--- age <= 51.50
| | | | | | | | | |--- avg_glucose_level <= 183.19
| | | | | | | | | | |--- class: 1
| | | | | | | | | |--- avg_glucose_level > 183.19
| | | | | | | | | | |--- class: 0
| | | | | | | | |--- age > 51.50
| | | | | | | | | |--- age <= 66.00
| | | | | | | | | | |--- age <= 56.50
| | | | | | | | | | | |--- truncated branch of depth 4
| | | | | | | | | | |--- age > 56.50
| | | | | | | | | | | |--- class: 0
| | | | | | | | | |--- age > 66.00
| | | | | | | | | | |--- class: 1
| | | | | | | |--- gender=Male > 0.50
| | | | | | | | |--- smoking_status=Unknown <= 0.50
| | | | | | | | | |--- class: 0
| | | | | | | | |--- smoking_status=Unknown > 0.50
| | | | | | | | | |--- residence_type=Urban <= 0.50
| | | | | | | | | | |--- avg_glucose_level <= 180.38
| | | | | | | | | | | |--- class: 0
| | | | | | | | | | |--- avg_glucose_level > 180.38
| | | | | | | | | | | |--- class: 1
| | | | | | | | | |--- residence_type=Urban > 0.50
| | | | | | | | | | |--- class: 0
| | | | | | |--- avg_glucose_level > 197.19
| | | | | | | |--- class: 1
| | | | | |--- avg_glucose_level > 197.61
| | | | | | |--- bmi <= 20.80
| | | | | | | |--- bmi <= 17.20
| | | | | | | | |--- class: 0
| | | | | | | |--- bmi > 17.20
| | | | | | | | |--- class: 1
| | | | | | |--- bmi > 20.80
| | | | | | | |--- bmi <= 42.00
| | | | | | | | |--- smoking_status=never smoked <= 0.50
| | | | | | | | | |--- class: 0
| | | | | | | | |--- smoking_status=never smoked > 0.50
| | | | | | | | | |--- avg_glucose_level <= 246.81
| | | | | | | | | | |--- age <= 60.50
| | | | | | | | | | | |--- class: 0
| | | | | | | | | | |--- age > 60.50
| | | | | | | | | | | |--- truncated branch of depth 3
| | | | | | | | | |--- avg_glucose_level > 246.81
| | | | | | | | | | |--- gender=Female <= 0.50
| | | | | | | | | | | |--- class: 1
| | | | | | | | | | |--- gender=Female > 0.50
| | | | | | | | | | | |--- class: 0
| | | | | | | |--- bmi > 42.00
| | | | | | | | |--- bmi <= 42.25
| | | | | | | | | |--- class: 1
| | | | | | | | |--- bmi > 42.25
| | | | | | | | | |--- residence_type=Rural <= 0.50
| | | | | | | | | | |--- bmi <= 45.60
| | | | | | | | | | | |--- class: 0
| | | | | | | | | | |--- bmi > 45.60
| | | | | | | | | | | |--- truncated branch of depth 3
| | | | | | | | | |--- residence_type=Rural > 0.50
| | | | | | | | | | |--- class: 0
| | | | |--- heart_disease > 0.50
| | | | | |--- avg_glucose_level <= 240.27
| | | | | | |--- bmi <= 36.10
| | | | | | | |--- avg_glucose_level <= 213.97
| | | | | | | | |--- class: 0
| | | | | | | |--- avg_glucose_level > 213.97
| | | | | | | | |--- bmi <= 31.60
| | | | | | | | | |--- class: 1
| | | | | | | | |--- bmi > 31.60
| | | | | | | | | |--- class: 0
| | | | | | |--- bmi > 36.10
| | | | | | | |--- avg_glucose_level <= 215.32
| | | | | | | | |--- class: 1
| | | | | | | |--- avg_glucose_level > 215.32
| | | | | | | | |--- class: 0
| | | | | |--- avg_glucose_level > 240.27
| | | | | | |--- class: 1
|--- age > 67.50
| |--- avg_glucose_level <= 103.64
| | |--- hypertension <= 0.50
| | | |--- bmi <= 43.80
| | | | |--- bmi <= 34.50
| | | | | |--- smoking_status=smokes <= 0.50
| | | | | | |--- bmi <= 31.95
| | | | | | | |--- bmi <= 29.55
| | | | | | | | |--- ever_married=Yes <= 0.50
| | | | | | | | | |--- age <= 79.50
| | | | | | | | | | |--- class: 1
| | | | | | | | | |--- age > 79.50
| | | | | | | | | | |--- class: 0
| | | | | | | | |--- ever_married=Yes > 0.50
| | | | | | | | | |--- bmi <= 29.45
| | | | | | | | | | |--- bmi <= 28.45
| | | | | | | | | | | |--- truncated branch of depth 9
| | | | | | | | | | |--- bmi > 28.45
| | | | | | | | | | | |--- truncated branch of depth 5
| | | | | | | | | |--- bmi > 29.45
| | | | | | | | | | |--- class: 1
| | | | | | | |--- bmi > 29.55
| | | | | | | | |--- class: 0
| | | | | | |--- bmi > 31.95
| | | | | | | |--- smoking_status=never smoked <= 0.50
| | | | | | | | |--- bmi <= 32.35
| | | | | | | | | |--- heart_disease <= 0.50
| | | | | | | | | | |--- class: 0
| | | | | | | | | |--- heart_disease > 0.50
| | | | | | | | | | |--- class: 1
| | | | | | | | |--- bmi > 32.35
| | | | | | | | | |--- class: 0
| | | | | | | |--- smoking_status=never smoked > 0.50
| | | | | | | | |--- bmi <= 33.10
| | | | | | | | | |--- avg_glucose_level <= 84.72
| | | | | | | | | | |--- class: 0
| | | | | | | | | |--- avg_glucose_level > 84.72
| | | | | | | | | | |--- class: 1
| | | | | | | | |--- bmi > 33.10
| | | | | | | | | |--- age <= 79.50
| | | | | | | | | | |--- age <= 75.50
| | | | | | | | | | | |--- class: 1
| | | | | | | | | | |--- age > 75.50
| | | | | | | | | | | |--- class: 0
| | | | | | | | | |--- age > 79.50
| | | | | | | | | | |--- class: 1
| | | | | |--- smoking_status=smokes > 0.50
| | | | | | |--- age <= 71.50
| | | | | | | |--- gender=Female <= 0.50
| | | | | | | | |--- class: 1
| | | | | | | |--- gender=Female > 0.50
| | | | | | | | |--- class: 0
| | | | | | |--- age > 71.50
| | | | | | | |--- age <= 78.50
| | | | | | | | |--- class: 0
| | | | | | | |--- age > 78.50
| | | | | | | | |--- avg_glucose_level <= 68.45
| | | | | | | | | |--- class: 0
| | | | | | | | |--- avg_glucose_level > 68.45
| | | | | | | | | |--- age <= 81.50
| | | | | | | | | | |--- class: 1
| | | | | | | | | |--- age > 81.50
| | | | | | | | | | |--- class: 0
| | | | |--- bmi > 34.50
| | | | | |--- class: 0
| | | |--- bmi > 43.80
| | | | |--- class: 1
| | |--- hypertension > 0.50
| | | |--- avg_glucose_level <= 66.01
| | | | |--- class: 0
| | | |--- avg_glucose_level > 66.01
| | | | |--- avg_glucose_level <= 93.89
| | | | | |--- ever_married=Yes <= 0.50
| | | | | | |--- residence_type=Urban <= 0.50
| | | | | | | |--- class: 0
| | | | | | |--- residence_type=Urban > 0.50
| | | | | | | |--- class: 1
| | | | | |--- ever_married=Yes > 0.50
| | | | | | |--- bmi <= 27.45
| | | | | | | |--- avg_glucose_level <= 76.09
| | | | | | | | |--- residence_type=Urban <= 0.50
| | | | | | | | | |--- class: 1
| | | | | | | | |--- residence_type=Urban > 0.50
| | | | | | | | | |--- smoking_status=formerly smoked <= 0.50
| | | | | | | | | | |--- class: 1
| | | | | | | | | |--- smoking_status=formerly smoked > 0.50
| | | | | | | | | | |--- class: 0
| | | | | | | |--- avg_glucose_level > 76.09
| | | | | | | | |--- bmi <= 26.70
| | | | | | | | | |--- class: 0
| | | | | | | | |--- bmi > 26.70
| | | | | | | | | |--- class: 1
| | | | | | |--- bmi > 27.45
| | | | | | | |--- avg_glucose_level <= 88.98
| | | | | | | | |--- smoking_status=Unknown <= 0.50
| | | | | | | | | |--- age <= 80.50
| | | | | | | | | | |--- class: 0
| | | | | | | | | |--- age > 80.50
| | | | | | | | | | |--- heart_disease <= 0.50
| | | | | | | | | | | |--- class: 1
| | | | | | | | | | |--- heart_disease > 0.50
| | | | | | | | | | | |--- class: 0
| | | | | | | | |--- smoking_status=Unknown > 0.50
| | | | | | | | | |--- class: 1
| | | | | | | |--- avg_glucose_level > 88.98
| | | | | | | | |--- avg_glucose_level <= 91.32
| | | | | | | | | |--- class: 1
| | | | | | | | |--- avg_glucose_level > 91.32
| | | | | | | | | |--- class: 0
| | | | |--- avg_glucose_level > 93.89
| | | | | |--- class: 0
| |--- avg_glucose_level > 103.64
| | |--- age <= 73.50
| | | |--- bmi <= 40.40
| | | | |--- ever_married=Yes <= 0.50
| | | | | |--- gender=Female <= 0.50
| | | | | | |--- class: 1
| | | | | |--- gender=Female > 0.50
| | | | | | |--- class: 0
| | | | |--- ever_married=Yes > 0.50
| | | | | |--- smoking_status=smokes <= 0.50
| | | | | | |--- residence_type=Rural <= 0.50
| | | | | | | |--- age <= 70.50
| | | | | | | | |--- class: 0
| | | | | | | |--- age > 70.50
| | | | | | | | |--- hypertension <= 0.50
| | | | | | | | | |--- heart_disease <= 0.50
| | | | | | | | | | |--- bmi <= 34.95
| | | | | | | | | | | |--- truncated branch of depth 5
| | | | | | | | | | |--- bmi > 34.95
| | | | | | | | | | | |--- class: 0
| | | | | | | | | |--- heart_disease > 0.50
| | | | | | | | | | |--- class: 0
| | | | | | | | |--- hypertension > 0.50
| | | | | | | | | |--- class: 0
| | | | | | |--- residence_type=Rural > 0.50
| | | | | | | |--- class: 0
| | | | | |--- smoking_status=smokes > 0.50
| | | | | | |--- bmi <= 23.90
| | | | | | | |--- class: 1
| | | | | | |--- bmi > 23.90
| | | | | | | |--- age <= 69.50
| | | | | | | | |--- bmi <= 32.85
| | | | | | | | | |--- class: 1
| | | | | | | | |--- bmi > 32.85
| | | | | | | | | |--- class: 0
| | | | | | | |--- age > 69.50
| | | | | | | | |--- class: 0
| | | |--- bmi > 40.40
| | | | |--- age <= 70.50
| | | | | |--- class: 1
| | | | |--- age > 70.50
| | | | | |--- class: 0
| | |--- age > 73.50
| | | |--- avg_glucose_level <= 105.92
| | | | |--- gender=Female <= 0.50
| | | | | |--- class: 1
| | | | |--- gender=Female > 0.50
| | | | | |--- smoking_status=never smoked <= 0.50
| | | | | | |--- class: 0
| | | | | |--- smoking_status=never smoked > 0.50
| | | | | | |--- class: 1
| | | |--- avg_glucose_level > 105.92
| | | | |--- avg_glucose_level <= 200.66
| | | | | |--- avg_glucose_level <= 123.81
| | | | | | |--- residence_type=Urban <= 0.50
| | | | | | | |--- class: 0
| | | | | | |--- residence_type=Urban > 0.50
| | | | | | | |--- bmi <= 32.00
| | | | | | | | |--- bmi <= 28.07
| | | | | | | | | |--- bmi <= 25.21
| | | | | | | | | | |--- class: 0
| | | | | | | | | |--- bmi > 25.21
| | | | | | | | | | |--- hypertension <= 0.50
| | | | | | | | | | | |--- class: 1
| | | | | | | | | | |--- hypertension > 0.50
| | | | | | | | | | | |--- class: 0
| | | | | | | | |--- bmi > 28.07
| | | | | | | | | |--- class: 0
| | | | | | | |--- bmi > 32.00
| | | | | | | | |--- class: 1
| | | | | |--- avg_glucose_level > 123.81
| | | | | | |--- bmi <= 33.95
| | | | | | | |--- age <= 81.50
| | | | | | | | |--- bmi <= 23.05
| | | | | | | | | |--- class: 1
| | | | | | | | |--- bmi > 23.05
| | | | | | | | | |--- bmi <= 31.20
| | | | | | | | | | |--- avg_glucose_level <= 128.01
| | | | | | | | | | | |--- class: 1
| | | | | | | | | | |--- avg_glucose_level > 128.01
| | | | | | | | | | | |--- truncated branch of depth 6
| | | | | | | | | |--- bmi > 31.20
| | | | | | | | | | |--- bmi <= 31.85
| | | | | | | | | | | |--- class: 1
| | | | | | | | | | |--- bmi > 31.85
| | | | | | | | | | | |--- truncated branch of depth 3
| | | | | | | |--- age > 81.50
| | | | | | | | |--- class: 1
| | | | | | |--- bmi > 33.95
| | | | | | | |--- class: 0
| | | | |--- avg_glucose_level > 200.66
| | | | | |--- bmi <= 26.35
| | | | | | |--- class: 0
| | | | | |--- bmi > 26.35
| | | | | | |--- bmi <= 28.29
| | | | | | | |--- residence_type=Rural <= 0.50
| | | | | | | | |--- avg_glucose_level <= 228.81
| | | | | | | | | |--- class: 1
| | | | | | | | |--- avg_glucose_level > 228.81
| | | | | | | | | |--- smoking_status=never smoked <= 0.50
| | | | | | | | | | |--- class: 0
| | | | | | | | | |--- smoking_status=never smoked > 0.50
| | | | | | | | | | |--- class: 1
| | | | | | | |--- residence_type=Rural > 0.50
| | | | | | | | |--- class: 0
| | | | | | |--- bmi > 28.29
| | | | | | | |--- age <= 74.50
| | | | | | | | |--- residence_type=Urban <= 0.50
| | | | | | | | | |--- class: 1
| | | | | | | | |--- residence_type=Urban > 0.50
| | | | | | | | | |--- class: 0
| | | | | | | |--- age > 74.50
| | | | | | | | |--- avg_glucose_level <= 207.46
| | | | | | | | | |--- avg_glucose_level <= 203.61
| | | | | | | | | | |--- class: 0
| | | | | | | | | |--- avg_glucose_level > 203.61
| | | | | | | | | | |--- residence_type=Rural <= 0.50
| | | | | | | | | | | |--- class: 1
| | | | | | | | | | |--- residence_type=Rural > 0.50
| | | | | | | | | | | |--- class: 0
| | | | | | | | |--- avg_glucose_level > 207.46
| | | | | | | | | |--- avg_glucose_level <= 236.64
| | | | | | | | | | |--- smoking_status=Unknown <= 0.50
| | | | | | | | | | | |--- class: 0
| | | | | | | | | | |--- smoking_status=Unknown > 0.50
| | | | | | | | | | | |--- truncated branch of depth 3
| | | | | | | | | |--- avg_glucose_level > 236.64
| | | | | | | | | | |--- gender=Female <= 0.50
| | | | | | | | | | | |--- class: 1
| | | | | | | | | | |--- gender=Female > 0.50
| | | | | | | | | | | |--- class: 0
y_pred = dt.predict_proba(X_train_res)[:, 1]
auc_dt = roc_auc_score(y_train_res, y_pred)
print(f"AUC DTree on training dataset: {round(auc_dt, 4)}")
y_pred = dt.predict_proba(X_val)[:, 1]
auc_dt = roc_auc_score(y_val, y_pred)
print(f"AUC DTree on validation dataset: {round(auc_dt, 4)}")AUC DTree on training dataset: 0.7885
AUC DTree on validation dataset: 0.5568
Tuning
Let’s find the best split and tune some hyperparameters to improve the model performance, will focus on max_depth and min_samples_leaf parameters.
for d in [1, 2, 3, 4, 5, 6, 10, 15, 20, None]:
dt = DecisionTreeClassifier(max_depth=d)
dt.fit(X_train_res, y_train_res)
y_pred = dt.predict_proba(X_val)[:, 1]
auc = roc_auc_score(y_val, y_pred)
print('%4s -> %.3f' % (d, auc)) 1 -> 0.749
2 -> 0.750
3 -> 0.740
4 -> 0.748
5 -> 0.766
6 -> 0.779
10 -> 0.575
15 -> 0.514
20 -> 0.540
None -> 0.535
The best depth is 4, 5, 6
scores = []
for d in [4, 5, 6]:
for s in [1, 2, 5, 10, 15, 20, 100, 200, 500]:
dt = DecisionTreeClassifier(max_depth=d, min_samples_leaf=s)
dt.fit(X_train_res, y_train_res)
y_pred = dt.predict_proba(X_val)[:, 1]
auc = roc_auc_score(y_val, y_pred)
print('(%4s, %3d) -> %.3f' % (d, s, auc))
scores.append((d, s, auc))( 4, 1) -> 0.748
( 4, 2) -> 0.748
( 4, 5) -> 0.748
( 4, 10) -> 0.748
( 4, 15) -> 0.745
( 4, 20) -> 0.740
( 4, 100) -> 0.738
( 4, 200) -> 0.749
( 4, 500) -> 0.750
( 5, 1) -> 0.766
( 5, 2) -> 0.766
( 5, 5) -> 0.765
( 5, 10) -> 0.778
( 5, 15) -> 0.775
( 5, 20) -> 0.762
( 5, 100) -> 0.746
( 5, 200) -> 0.754
( 5, 500) -> 0.750
( 6, 1) -> 0.779
( 6, 2) -> 0.772
( 6, 5) -> 0.769
( 6, 10) -> 0.752
( 6, 15) -> 0.750
( 6, 20) -> 0.732
( 6, 100) -> 0.750
( 6, 200) -> 0.762
( 6, 500) -> 0.750
df_scores = pd.DataFrame(scores, columns=['max_depth', 'min_samples_leaf', 'auc'])
# df_scores.sort_values(by='auc', ascending=False)
df_scores_pivot = df_scores.pivot(index='min_samples_leaf', columns=['max_depth'], values=['auc'])
# df_scores_pivot
sns.heatmap(df_scores_pivot, annot=True, fmt='.4f')Finally we find the best parameters as max_depth=6 and min_samples_leaf=1 with AUC of 0.7785 on validation dataset.
Full training set
dt = DecisionTreeClassifier(max_depth=6, min_samples_leaf=1)
dt.fit(X_full_train_res, y_full_train_res)
y_pred = dt.predict_proba(X_test)[:, 1]
auc_dt = roc_auc_score(y_test, y_pred)
print(f"AUC Dtree on testing dataset: {round(auc_dt, 4)}")AUC Dtree on testing dataset: 0.7733
Random forest
rf = RandomForestClassifier(n_estimators=10, random_state=42)
rf.fit(X_train_res, y_train_res)
y_pred = rf.predict_proba(X_val)[:, 1]
auc_rf = roc_auc_score(y_val, y_pred)
print(f"AUC RF on validation dataset: {round(auc_rf, 4)}")AUC RF on validation dataset: 0.7242
Tuning number of estimators
scores = []
for n in range(10, 201, 10):
rf = RandomForestClassifier(n_estimators=n, random_state=42)
rf.fit(X_train_res, y_train_res)
y_pred = rf.predict_proba(X_val)[:, 1]
auc = roc_auc_score(y_val, y_pred)
print('%3d -> %.3f' % (n, auc))
scores.append((n, auc)) 10 -> 0.724
20 -> 0.730
30 -> 0.736
40 -> 0.741
50 -> 0.750
60 -> 0.752
70 -> 0.755
80 -> 0.773
90 -> 0.772
100 -> 0.768
110 -> 0.775
120 -> 0.773
130 -> 0.769
140 -> 0.776
150 -> 0.776
160 -> 0.778
170 -> 0.778
180 -> 0.778
190 -> 0.777
200 -> 0.776
df_scores = pd.DataFrame(scores, columns=['n_estimators', 'auc'])
plt.plot(df_scores.n_estimators, df_scores.auc)Based on the plot, we need another tweak with the number of max depth of trees in random forest 5, 10, 15
scores = []
for d in [5, 10, 15]:
for n in range(10, 201, 10):
rf = RandomForestClassifier(n_estimators=n, max_depth=d, random_state=42)
rf.fit(X_train_res, y_train_res)
y_pred = rf.predict_proba(X_val)[:, 1]
auc = roc_auc_score(y_val, y_pred)
print('(%3d %3d) -> %.3f' % (d, n, auc))
scores.append((d, n, auc))( 5 10) -> 0.765
( 5 20) -> 0.752
( 5 30) -> 0.756
( 5 40) -> 0.755
( 5 50) -> 0.757
( 5 60) -> 0.762
( 5 70) -> 0.758
( 5 80) -> 0.756
( 5 90) -> 0.755
( 5 100) -> 0.755
( 5 110) -> 0.756
( 5 120) -> 0.754
( 5 130) -> 0.754
( 5 140) -> 0.755
( 5 150) -> 0.755
( 5 160) -> 0.755
( 5 170) -> 0.754
( 5 180) -> 0.753
( 5 190) -> 0.751
( 5 200) -> 0.751
( 10 10) -> 0.726
( 10 20) -> 0.743
( 10 30) -> 0.755
( 10 40) -> 0.755
( 10 50) -> 0.755
( 10 60) -> 0.761
( 10 70) -> 0.763
( 10 80) -> 0.763
( 10 90) -> 0.763
( 10 100) -> 0.761
( 10 110) -> 0.763
( 10 120) -> 0.762
( 10 130) -> 0.763
( 10 140) -> 0.764
( 10 150) -> 0.764
( 10 160) -> 0.763
( 10 170) -> 0.763
( 10 180) -> 0.761
( 10 190) -> 0.761
( 10 200) -> 0.763
( 15 10) -> 0.688
( 15 20) -> 0.709
( 15 30) -> 0.749
( 15 40) -> 0.765
( 15 50) -> 0.763
( 15 60) -> 0.761
( 15 70) -> 0.757
( 15 80) -> 0.763
( 15 90) -> 0.760
( 15 100) -> 0.760
( 15 110) -> 0.771
( 15 120) -> 0.771
( 15 130) -> 0.773
( 15 140) -> 0.783
( 15 150) -> 0.783
( 15 160) -> 0.782
( 15 170) -> 0.780
( 15 180) -> 0.781
( 15 190) -> 0.780
( 15 200) -> 0.781
df_scores = pd.DataFrame(scores, columns=['max_depth', 'n_estimators', 'auc'])
for d in [5, 10, 15]:
df_subset = df_scores[df_scores.max_depth == d]
plt.plot(df_subset.n_estimators, df_subset.auc, label='max_depth=%d' % d)
plt.legend()The best max depth would be 15 then let’s tune again with min_samples_leaf
scores = []
max_depths = 15
for s in [1, 3, 5, 10, 50]:
for n in range(10, 201, 10):
rf = RandomForestClassifier(n_estimators=n, max_depth=max_depths, min_samples_leaf=s, random_state=42)
rf.fit(X_train_res, y_train_res)
y_pred = rf.predict_proba(X_val)[:, 1]
auc = roc_auc_score(y_val, y_pred)
print('(%3d %3d) -> %.3f' % (s, n, auc))
scores.append((s, n, auc))
df_scores = pd.DataFrame(scores, columns=['min_sample_leaf', 'n_estimators', 'auc'])
for s in [1, 3, 5, 10, 50]:
df_subset = df_scores[df_scores.min_sample_leaf == s]
plt.plot(df_subset.n_estimators, df_subset.auc, label='min_sample_leaf=%d' % s)
plt.legend()( 1 10) -> 0.688
( 1 20) -> 0.709
( 1 30) -> 0.749
( 1 40) -> 0.765
( 1 50) -> 0.763
( 1 60) -> 0.761
( 1 70) -> 0.757
( 1 80) -> 0.763
( 1 90) -> 0.760
( 1 100) -> 0.760
( 1 110) -> 0.771
( 1 120) -> 0.771
( 1 130) -> 0.773
( 1 140) -> 0.783
( 1 150) -> 0.783
( 1 160) -> 0.782
( 1 170) -> 0.780
( 1 180) -> 0.781
( 1 190) -> 0.780
( 1 200) -> 0.781
( 3 10) -> 0.734
( 3 20) -> 0.781
( 3 30) -> 0.783
( 3 40) -> 0.778
( 3 50) -> 0.784
( 3 60) -> 0.784
( 3 70) -> 0.782
( 3 80) -> 0.787
( 3 90) -> 0.787
( 3 100) -> 0.783
( 3 110) -> 0.781
( 3 120) -> 0.782
( 3 130) -> 0.782
( 3 140) -> 0.786
( 3 150) -> 0.786
( 3 160) -> 0.786
( 3 170) -> 0.785
( 3 180) -> 0.786
( 3 190) -> 0.786
( 3 200) -> 0.786
( 5 10) -> 0.771
( 5 20) -> 0.780
( 5 30) -> 0.788
( 5 40) -> 0.787
( 5 50) -> 0.791
( 5 60) -> 0.792
( 5 70) -> 0.792
( 5 80) -> 0.795
( 5 90) -> 0.795
( 5 100) -> 0.793
( 5 110) -> 0.790
( 5 120) -> 0.790
( 5 130) -> 0.789
( 5 140) -> 0.789
( 5 150) -> 0.788
( 5 160) -> 0.789
( 5 170) -> 0.788
( 5 180) -> 0.789
( 5 190) -> 0.789
( 5 200) -> 0.789
( 10 10) -> 0.783
( 10 20) -> 0.793
( 10 30) -> 0.806
( 10 40) -> 0.797
( 10 50) -> 0.795
( 10 60) -> 0.799
( 10 70) -> 0.800
( 10 80) -> 0.800
( 10 90) -> 0.798
( 10 100) -> 0.795
( 10 110) -> 0.794
( 10 120) -> 0.794
( 10 130) -> 0.795
( 10 140) -> 0.793
( 10 150) -> 0.793
( 10 160) -> 0.792
( 10 170) -> 0.793
( 10 180) -> 0.793
( 10 190) -> 0.794
( 10 200) -> 0.794
( 50 10) -> 0.791
( 50 20) -> 0.784
( 50 30) -> 0.784
( 50 40) -> 0.784
( 50 50) -> 0.787
( 50 60) -> 0.790
( 50 70) -> 0.793
( 50 80) -> 0.791
( 50 90) -> 0.791
( 50 100) -> 0.790
( 50 110) -> 0.791
( 50 120) -> 0.791
( 50 130) -> 0.790
( 50 140) -> 0.791
( 50 150) -> 0.790
( 50 160) -> 0.790
( 50 170) -> 0.790
( 50 180) -> 0.791
( 50 190) -> 0.792
( 50 200) -> 0.791
Full training set
# Final hyperparameters
numerical = 170
max_depths = 15
min_samples_leaf = 10
rf = RandomForestClassifier(n_estimators=n, max_depth=max_depths, min_samples_leaf=min_samples_leaf, random_state=42)
rf.fit(X_full_train_res, y_full_train_res)
y_pred = rf.predict_proba(X_test)[:, 1]
auc_rf = roc_auc_score(y_test, y_pred)
print(f"AUC RF on testing dataset: {round(auc_rf, 4)}")AUC RF on testing dataset: 0.8258
Feature importance
importances = rf.feature_importances_
feature_names = dv.get_feature_names_out()
feature_importances = pd.Series(importances, index=feature_names).sort_values(ascending=False)
sns.barplot(x=feature_importances.values, y=feature_importances.index)XGBoost
# define function to parse xgboost output
def parse_xgb_output(output):
results = []
for line in output.stdout.strip().split('\n'):
it_line, train_line, val_line = line.split('\t')
it = int(it_line.strip('[]'))
train = float(train_line.split(':')[1])
val = float(val_line.split(':')[1])
results.append((it, train, val))
df_results = pd.DataFrame(results, columns=['num_iter', 'auc_train', 'auc_val'])
return df_results
features = dv.get_feature_names_out().tolist()
dtrain = xgb.DMatrix(X_train_res, label=y_train_res, feature_names=features)
dval = xgb.DMatrix(X_val, label=y_val, feature_names=features)
xgb_params = {
'eta': 0.3,
'max_depth': 6,
'min_child_weight': 1,
'objective': 'binary:logistic',
'nthread': 8,
'seed': 1,
'verbosity': 1
}
model = xgb.train(xgb_params, dtrain, num_boost_round=10)
y_pred = model.predict(dval)
auc_xgb = roc_auc_score(y_val, y_pred)
print(f"AUC XGBoost on validation dataset: {round(auc_xgb, 4)}")AUC XGBoost on validation dataset: 0.7706
Full training set
dfulltrain = xgb.DMatrix(X_full_train_res, label=y_full_train_res, feature_names=dv.get_feature_names_out().tolist())
dtest = xgb.DMatrix(X_test, feature_names=dv.get_feature_names_out().tolist())
xgb_params = {
'eta': 0.01,
'max_depth': 6,
'min_child_weight': 10,
'objective': 'binary:logistic',
'nthread': 8,
'seed': 1,
'verbosity': 1,
'eval_metric': 'auc'
}
model = xgb.train(xgb_params, dfulltrain, num_boost_round=175, verbose_eval=5)
y_pred = model.predict(dtest)
auc_xgb = roc_auc_score(y_test, y_pred)
print(f"AUC RF on testing dataset: {round(auc_xgb, 4)}")AUC RF on testing dataset: 0.8005
Summary
Finally we have the following AUC scores on the testing dataset:
| Model | AUC |
|---|---|
| Logistic Regression | 0.8426 |
| Decision Tree | 0.7722 |
| Random Forest | 0.8258 |
| XGBoost | 0.8005 |
The model with the given data and parameters that performs the best is Logistic Regression.