Biostatistical Computing, PHC 6068

Decision Tree

Zhiguang Huo (Caleb)

Wednesday October 25, 2017

Outline

Supervised machine learning and unsupervised machine learning

  1. Classification (supervised machine learning):
    • With the class label known, learn the features of the classes to predict a future observation.
    • The learning performance can be evaluated by the prediction error rate.
  2. Clustering (unsupervised machine learning)
    • Without knowing the class label, cluster the data according to their similarity and learn the features.
    • Normally the performance is difficult to evaluate and depends on the content of the problem.

Classification (supervised machine learning)

Clustering (unsupervised machine learning)

Decision Tree

library(ElemStatLearn)
library(rpart)
library(rpart.plot)
## Warning: package 'rpart.plot' was built under R version 3.3.2
afit <- rpart(svi ~ . - train, data = prostate)
rpart.plot(afit)

Decision Tree

Motivating example

Purpose: predict whether a student will play cricket in his/her leisure time.

Decision tree

Questions:

Decision tree structure

Terminology:

How does a tree decide where to split?

Goodness of split (GOS) criteria

Gini Index

Assume:

\[M_{Gini}(t) = 1 - P(Y = 0 | X \in t)^2 - P(Y = 1 | X \in t)^2\]

Gini Index

Goodness of split (GOS) criteria using Gini Index

Given an impurity function \(M(t)\), the GOS criterion is to find the split \(t_L\) and \(t_R\) of note \(t\) such that the impurity measure is maximally decreased:

\[\arg \max_{t_R, t_L} M(t) - [P(X\in t_L|X\in t) M(t_L) + P(X\in t_R|X\in t) M(t_R)]\]

Therefore, we will split based on Gender.

Entropy

Assume:

\[M_{entropy}(t) = - P(Y = 0 | X \in t)\log P(Y = 0 | X \in t) - P(Y = 1 | X \in t)\log P(Y = 1 | X \in t)\]

Entropy

Goodness of split (GOS) criteria using entropy

Given an impurity function \(M(t)\), the GOS criterion is to find the split \(t_L\) and \(t_R\) of note \(t\) such that the impurity measure is maximally decreased:

\[\arg \max_{t_R, t_L} M(t) - [P(X\in t_L|X\in t) M(t_L) + P(X\in t_R|X\in t) M(t_R)]\]

Summary of impurity measurement

Complexity for each split:

Decision tree

Construct the tree structure:

How to make a prediction:

\[\hat{p}_{mk} = \frac{1}{N_m} \sum_{x_i \in t_m} \mathbb{I}(y_i = k)\]

E.g. if a new subject (G:Male, Class:X, H:5.8ft) falls into a node, we just do a majority vote.

Regression tree

\[\hat{c}_{m} = ave(y_i|x_i \in T_m)\]

Regression Trees vs Classification Trees

  1. Regression trees are used when dependent variable is continuous. Classification trees are used when dependent variable is categorical.
  2. In case of regression tree, the value obtained by terminal nodes in the training data is the mean response of observation falling in that region.
  3. In case of classification tree, the value (class) obtained by terminal node in the training data is the mode of observations falling in that region.
  4. Both the trees divide the predictor space (independent variables) into distinct and non-overlapping regions.
  5. Both the trees follow a top-down greedy approach known as recursive binary splitting.
  6. In both the cases, the splitting process results in fully grown trees. Then we prune the tree to tackle overfitting.

Prostate cancer example

library(ElemStatLearn)
library(rpart)
library(rpart.plot)
head(prostate)
##       lcavol  lweight age      lbph svi       lcp gleason pgg45       lpsa
## 1 -0.5798185 2.769459  50 -1.386294   0 -1.386294       6     0 -0.4307829
## 2 -0.9942523 3.319626  58 -1.386294   0 -1.386294       6     0 -0.1625189
## 3 -0.5108256 2.691243  74 -1.386294   0 -1.386294       7    20 -0.1625189
## 4 -1.2039728 3.282789  58 -1.386294   0 -1.386294       6     0 -0.1625189
## 5  0.7514161 3.432373  62 -1.386294   0 -1.386294       6     0  0.3715636
## 6 -1.0498221 3.228826  50 -1.386294   0 -1.386294       6     0  0.7654678
##   train
## 1  TRUE
## 2  TRUE
## 3  TRUE
## 4  TRUE
## 5  TRUE
## 6  TRUE

Apply cart

prostate_train <- subset(prostate, subset = train==TRUE)
prostate_test <- subset(prostate, subset = train==FALSE)

afit <- rpart(svi ~ . - train, data = prostate_train)
afit
## n= 67 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
## 1) root 67 11.641790 0.22388060  
##   2) lcavol< 2.523279 52  2.826923 0.05769231  
##     4) lpsa< 2.993028 43  0.000000 0.00000000 *
##     5) lpsa>=2.993028 9  2.000000 0.33333330 *
##   3) lcavol>=2.523279 15  2.400000 0.80000000 *

Visualize cart result

rpart.plot(afit)

For the top node:

predicting the testing dataset using CART subject

predProb_cart <- predict(object = afit, newdata = prostate_test)
predProb_cart
##         7         9        10        15        22        25        26 
## 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 
##        28        32        34        36        42        44        48 
## 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 
##        49        50        53        54        55        57        62 
## 0.0000000 0.0000000 0.0000000 0.0000000 0.8000000 0.0000000 0.0000000 
##        64        65        66        73        74        80        84 
## 0.0000000 0.0000000 0.0000000 0.3333333 0.3333333 0.8000000 0.8000000 
##        95        97 
## 0.8000000 0.8000000
atable <- table(predictLabel = predProb_cart>0.5, trueLabel = prostate_test$svi)
atable
##             trueLabel
## predictLabel  0  1
##        FALSE 21  4
##        TRUE   3  2
## accuracy
sum(diag(atable)) / sum(atable)
## [1] 0.7666667

