1. What is Classification Decision Tree?
- Predict a qualitative response rather than a quantitative one.
- Predict that each observation belongs to the most commonly occuring class.
- Use recursive binary splitting to grow a classification tree.
- Use classification error rate(missclassification rate) as evaluation metrics.
Splitting metrics
- The classification error rate : \(Error = 1 - max_{k}(\hat{p}_{mk})\)
- Gini index
- Term : \(G_m = \sum_{k=1}^K\hat{p}_{mk}(1 - \hat{p}_{mk})\)
- The gini inxex is referred to as a measure of node purity.
- Cross-entropy
- Term : \(C_m = -\sum_{k=1}^K \hat{p}_{mk} log(\hat{p}_{mk})\)
- \(\hat{p}_{mk} = \frac{n_{mk}}{n_m}\)
- Deviance
- Term : \(C_m = -\sum_{k=1}^K n_{mk} log(\hat{p}_{mk})\)
- \(n_{mk}\) : The number of observations in the \(m\)th node that belong to the \(k\)th class.
- Residual mean deviance : \(\frac{2}{n - |T_0|}\sum_m D_m\)
- \(|T_o|\) : The size of the tree
2. [Ex] Building Classification Decision Tree
Step 1 : Prerequirisite
# Importing dataset
url.ht <- "https://www.statlearning.com/s/Heart.csv"
Heart <- read.csv(url.ht, h=T)
# Preview dataset
summary(Heart)
# Preprocessing dataset
Heart <- Heart[, colnames(Heart)!="X"]
Heart[,"Sex"] <- factor(Heart[,"Sex"], 0:1, c("female", "male"))
Heart[,"Fbs"] <- factor(Heart[,"Fbs"], 0:1, c("false", "true"))
Heart[,"ExAng"] <- factor(Heart[,"ExAng"], 0:1, c("no", "yes"))
Heart[,"ChestPain"] <- as.factor(Heart[,"ChestPain"])
Heart[,"Thal"] <- as.factor(Heart[,"Thal"])
Heart[,"AHD"] <- as.factor(Heart[,"AHD"])
# Basic EDA
summary(Heart)
dim(Heart)
sum(is.na(Heart))
Heart <- na.omit(Heart)
dim(Heart)
summary(Heart)
Step 2 : Visualize proportion of heart disease rate by classes
library(ggplot2)
library(gridExtra)
g1 <-ggplot(Heart, aes(x=Sex, fill=AHD)) + geom_bar(position="stack")
g2 <-ggplot(Heart, aes(x=ChestPain,fill=AHD)) + geom_bar(position="stack")
g3 <-ggplot(Heart, aes(x=Fbs, fill=AHD)) + geom_bar(position="stack")
g4 <-ggplot(Heart, aes(x=ExAng, fill=AHD)) + geom_bar(position="stack")
g5 <-ggplot(Heart,aes(x=Thal, fill=AHD)) + geom_bar(position="stack")
grid.arrange(g1, g2, g3, g4, g5, nrow=2)
Step 3 : Training using logistic regression model
g <- glm(AHD ~., family="binomial", Heart)
summary(g)
Step 4 : Training using classification decision tree model
library(tree)
tree.heart <- tree(AHD ~., Heart)
summary(tree.heart)
tree.heart
plot(tree.heart)
text(tree.heart)
plot(tree.heart)
text(tree.heart, pretty=0)
- Values of terminal nodes(leaf nodes) are fixed of Yes/No.
- If the conditional result of node is True move toward left subtree nodes
- The reason why splitting occurs under the subtrees of \(C_a < 0.5\) is that we use splitting metrics as gini index and cross-entropy.
Step 5 : Calculating missclassification rate of training set
# predict the probability of each class or class type
predict(tree.heart, Heart)
predict(tree.heart, Heart, type="class")
# Compute classification error rate of training observations
pred <- predict(tree.heart, Heart, type="class")
table(pred, Heart$AHD)
mean(pred!=Heart$AHD)
pred | No | Yes |
No | 152 | 21 |
Yes | 8 | 116 |
- Missclassification rate : 0.0976431
3. [Ex] Calculating Missclassification of Validation set
# Separate training and test sets
set.seed(123)
train <- sample(1:nrow(Heart), nrow(Heart)/2)
test <- setdiff(1:nrow(Heart), train)
heart.test <- Heart[test, ]
# Training model
heart.tran <- tree(AHD ~., Heart, subset=train)
# Make prediction of validation set
heart.pred <- predict(heart.tran, heart.test, type="class")
# Compute classification error rate
table(heart.pred, Heart$AHD[test])
mean(heart.pred!=Heart$AHD[test])
heart.pred | No | Yes |
No | 59 | 22 |
Yes | 19 | 49 |
- Missclassification rate : 0.2751678
4. [Ex] Pruning using Cross Validation
# Run 5-fold cross validation
set.seed(1234)
cv.heart <- cv.tree(heart.tran, FUN=prune.missclass, K=5)
cv.heart
# Visualize cross validation missclassification rate
par(mfrow=c(1,2))
plot(cv.heart$size, cv.heart$dev, type="b")
plot(cv.herat$k, cv.heart$dev, type="b")
# Find the optimal tree size
w <- cv.heart$size[which.min(cv.heart$dev)]
# Prune the tree with the optimal size
prune.heart <- prune.missclass(heart.tran, best=w)
# Visualize unpruned tree vs pruned tree
par(mfrow=c(1,2))
plot(heart.tran)
text(heart.tran)
plot(prune.heart)
text(prune.heart, pretty=0)
# Compute classification error of the subtree
heart.pred <- predict(prune.heart, heart.test, type="class")
table(heart.pred, Heart$AHD[test])
mean(heart.pred!=Heart$AHD[test])
- Missclassification rate of unpruned tree : 0.2751678
- Missclassification rate of pruned tree : 0.261745
- The missclassification rate aren't differ between unpruned tree and pruned tree.
5. [Ex] Simulation Study : Iterating 100 times
set.seed(111)
K <- 100
RES1 <- matrix(0, K, 2)
# Iterate 100 times
for (i in 1:K) {
train <- sample(1:nrow(Heart), floor(nrow(Heart)*2/3))
test <- setdiff(1:nrow(Heart), train)
heart.test <- Heart[test, ]
heart.tran <- tree(AHD ~., Heart, subset=train)
heart.pred <- predict(heart.tran, heart.test, type="class")
RES1[i, 1] <- mean(heart.pred != Heart$AHD[test])
cv.heart <- cv.tree(heart.tran, FUN=prune.misclass, K=5)
w <- cv.heart$size[which.min(cv.heart$dev)]
prune.heart <- prune.misclass(heart.tran, best=w)
heart.pred.cv <- predict(prune.heart, heart.test, type="class")
RES1[i, 2] <- mean(heart.pred.cv != Heart$AHD[test])
}
# Calculate mean CVE of 100 iterations
apply(RES1, 2, mean)
boxplot(RES1, col=c("orange", "lightblue"), boxwex=0.6, names=c("unpruned tree", "pruned tree"),
ylab="Classification Error Rate")
- Result of unpruned vs pruned tree : 0.2404040 vs 0.2351515
6. [Ex] Simulation Study : Iterating 100 times and comparing with other models
library(MASS)
library(e1071)
set.seed(111)
K <- 100
RES2 <- matrix(0, K, 4)
for (i in 1:K) {
train <- sample(1:nrow(Heart), floor(nrow(Heart)*2/3))
test <- setdiff(1:nrow(Heart), train)
y.test <- Heart$AHD[test]
# Logisitc Regression
g1 <- glm(AHD~., data=Heart, family="binomial", subset=train)
p1 <- predict(g1, Heart[test,], type="response")
pred1 <- rep("No", length(y.test))
pred1[p1 > 0.5] <- "Yes"
RES2[i, 1] <- mean(pred1!=y.test)
# LDA/QDA
g2 <- lda(AHD~., data=Heart, subset=train)
g3 <- qda(AHD~., data=Heart, subset=train)
pred2 <- predict(g2, Heart[test,])$class
pred3 <- predict(g3, Heart[test,])$class
RES2[i, 2] <- mean(pred2!=y.test)
RES2[i, 3] <- mean(pred3!=y.test)
# Bayes Naive
g4 <- naiveBayes(AHD~., data=Heart, subset=train)
pred4 <- predict(g4, Heart[test,])
RES2[i, 4] <- mean(pred4!= y.test)
}
apply(RES2, 2, mean)
boxplot(cbind(RES1, RES2), col=2:8, boxwex=0.6,
names=c("unpruned tree", "pruned tree", "Logistic", "LDA", "QDA", "BayesNaive"),
ylab="Classification Error Rate")
- The missclassification rate of tree methods are higher than other models.
'Data Science > R' 카테고리의 다른 글
[R] Tree-Based Methods : Bagging (0) | 2022.11.27 |
---|---|
[R] Tree-Based Methods : Advantages and Disadvantages of Tree (0) | 2022.11.27 |
[R] Tree-Based Methods : Regression Decision Tree (0) | 2022.11.27 |
[R] Tree-Based Methods : Decision Tree (0) | 2022.11.27 |
[R] Non-Linear Models : Local Regression, GAM (0) | 2022.11.14 |