본문 바로가기

Data Science/R

[R] Tree-Based Methods : Classification Decision Tree

1. What is Classification Decision Tree?

  • Predict a qualitative response rather than a quantitative one.
  • Predict that each observation belongs to the most commonly occuring class.
  • Use recursive binary splitting to grow a classification tree.
  • Use classification error rate(missclassification rate) as evaluation metrics.

 

Splitting metrics

  • The classification error rate : \(Error = 1 - max_{k}(\hat{p}_{mk})\)
  • Gini index
    • Term : \(G_m = \sum_{k=1}^K\hat{p}_{mk}(1 - \hat{p}_{mk})\)
    • The gini inxex is referred to as a measure of node purity.
  • Cross-entropy
    • Term : \(C_m = -\sum_{k=1}^K \hat{p}_{mk} log(\hat{p}_{mk})\)
    • \(\hat{p}_{mk} = \frac{n_{mk}}{n_m}\)
  • Deviance
    • Term : \(C_m = -\sum_{k=1}^K n_{mk} log(\hat{p}_{mk})\)
    • \(n_{mk}\) : The number of observations in the \(m\)th node that belong to the \(k\)th class.
    • Residual mean deviance : \(\frac{2}{n - |T_0|}\sum_m D_m\)
    • \(|T_o|\) : The size of the tree

 

2. [Ex] Building Classification Decision Tree

Step 1 : Prerequirisite

# Importing dataset 
url.ht <- "https://www.statlearning.com/s/Heart.csv"
Heart <- read.csv(url.ht, h=T)

# Preview dataset 
summary(Heart)

# Preprocessing dataset 
Heart <- Heart[, colnames(Heart)!="X"]
Heart[,"Sex"] <- factor(Heart[,"Sex"], 0:1, c("female", "male"))
Heart[,"Fbs"] <- factor(Heart[,"Fbs"], 0:1, c("false", "true"))
Heart[,"ExAng"] <- factor(Heart[,"ExAng"], 0:1, c("no", "yes"))
Heart[,"ChestPain"] <- as.factor(Heart[,"ChestPain"])
Heart[,"Thal"] <- as.factor(Heart[,"Thal"])
Heart[,"AHD"] <- as.factor(Heart[,"AHD"])

# Basic EDA 
summary(Heart)
dim(Heart)
sum(is.na(Heart))
Heart <- na.omit(Heart)
dim(Heart)
summary(Heart)

 

Step 2 : Visualize proportion of heart disease rate by classes

library(ggplot2)
library(gridExtra)
g1 <-ggplot(Heart, aes(x=Sex, fill=AHD)) + geom_bar(position="stack")
g2 <-ggplot(Heart, aes(x=ChestPain,fill=AHD)) + geom_bar(position="stack")
g3 <-ggplot(Heart, aes(x=Fbs, fill=AHD)) + geom_bar(position="stack")
g4 <-ggplot(Heart, aes(x=ExAng, fill=AHD)) + geom_bar(position="stack")
g5 <-ggplot(Heart,aes(x=Thal, fill=AHD)) + geom_bar(position="stack")
grid.arrange(g1, g2, g3, g4, g5, nrow=2)

 

 

Step 3 : Training using logistic regression model

g <- glm(AHD ~., family="binomial", Heart)
summary(g)

 

Step 4 : Training using classification decision tree model

library(tree)
tree.heart <- tree(AHD ~., Heart)
summary(tree.heart)
tree.heart
plot(tree.heart)
text(tree.heart)
plot(tree.heart)
text(tree.heart, pretty=0)

 

 

 

  • Values of terminal nodes(leaf nodes) are fixed of Yes/No.
  • If the conditional result of node is True move toward left subtree nodes
  • The reason why splitting occurs under the subtrees of \(C_a < 0.5\) is that we use splitting metrics as gini index and cross-entropy.

 

Step 5 : Calculating missclassification rate of training set

# predict the probability of each class or class type
predict(tree.heart, Heart)
predict(tree.heart, Heart, type="class")

# Compute classification error rate of training observations
pred <- predict(tree.heart, Heart, type="class")
table(pred, Heart$AHD)
mean(pred!=Heart$AHD)

 

 

pred No Yes
No 152 21
Yes 8 116
  • Missclassification rate : 0.0976431

 

3. [Ex] Calculating Missclassification of Validation set

# Separate training and test sets
set.seed(123)
train <- sample(1:nrow(Heart), nrow(Heart)/2)
test <- setdiff(1:nrow(Heart), train)
heart.test <- Heart[test, ]

# Training model 
heart.tran <- tree(AHD ~., Heart, subset=train)

# Make prediction of validation set 
heart.pred <- predict(heart.tran, heart.test, type="class")

