# load needed packages
library(ggplot2)
library(rpart)
library(rpart.plot)
library(rsample) # <--- install tidymodels package first

# load in the data set
data <- diamonds

# split the data into a "training" set and a "testing" set using rsample
split <- initial_split(data, prop = 0.7)
train <- training(split)
test <- testing(split)

# train the model by building a decision tree (use "class" for classification)
model <- rpart(cut ~ depth + table + price + x, data = train, method = "class")

# view the decision tree
rpart.plot(model)

# make predictions on the test data (use "class" for actual predictions of categories)
predictions <- predict(model, test, type = "class")

# build the confusion matrix
confusion_matrix <- table(test$cut, predictions)
print(confusion_matrix)

# compute the accuracy
accuracy <- sum(diag(confusion_matrix)) / sum(confusion_matrix)
cat("Accuracy:", accuracy)