1. What is Bayesian Additive Tree?
- Bayesian additive regression trees(BART) is another ensemble method that uses decision trees as its building blocks.
- BART methods combines other ensemble methods :
- Each tree is constructed in a random manner as in bagging and random forest.
- Each tree tries to capture signals not yet accounted for by the current model as in boosting.
- BART method can be viewed as a Bayesian approach :
- \(\theta_1 ~ p(\theta_1)\) : Prior distribution
- \(\theta_1 | \theta_2, ... \theta_k\) : Posterior distribution
- Each time we randomly perturb a tree in order to fit the residuals, we are in fact drawing a new tree from a posterior distribution.
- MCMC(Markov chain Monte Carlo) algorithm
- Remove out prediction results at burn-in period
1.1 BART algorithms
- \(\hat{f}^b_k(x)\) represents the prediction at \(x\) for the \(k\)th tree used in the \(b\)th iteration : \(k = 1, ..., K\) and \(b = 1, ..., B\)
- Let \(\hat{f}^1_1(x) = ... = \hat{f}^1_K(x) = \frac{1}{nK}\sum_{i=1}^n y_i\)
- Compute \(\hat{f}^1(x) = \sum_{k=1}^K \hat{f}^1_k(x)\)
- For \(b = 2, ..., B\) :
- For \(k = 1, 2, ..., K\) :
- For \(i = 1, ..., n\), compute the current partial residuals : \(r_i = y_i - \sum_{\dot{k} < k} \hat{f}\)
- For \(k = 1, 2, ..., K\) :
- Compute the mean after \(L\) burn-in samples : \(\hat{f}(x) = \frac{1}{B - L}\sum_{b = L+1}^B \hat{f}^b(x)\)
2. Building BART tree using lbart, pbart function
# Prerequirisite
library(BART)
# Train-test split
set.seed(123)
train <- sample(1:nrow(Heart), nrow(Heart)/2)
test <- setdiff(1:nrow(Heart), train)
x <- Heart[, -14]
y <- as.numeric(Heart[, 14])-1
xtrain <- x[train, ]
ytrain <- y[train]
xtest <- x[-train, ]
ytest <- y[-train]
# Logistic BART
set.seed(11)
fit1 <- lbart(xtrain, ytrain, x.test=xtest)
names(fit1)
# Make predictions
prob1 <- rep(0, length(ytest))
prob1[fit1$prob.test.mean > 0.5] <- 1
mean(prob1!=ytest)
# Probit BART
set.seed(22)
fit2 <- pbart(xtrain, ytrain, x.test=xtest)
# Make Prediction
prob2 <- rep(0, length(ytest))
prob2[fit2$prob.test.mean > 0.5] <- 1
mean(prob2!=ytest)
# Visualize results : lbart ~ pbart
cbind(fit1$prob.test.mean, fit2$prob.test.mean)
plot(fit1$prob.test.mean, fit2$prob.test.mean, col=ytest+2,
xlab="Logistic BART", ylab="Probit BART")
abline(0, 1, lty=3, col="grey")
abline(v=0.5, lty=1, col="grey")
abline(h=0.5, lty=1, col="grey")
legend("topleft", col=c(2,3), pch=1,
legend=c("AHD = No", "AHD = Yes"))
- Misclassification error rate : 0.1946309
3. Comparison of MSE among different models : Tree, RF, Boosting, BART
# Revisit Boston data set with a quantitative response
library(MASS)
summary(Boston)
dim(Boston)
# Train-test split
set.seed(111)
train <- sample(1:nrow(Boston), floor(nrow(Boston)*2/3))
boston.test <- Boston[-train, "medv"]
# Calculate misssclassification error rate of Regression tree
library(tree)
tree.boston <- tree(medv ~ ., Boston, subset=train)
yhat <- predict(tree.boston, newdata=Boston[-train, ])
mean((yhat - boston.test)^2)
# Calculate missclassification error rate of LSE: least square estimates
g0 <- lm(medv ~ ., Boston, subset=train)
pred0 <- predict(g0, Boston[-train,])
mean((pred0 - boston.test)^2)
# Calculate missclassification error rate of Bagging
library(randomForest)
g1 <- randomForest(medv ~ ., data=Boston, mtry=13, subset=train)
yhat1 <- predict(g1, newdata=Boston[-train, ])
mean((yhat1 - boston.test)^2)
# Calculate missclassification error rate of Random Forest (m = 4)
g2 <- randomForest(medv ~ ., data=Boston, mtry=4, subset=train)
yhat2 <- predict(g2, newdata=Boston[-train, ])
mean((yhat2 - boston.test)^2)
# Calculate missclassification error rate of Boosting (d = 4)
library(gbm)
g3 <- gbm(medv~., data = Boston[train, ], distribution="gaussian", n.trees=5000, interaction.depth=4)
yhat3 <- predict(g3, newdata=Boston[-train, ], n.trees=5000)
mean((yhat3 - boston.test)^2)
# Calculate missclassifcation error rate of BART
library(BART)
g4 <- gbart(Boston[train, 1:13], Boston[train, "medv"], x.test=Boston[-train, 1:13])
yhat4 <- g4$yhat.test.mean
mean((yhat4 - boston.test)^2)
Regression tree | LSE | Bagging | Random Forest | Boosting | BART |
24.59443 | 21.52607 | 12.43788 | 8.304427 | 12.2265 | 11.68592 |
4. Simulation study of 3
# Simulation: 4 ensemble methods
set.seed(1111)
N <- 20
ERR <- matrix(0, N, 4)
# replicate 20 times
for (i in 1:N) {
train <- sample(1:nrow(Boston), floor(nrow(Boston)*2/3))
boston.test <- Boston[-train, "medv"]
# Bagging
g1 <- randomForest(medv ~ ., data=Boston, mtry=13, subset=train)
yhat1 <- predict(g1, newdata=Boston[-train, ])
ERR[i,1] <- mean((yhat1 - boston.test)^2)
# Random forest
g2 <- randomForest(medv ~ ., data=Boston, mtry=4, subset=train)
yhat2 <- predict(g2, newdata=Boston[-train, ])
ERR[i, 2] <- mean((yhat2 - boston.test)^2)
# Boosting
g3 <- gbm(medv~., data = Boston[train, ], n.trees=5000, distribution="gaussian", interaction.depth=4)
yhat3 <- predict(g3, newdata=Boston[-train, ], n.trees=5000)
ERR[i, 3] <- mean((yhat3 - boston.test)^2)
# BART
invisible(capture.output(g4 <- gbart(Boston[train, 1:13], Boston[train, "medv"], x.test=Boston[-train, 1:13])))
yhat4 <- g4$yhat.test.mean
ERR[i, 4] <- mean((yhat4 - boston.test)^2)
}
# Visualize simulation results
labels <- c("Bagging", "RF", "Boosting", "BART")
boxplot(ERR, boxwex=0.5, main="Ensemble Methods", col=2:5, names=labels, ylab="Mean Squared Errors", ylim=c(0,30))
colnames(ERR) <- labels
# Check statistical reports
apply(ERR, 2, summary)
apply(ERR, 2, var)
# Check rankings
RA <- t(apply(ERR, 1, rank))
RA
apply(RA, 2, table)
Mean of missclassification error rate
Bagging | RF | Boosting | BART |
12.265594 | 12.104095 | 11.323421 | 11.057668 |
Variance of missclassification error rate
Bagging | RF | Boosting | BART |
8.16965 | 12.691002 | 4.768590 | 7.754700 |
Ranking of simulation
Ranks | Bagging | RF | Boosting | BART |
1 | 3 | 3 | 8 | 6 |
2 | 4 | 5 | 3 | 8 |
3 | 6 | 3 | 6 | 6 |
4 | 7 | 9 | 3 | 1 |
- The variance of boosting model is best.
'Data Science > R' 카테고리의 다른 글
[R] Non-Linear Support Vector Machine (0) | 2022.12.06 |
---|---|
[R] Support Vector Machine (0) | 2022.12.06 |
[R] Tree-Based Methods : Boosting (0) | 2022.11.27 |
[R] Tree-Based Methods : Random Forest (0) | 2022.11.27 |
[R] Tree-Based Methods : Bagging (0) | 2022.11.27 |