# load needed packages
library(tidymodels)
library(klaR)
library(naivebayes)
library(discrim)

# split the data (rsample)
split <- initial_split(mpg, prop = 0.7)
train <- training(split)
test <- testing(split)

# create a recipe, used for preprocessing the data (recipes)
mpg_recipe <- recipe(drv ~ class, data = mpg) |>
  step_string2factor(all_nominal_predictors())

# define model specs (parsnip)

# naivebayes model
model1 <- naive_Bayes(Laplace = 1) |>
  set_mode("classification") |>
  set_engine("naivebayes")

# klaR model
model2 <- naive_Bayes(Laplace = 1) |>
  set_mode("classification") |>
  set_engine("klaR")

# create workflows (preprocesses the data and fits the model)

wf1 <- workflow() |>
  add_model(model1) |>
  add_recipe(mpg_recipe)

wf2 <- workflow() |>
  add_model(model2) |>
  add_recipe(mpg_recipe)

# cross-validation (optional but good practice)

folds <- vfold_cv(train, v = 10)

res1 <- fit_resamples(
  wf1,
  resamples = folds,
  metrics = metric_set(accuracy, roc_auc)
)

res2 <- fit_resamples(
  wf2,
  resamples = folds,
  metrics = metric_set(accuracy, roc_auc)
)

# view CV results
collect_metrics(res1)
collect_metrics(res2)

# final fit on training data
fit1 <- fit(wf1, data = train)
fit2 <- fit(wf2, data = train)

test$drv <- factor(test$drv)

# make predictions on the test set
pred1 <- predict(fit1, test) |>
  bind_cols(test)
pred2 <- predict(fit2, test) |>
  bind_cols(test)

# evaluate performance

# accuracy and other metrics
pred1 |>
  metrics(truth = drv, estimate = .pred_class)
pred2 |>
  metrics(truth = drv, estimate = .pred_class)

# confusion matrices
pred1 |>
  conf_mat(truth = drv, estimate = .pred_class)
pred2 |>
  conf_mat(truth = drv, estimate = .pred_class)

# ROC curve
probs <- predict(fit1, test, type = "prob") |>
  bind_cols(test)

probs |>
  roc_curve(drv, .pred_4:.pred_r) |>
  autoplot()