# Compute classification error rate
table(heart.pred, Heart$AHD[test])
mean(heart.pred!=Heart$AHD[test])

 

heart.pred No Yes
No 59 22
Yes 19 49
  • Missclassification rate : 0.2751678

 

4. [Ex] Pruning using Cross Validation

# Run 5-fold cross validation 
set.seed(1234)
cv.heart <- cv.tree(heart.tran, FUN=prune.missclass, K=5) 
cv.heart 

# Visualize cross validation missclassification rate 
par(mfrow=c(1,2)) 
plot(cv.heart$size, cv.heart$dev, type="b") 
plot(cv.herat$k, cv.heart$dev, type="b")

 

 

# Find the optimal tree size 
w <- cv.heart$size[which.min(cv.heart$dev)] 

# Prune the tree with the optimal size 
prune.heart <- prune.missclass(heart.tran, best=w) 

# Visualize unpruned tree vs pruned tree 
par(mfrow=c(1,2)) 
plot(heart.tran) 
text(heart.tran) 
plot(prune.heart) 
text(prune.heart, pretty=0)

 

 

# Compute classification error of the subtree 
heart.pred <- predict(prune.heart, heart.test, type="class") 
table(heart.pred, Heart$AHD[test]) 
mean(heart.pred!=Heart$AHD[test])
  • Missclassification rate of unpruned tree : 0.2751678
  • Missclassification rate of pruned tree : 0.261745
  • The missclassification rate aren't differ between unpruned tree and pruned tree.

 

5. [Ex] Simulation Study : Iterating 100 times

set.seed(111) 
K <- 100 
RES1 <- matrix(0, K, 2) 

# Iterate 100 times 
for (i in 1:K) {
    train <- sample(1:nrow(Heart), floor(nrow(Heart)*2/3)) 
    test <- setdiff(1:nrow(Heart), train) 
    heart.test <- Heart[test, ]
    heart.tran <- tree(AHD ~., Heart, subset=train) 
    heart.pred <- predict(heart.tran, heart.test, type="class") 
    RES1[i, 1] <- mean(heart.pred != Heart$AHD[test]) 
    cv.heart <- cv.tree(heart.tran, FUN=prune.misclass, K=5) 
    w <- cv.heart$size[which.min(cv.heart$dev)] 
    prune.heart <- prune.misclass(heart.tran, best=w) 
    heart.pred.cv <- predict(prune.heart, heart.test, type="class") 
    RES1[i, 2] <- mean(heart.pred.cv != Heart$AHD[test]) 
} 

# Calculate mean CVE of 100 iterations 
apply(RES1, 2, mean) 
boxplot(RES1, col=c("orange", "lightblue"), boxwex=0.6, names=c("unpruned tree", "pruned tree"),
        ylab="Classification Error Rate")
  • Result of unpruned vs pruned tree : 0.2404040 vs 0.2351515

 

 

 

6. [Ex] Simulation Study : Iterating 100 times and comparing with other models

library(MASS) 
library(e1071) 

set.seed(111) 
K <- 100
RES2 <- matrix(0, K, 4) 

for (i in 1:K) { 
    train <- sample(1:nrow(Heart), floor(nrow(Heart)*2/3)) 
    test <- setdiff(1:nrow(Heart), train) 
    y.test <- Heart$AHD[test] 

    # Logisitc Regression 
    g1 <- glm(AHD~., data=Heart, family="binomial", subset=train) 
    p1 <- predict(g1, Heart[test,], type="response") 
    pred1 <- rep("No", length(y.test)) 
    pred1[p1 > 0.5] <- "Yes" 
    RES2[i, 1] <- mean(pred1!=y.test) 

    # LDA/QDA 
    g2 <- lda(AHD~., data=Heart, subset=train) 
    g3 <- qda(AHD~., data=Heart, subset=train) 
    pred2 <- predict(g2, Heart[test,])$class
    pred3 <- predict(g3, Heart[test,])$class
    RES2[i, 2] <- mean(pred2!=y.test) 
    RES2[i, 3] <- mean(pred3!=y.test) 

    # Bayes Naive 
    g4 <- naiveBayes(AHD~., data=Heart, subset=train) 
    pred4 <- predict(g4, Heart[test,]) 
    RES2[i, 4] <- mean(pred4!= y.test) 
} 

apply(RES2, 2, mean) 
boxplot(cbind(RES1, RES2), col=2:8, boxwex=0.6, 
        names=c("unpruned tree", "pruned tree", "Logistic", "LDA", "QDA", "BayesNaive"),
        ylab="Classification Error Rate")

 

  • The missclassification rate of tree methods are higher than other models.