본문 바로가기

Data Science/R

[R] Tree-Based Methods : Bayesian Additive Tree

1. What is Bayesian Additive Tree?

  • Bayesian additive regression trees(BART) is another ensemble method that uses decision trees as its building blocks.
  • BART methods combines other ensemble methods :
    • Each tree is constructed in a random manner as in bagging and random forest.
    • Each tree tries to capture signals not yet accounted for by the current model as in boosting.
  • BART method can be viewed as a Bayesian approach :
    • \(\theta_1 ~ p(\theta_1)\) : Prior distribution
    • \(\theta_1 | \theta_2, ... \theta_k\) : Posterior distribution
    • Each time we randomly perturb a tree in order to fit the residuals, we are in fact drawing a new tree from a posterior distribution.
    • MCMC(Markov chain Monte Carlo) algorithm
    • Remove out prediction results at burn-in period 

 

1.1 BART algorithms

  • \(\hat{f}^b_k(x)\) represents the prediction at \(x\) for the \(k\)th tree used in the \(b\)th iteration : \(k = 1, ..., K\) and \(b = 1, ..., B\)
  • Let \(\hat{f}^1_1(x) = ... = \hat{f}^1_K(x) = \frac{1}{nK}\sum_{i=1}^n y_i\)
  • Compute \(\hat{f}^1(x) = \sum_{k=1}^K \hat{f}^1_k(x)\)
  • For \(b = 2, ..., B\) : 
    • For \(k = 1, 2, ..., K\) :
      • For \(i = 1, ..., n\), compute the current partial residuals : \(r_i = y_i - \sum_{\dot{k} < k} \hat{f}\)
  •  Compute the mean after \(L\) burn-in samples : \(\hat{f}(x) = \frac{1}{B - L}\sum_{b = L+1}^B \hat{f}^b(x)\)

 

2. Building BART tree using lbart, pbart function

# Prerequirisite 
library(BART)

# Train-test split 
set.seed(123)
train <- sample(1:nrow(Heart), nrow(Heart)/2)
test <- setdiff(1:nrow(Heart), train)
x <- Heart[, -14]
y <- as.numeric(Heart[, 14])-1
xtrain <- x[train, ]
ytrain <- y[train]
xtest <- x[-train, ]
ytest <- y[-train]

# Logistic BART 
set.seed(11)
fit1 <- lbart(xtrain, ytrain, x.test=xtest)
names(fit1)

# Make predictions 
prob1 <- rep(0, length(ytest))
prob1[fit1$prob.test.mean > 0.5] <- 1
mean(prob1!=ytest)

# Probit BART 
set.seed(22)
fit2 <- pbart(xtrain, ytrain, x.test=xtest)

# Make Prediction 
prob2 <- rep(0, length(ytest))
prob2[fit2$prob.test.mean > 0.5] <- 1
mean(prob2!=ytest)

# Visualize results : lbart ~ pbart 
cbind(fit1$prob.test.mean, fit2$prob.test.mean)
plot(fit1$prob.test.mean, fit2$prob.test.mean, col=ytest+2,
xlab="Logistic BART", ylab="Probit BART")
abline(0, 1, lty=3, col="grey")
abline(v=0.5, lty=1, col="grey")
abline(h=0.5, lty=1, col="grey")
legend("topleft", col=c(2,3), pch=1,
legend=c("AHD = No", "AHD = Yes"))

 

 

  • Misclassification error rate : 0.1946309
 

 

3. Comparison of MSE among different models : Tree, RF, Boosting, BART

# Revisit Boston data set with a quantitative response
library(MASS)
summary(Boston)
dim(Boston)

# Train-test split
set.seed(111)
train <- sample(1:nrow(Boston), floor(nrow(Boston)*2/3))
boston.test <- Boston[-train, "medv"]

# Calculate misssclassification error rate of Regression tree
library(tree)
tree.boston <- tree(medv ~ ., Boston, subset=train)
yhat <- predict(tree.boston, newdata=Boston[-train, ])
mean((yhat - boston.test)^2)

# Calculate missclassification error rate of LSE: least square estimates
g0 <- lm(medv ~ ., Boston, subset=train)
pred0 <- predict(g0, Boston[-train,])
mean((pred0 - boston.test)^2)

