# load needed packages
library(ggplot2)
library(rsample)
library(class)

# load in the data set
data <- diamonds

# split the data set into a "training" set and a "testing" set
split <- initial_split(data, prop = 0.7)
train <- training(split)
test <- testing(split)

# scale
train_means <- apply(train[, scale_columns], 2, mean)
train_sds   <- apply(train[, scale_columns], 2, sd)
train_scaled <- as.data.frame(scale(train[, scale_columns], center = train_means, scale = train_sds))
test_scaled <- as.data.frame(scale(test[, scale_columns], center = train_means, scale = train_sds))
train_scaled$cut <- train$cut
test_scaled$cut  <- test$cut

train <- train_scaled
test <- test_scaled

# train the model by using kNN, and make predictions
predictions <- knn(
  train = train[,c("carat", "depth")],
  test = test[,c("carat", "depth")],
  cl = factor(train$cut),
  k = 5
)

# build the confusion matrix
confusion_matrix <- table(test$cut, predictions)

# compute and show the accuracy of your decision tree using the test set
accuracy <- sum(diag(confusion_matrix)) / sum(confusion_matrix)
print(accuracy)

# one way to visualize the kNN model (this only works if you are using 2 predictors!)

test$predictions <- predictions

# actual
ggplot(train, aes(x = carat, y = depth, color = cut)) +
  geom_point(alpha = 0.5) +
  labs(
    title = "Training Data (Actual Cut)",
    x = "Carat",
    y = "Depth",
    color = "Cut"
  ) +
  theme_minimal()

# predicted
ggplot(test, aes(x = carat, y = depth, color = predictions)) +
  geom_point(alpha = 0.7) +
  labs(
    title = "Test Data (kNN Predictions)",
    x = "Carat",
    y = "Depth",
    color = "Predicted Cut"
  ) +
  theme_minimal()