Compare with logistic regression

aglm <- glm(svi ~ . - train, data = prostate_train, family = binomial(link = "logit"))
predProb_logistic <- predict(object = aglm, newdata = prostate_test, type="response")
btable <- table(predictLabel = predProb_logistic>0.5, trueLabel = prostate_test$svi)
btable
##             trueLabel
## predictLabel  0  1
##        FALSE 22  4
##        TRUE   2  2
## accuracy
sum(diag(btable)) / sum(btable)
## [1] 0.8

Why CART not logistic regression or linear regression

Model complexity (how deep should we keep the tree)

Pruning the tree

Titanic example (Will be on HW)

library(titanic)
dim(titanic_train)
## [1] 891  12
head(titanic_train)
##   PassengerId Survived Pclass
## 1           1        0      3
## 2           2        1      1
## 3           3        1      3
## 4           4        1      1
## 5           5        0      3
## 6           6        0      3
##                                                  Name    Sex Age SibSp
## 1                             Braund, Mr. Owen Harris   male  22     1
## 2 Cumings, Mrs. John Bradley (Florence Briggs Thayer) female  38     1
## 3                              Heikkinen, Miss. Laina female  26     0
## 4        Futrelle, Mrs. Jacques Heath (Lily May Peel) female  35     1
## 5                            Allen, Mr. William Henry   male  35     0
## 6                                    Moran, Mr. James   male  NA     0
##   Parch           Ticket    Fare Cabin Embarked
## 1     0        A/5 21171  7.2500              S
## 2     0         PC 17599 71.2833   C85        C
## 3     0 STON/O2. 3101282  7.9250              S
## 4     0           113803 53.1000  C123        S
## 5     0           373450  8.0500              S
## 6     0           330877  8.4583              Q

Bagging

Bagging

  1. Create Multiple DataSets:
    • Sampling with replacement on the original data
    • Taking row and column fractions less than 1 helps in making robust models, less prone to overfitting
  2. Build Multiple Classifiers:
    • Classifiers are built on each data set.
  3. Combine Classifiers:
    • The predictions of all the classifiers are combined using a mean, median or mode value depending on the problem at hand.
    • The combined values are generally more robust than a single model.

Random forest

Random forest algorithm

  1. Assume number of cases in the training set is N. Then, sample of these N cases with replacement, which will be the training set for growing the tree.
  2. If there are M input variables, a number m<M is specified such that at each node, m variables are selected at random out of the M. The best split on these m is used to split the node. The value of m is held constant while we grow the forest.
  3. Each tree is grown to the largest extent possible and there is no pruning.
  4. Predict new data by aggregating the predictions of the ntree trees (i.e., majority votes for classification, average for regression).

Random Forest on prostate cancer example

library("randomForest")
## randomForest 4.6-12
## Type rfNews() to see new features/changes/bug fixes.
prostate_train <- subset(prostate, subset = train==TRUE)
prostate_test <- subset(prostate, subset = train==FALSE)

rfit <- randomForest(as.factor(svi) ~ . - train, data = prostate_train)
rfit
## 
## Call:
##  randomForest(formula = as.factor(svi) ~ . - train, data = prostate_train) 
##                Type of random forest: classification
##                      Number of trees: 500
## No. of variables tried at each split: 2
## 
##         OOB estimate of  error rate: 13.43%
## Confusion matrix:
##    0 1 class.error
## 0 49 3  0.05769231
## 1  6 9  0.40000000

Random Forest on prostate cancer example

library(ggplot2)
## Warning: package 'ggplot2' was built under R version 3.3.2
## 
## Attaching package: 'ggplot2'
## The following object is masked from 'package:randomForest':
## 
##     margin
imp <- importance(rfit)
impData <- data.frame(cov = rownames(imp), importance=imp[,1])

ggplot(impData) + aes(x=cov, y=importance, fill=cov) + geom_bar(stat="identity")

Random Forest on prostate prediction

pred_logistic <- predict(rfit, prostate_test)

ctable <- table(pred_logistic, trueLabel = prostate_test$svi)
ctable
##              trueLabel
## pred_logistic  0  1
##             0 22  4
##             1  2  2
## accuracy
sum(diag(btable)) / sum(btable)
## [1] 0.8

Will be on homework

Apply random forest on Titanic dataset to predict survival.

reference

knitr::purl("cart.rmd", output = "cart.R ", documentation = 2)
## 
## 
## processing file: cart.rmd
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |...                                                              |   5%
  |                                                                       
  |......                                                           |   9%
  |                                                                       
  |.........                                                        |  14%
  |                                                                       
  |............                                                     |  18%
  |                                                                       
  |...............                                                  |  23%
  |                                                                       
  |..................                                               |  27%
  |                                                                       
  |.....................                                            |  32%
  |                                                                       
  |........................                                         |  36%
  |                                                                       
  |...........................                                      |  41%
  |                                                                       
  |..............................                                   |  45%
  |                                                                       
  |................................                                 |  50%
  |                                                                       
  |...................................                              |  55%
  |                                                                       
  |......................................                           |  59%
  |                                                                       
  |.........................................                        |  64%
  |                                                                       
  |............................................                     |  68%
  |                                                                       
  |...............................................                  |  73%
  |                                                                       
  |..................................................               |  77%
  |                                                                       
  |.....................................................            |  82%
  |                                                                       
  |........................................................         |  86%
  |                                                                       
  |...........................................................      |  91%
  |                                                                       
  |..............................................................   |  95%
  |                                                                       
  |.................................................................| 100%
## output file: cart.R
## [1] "cart.R "