# load required packages
library(tidymodels)
library(ranger)
library(randomForest)

# load in the dataset
data <- iris

# split the data
split <- initial_split(iris, prop = 0.65)
train <- training(split)
test <- testing(split)

# create a recipe (used for preprocessing the data)
rec <- recipe(Species ~ ., data = train) |>
  step_corr(all_predictors())

# define model specs

# ranger engine
ranger_spec <- rand_forest(trees = 100) |>
  set_mode("classification") |>
  set_engine("ranger")

# randomForest engine
rf_spec <- rand_forest(trees = 100) |>
  set_mode("classification") |>
  set_engine("randomForest")

# create workflows (preprocesses the data and fits the model)

ranger_wf <- workflow() |>
  add_model(ranger_spec) |>
  add_recipe(rec)

rf_wf <- workflow() |>
  add_model(rf_spec) |>
  add_recipe(rec)


# cross-validation (optional but good practice)

folds <- vfold_cv(train, v = 10)

ranger_res <- fit_resamples(
  ranger_wf,
  resamples = folds,
  metrics = metric_set(accuracy, roc_auc)
)

rf_res <- fit_resamples(
  rf_wf,
  resamples = folds,
  metrics = metric_set(accuracy, roc_auc)
)

# view CV results
collect_metrics(ranger_res)
collect_metrics(rf_res)


# final fit on training data
ranger_fit <- fit(ranger_wf, data = train)
rf_fit <- fit(rf_wf, data = train)

# make predictions on the test set
ranger_pred <- predict(ranger_fit, test) |>
  bind_cols(test)

rf_pred <- predict(rf_fit, test) |>
  bind_cols(test)

# evaluate performance

# accuracy and other metrics
ranger_pred |>
  metrics(truth = Species, estimate = .pred_class)

rf_pred |>
  metrics(truth = Species, estimate = .pred_class)

# confusion matrices

ranger_pred |>
  conf_mat(truth = Species, estimate = .pred_class)

rf_pred |>
  conf_mat(truth = Species, estimate = .pred_class) |>
  autoplot(type = "heatmap")

# ROC curve (just for ranger, as an example)
ranger_probs <- predict(ranger_fit, test, type = "prob") |>
  bind_cols(test)

ranger_probs |>
  roc_curve(Species, .pred_setosa:.pred_virginica) |>
  autoplot()