# Calculate missclassification error rate of Bagging
library(randomForest)
g1 <- randomForest(medv ~ ., data=Boston, mtry=13, subset=train)
yhat1 <- predict(g1, newdata=Boston[-train, ])
mean((yhat1 - boston.test)^2)

# Calculate missclassification error rate of Random Forest (m = 4)
g2 <- randomForest(medv ~ ., data=Boston, mtry=4, subset=train)
yhat2 <- predict(g2, newdata=Boston[-train, ])
mean((yhat2 - boston.test)^2)

# Calculate missclassification error rate of Boosting (d = 4)
library(gbm)
g3 <- gbm(medv~., data = Boston[train, ], distribution="gaussian", n.trees=5000, interaction.depth=4)
yhat3 <- predict(g3, newdata=Boston[-train, ], n.trees=5000)
mean((yhat3 - boston.test)^2)

# Calculate missclassifcation error rate of BART
library(BART)
g4 <- gbart(Boston[train, 1:13], Boston[train, "medv"], x.test=Boston[-train, 1:13])
yhat4 <- g4$yhat.test.mean
mean((yhat4 - boston.test)^2)

 

Regression tree LSE Bagging Random Forest Boosting BART
24.59443 21.52607 12.43788 8.304427 12.2265 11.68592

 

4. Simulation study of 3

# Simulation: 4 ensemble methods
set.seed(1111)
N <- 20
ERR <- matrix(0, N, 4)

# replicate 20 times 
for (i in 1:N) {
    train <- sample(1:nrow(Boston), floor(nrow(Boston)*2/3))
    boston.test <- Boston[-train, "medv"]

    # Bagging
    g1 <- randomForest(medv ~ ., data=Boston, mtry=13, subset=train)
    yhat1 <- predict(g1, newdata=Boston[-train, ])
    ERR[i,1] <- mean((yhat1 - boston.test)^2)

    # Random forest
    g2 <- randomForest(medv ~ ., data=Boston, mtry=4, subset=train)
    yhat2 <- predict(g2, newdata=Boston[-train, ])
    ERR[i, 2] <- mean((yhat2 - boston.test)^2)

    # Boosting
    g3 <- gbm(medv~., data = Boston[train, ], n.trees=5000, distribution="gaussian", interaction.depth=4)
    yhat3 <- predict(g3, newdata=Boston[-train, ], n.trees=5000)
    ERR[i, 3] <- mean((yhat3 - boston.test)^2)

    # BART
    invisible(capture.output(g4 <- gbart(Boston[train, 1:13], Boston[train, "medv"], x.test=Boston[-train, 1:13])))
    yhat4 <- g4$yhat.test.mean
    ERR[i, 4] <- mean((yhat4 - boston.test)^2)
}

# Visualize simulation results 
labels <- c("Bagging", "RF", "Boosting", "BART")
boxplot(ERR, boxwex=0.5, main="Ensemble Methods", col=2:5, names=labels, ylab="Mean Squared Errors", ylim=c(0,30))
colnames(ERR) <- labels

# Check statistical reports 
apply(ERR, 2, summary)
apply(ERR, 2, var)

# Check rankings 
RA <- t(apply(ERR, 1, rank))
RA
apply(RA, 2, table)

 

 

Mean of missclassification error rate

Bagging RF Boosting BART
12.265594 12.104095 11.323421 11.057668

 

Variance of missclassification error rate

Bagging RF Boosting BART
8.16965 12.691002 4.768590 7.754700

 

Ranking of simulation

Ranks Bagging RF Boosting BART
1 3 3 8 6
2 4 5 3 8
3 6 3 6 6
4 7 9 3 1

 

  • The variance of boosting model is best.

'Data Science > R' 카테고리의 다른 글

[R] Non-Linear Support Vector Machine  (0) 2022.12.06
[R] Support Vector Machine  (0) 2022.12.06
[R] Tree-Based Methods : Boosting  (0) 2022.11.27
[R] Tree-Based Methods : Random Forest  (0) 2022.11.27
[R] Tree-Based Methods : Bagging  (0) 2022.11.27