Introduction to Machine Learning
Lecturer: Emi Tanaka
Department of Econometrics and Business Statistics
tree
and rpart
.\mathcal{A}_L | \mathcal{A}_R | |
---|---|---|
p_{M} | 10.2% (30) | 92.4% (121) |
p_{B} | 89.8% (265) | 7.6% (10) |
\mathcal{A}_L | \mathcal{A}_R | |
---|---|---|
p_{M} | 10.2% (30) | 92.4% (121) |
p_{B} | 89.8% (265) | 7.6% (10) |
rpart.control()
include:
minsplit
: the minimum number of observations in any non-terminal node.minbucket
: the minimum number of observations allowed in a terminal node.cp
: complexity parameter β minimum difference between impurity values required to continue splitting.rpart
search_df <- expand_grid(minbucket = seq(10, 100, length.out = 4),
minsplit = seq(10, 100, length.out = 4),
cp = seq(0, 1, length.out = 50))
search_df
# A tibble: 800 Γ 3
minbucket minsplit cp
<dbl> <dbl> <dbl>
1 10 10 0
2 10 10 0.0204
3 10 10 0.0408
4 10 10 0.0612
5 10 10 0.0816
6 10 10 0.102
7 10 10 0.122
8 10 10 0.143
9 10 10 0.163
10 10 10 0.184
# βΉ 790 more rows
scroll
library(yardstick)
set.seed(2023)
cancer_folds <- cancer_train %>%
vfold_cv(v = 5)
search_res <- cancer_folds %>%
mutate(search = map(splits, function(asplit) {
search_df %>%
rowwise() %>%
# fit the model for each row
# different row contains unique combination of
# minbucket, minsplit, and cp
mutate(fit = list(rpart(diagnosis ~ radius_mean + concave_points_mean,
data = training(asplit), method = "class",
control = rpart.control(minbucket = .data$minbucket,
minsplit = .data$minsplit,
cp = .data$cp)))) %>%
ungroup() %>%
# compute classification metric on validation fold
mutate(metrics = map(fit, function(afit) {
# get validation fold
testing(asplit) %>%
# predict from fitted model for this validation fold
mutate(pred = predict(afit, ., type = "class")) %>%
# get classification metrics
metric_set(accuracy, bal_accuracy, kap)(., truth = diagnosis, estimate = pred)
})) %>%
unnest(metrics) %>%
select(-c(fit, .estimator))
})) %>%
unnest(search)
# summarise the data for easy view
search_res_summary <- search_res %>%
group_by(minbucket, minsplit, cp, .metric) %>%
summarise(mean = mean(.estimate),
sd = sd(.estimate))
search_res_summary
# A tibble: 2,400 Γ 6
# Groups: minbucket, minsplit, cp [800]
minbucket minsplit cp .metric mean sd
<dbl> <dbl> <dbl> <chr> <dbl> <dbl>
1 10 10 0 accuracy 0.906 0.0275
2 10 10 0 bal_accuracy 0.896 0.0366
3 10 10 0 kap 0.794 0.0583
4 10 10 0.0204 accuracy 0.904 0.0409
5 10 10 0.0204 bal_accuracy 0.900 0.0417
6 10 10 0.0204 kap 0.792 0.0864
7 10 10 0.0408 accuracy 0.897 0.0314
8 10 10 0.0408 bal_accuracy 0.886 0.0260
9 10 10 0.0408 kap 0.774 0.0633
10 10 10 0.0612 accuracy 0.899 0.0306
# βΉ 2,390 more rows
minbucket | minsplit | cp | .metric | mean | sd |
---|---|---|---|---|---|
10 | 10 | 0.0000000 | accuracy | 0.9061286 | 0.0274646 |
10 | 10 | 0.0000000 | kap | 0.7935643 | 0.0583496 |
10 | 10 | 0.0204082 | bal_accuracy | 0.9004277 | 0.0417005 |
minbucket
and minsplit
doesnβt seem to make much difference (for the range searched at least).cp = 0
seems sufficient in this case.cp
with cross validation error in rpart
rpart
automatically stores the cross validation error in the resulting model objectfit <- rpart(diagnosis ~ radius_mean + concave_points_mean,
data = cancer_train, method = "class",
control = rpart.control(cp = 0, xval = 10)) # 10 folds (default)
fit$cptable
CP nsplit rel error xerror xstd
1 0.768211921 0 1.0000000 1.0000000 0.06538424
2 0.026490066 1 0.2317881 0.2781457 0.04074821
3 0.006622517 2 0.2052980 0.2715232 0.04031258
4 0.000000000 3 0.1986755 0.2847682 0.04117673
rel error
is the in-sample error (always decreases with more split)xerror
is the cross-validation errorxstd
is the standard deviation of the cross-validation errorscroll
y
= "yes"
or "no"
).library(tidyverse)
bank <- read_delim("https://emitanaka.org/iml/data/bank-full.csv", delim = ";")
skimr::skim(bank)
Name | bank |
Number of rows | 45211 |
Number of columns | 17 |
_______________________ | |
Column type frequency: | |
character | 10 |
numeric | 7 |
________________________ | |
Group variables | None |
Variable type: character
skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
---|---|---|---|---|---|---|---|
job | 0 | 1 | 6 | 13 | 0 | 12 | 0 |
marital | 0 | 1 | 6 | 8 | 0 | 3 | 0 |
education | 0 | 1 | 7 | 9 | 0 | 4 | 0 |
default | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
housing | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
loan | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
contact | 0 | 1 | 7 | 9 | 0 | 3 | 0 |
month | 0 | 1 | 3 | 3 | 0 | 12 | 0 |
poutcome | 0 | 1 | 5 | 7 | 0 | 4 | 0 |
y | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
age | 0 | 1 | 40.94 | 10.62 | 18 | 33 | 39 | 48 | 95 | β ββββ |
balance | 0 | 1 | 1362.27 | 3044.77 | -8019 | 72 | 448 | 1428 | 102127 | βββββ |
day | 0 | 1 | 15.81 | 8.32 | 1 | 8 | 16 | 21 | 31 | βββββ |
duration | 0 | 1 | 258.16 | 257.53 | 0 | 103 | 180 | 319 | 4918 | βββββ |
campaign | 0 | 1 | 2.76 | 3.10 | 1 | 1 | 2 | 3 | 63 | βββββ |
pdays | 0 | 1 | 40.20 | 100.13 | -1 | -1 | -1 | -1 | 871 | βββββ |
previous | 0 | 1 | 0.58 | 2.30 | 0 | 0 | 0 | 0 | 275 | βββββ |
duration
is omitted as a predictor in the model as it is computed based on the response.
poutcome
and so 4 overall impurities are calculated, one for each class vs other.scroll
library(tidyverse)
insurance <- read_csv("https://emitanaka.org/iml/data/insurance.csv")
skimr::skim(insurance)
Name | insurance |
Number of rows | 1338 |
Number of columns | 7 |
_______________________ | |
Column type frequency: | |
character | 3 |
numeric | 4 |
________________________ | |
Group variables | None |
Variable type: character
skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
---|---|---|---|---|---|---|---|
sex | 0 | 1 | 4 | 6 | 0 | 2 | 0 |
smoker | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
region | 0 | 1 | 9 | 9 | 0 | 4 | 0 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
age | 0 | 1 | 39.21 | 14.05 | 18.00 | 27.00 | 39.00 | 51.00 | 64.00 | ββ β ββ |
bmi | 0 | 1 | 30.66 | 6.10 | 15.96 | 26.30 | 30.40 | 34.69 | 53.13 | βββββ |
children | 0 | 1 | 1.09 | 1.21 | 0.00 | 0.00 | 1.00 | 2.00 | 5.00 | βββββ |
charges | 0 | 1 | 13270.42 | 12110.01 | 1121.87 | 4740.29 | 9382.03 | 16639.91 | 63770.43 | βββββ |
cp
based on xerror
scroll
cp
based on xerror
# A tibble: 86 Γ 5
CP nsplit `rel error` xerror xstd
<dbl> <dbl> <dbl> <dbl> <dbl>
1 0.620 0 1 1.00 0.0519
2 0.144 1 0.380 0.382 0.0190
3 0.0637 2 0.236 0.239 0.0145
4 0.00967 3 0.173 0.178 0.0133
5 0.00784 4 0.163 0.172 0.0135
6 0.00712 5 0.155 0.165 0.0131
7 0.00537 6 0.148 0.157 0.0131
8 0.00196 7 0.143 0.153 0.0132
9 0.00190 8 0.141 0.156 0.0133
10 0.00173 9 0.139 0.154 0.0132
# βΉ 76 more rows
optimal_cp <- T0$cptable %>%
as.data.frame() %>%
filter(xerror == min(xerror)) %>%
# if multiple optimal points, then select one
slice(1) %>%
pull(CP)
optimal_cp
[1] 0.0009125473
ETC3250/5250 Week 5