##########################################
## k-Nearest Neighbors with Mnist Dataset ##
##########################################
##see: https://www.r-craft.org/r-news/exploring-handwritten-digit-classification-a-tidy-analysis-of-the-mnist-dataset/
##########################################

library(tidymodels)
library(rio)
library(janitor)

## Define a Function to plot ## 
## a 6 MNIST Observation      ##
FctPlotImages=function(Data){
  MnistLong <- Data %>% 
    mutate(instance = row_number()) %>%
    gather(pixel, value, -Label, -instance) %>%
    tidyr::extract(pixel, "pixel", "(\\d+)", convert = TRUE) %>%
    mutate(pixel = pixel - 2,
           x = pixel %% 28,
           y = 28 - pixel %/% 28) 
  Plot=MnistLong %>% ggplot(aes(x, y, fill = value)) +
    geom_tile() +
    facet_wrap(~ instance + Label)+
    scale_fill_gradient(
      low = "#000000",
      high = "#FFFFFF",
      guide = "colourbar",
      aesthetics = "fill"
    )
  return(Plot)}


# Load MNIST with 500 observations
DataMnist <- import("https://ai.lange-analytics.com/data/MN500.rds") %>%
  mutate(Label=as.factor(Label))


##MNIST with 1000 observations
# DataMnist <- import("https://ai.lange-analytics.com/data/MN1000.rds") %>% 
# mutate(Label=as.factor(Label))

##MNIST with 10,000 observations
# Takes up to 15 minutes to later 
# fit the data to the workflow model
# DataMnist <- import("https://ai.lange-analytics.com/data/MN10000.rds") %>% 
#   mutate(Label=as.factor(Label))


# Select 6 random images and plot the images
set.seed(7890)
DataToPlot=sample_n(DataMnist,6)
FctPlotImages(DataToPlot)

# Create Training and Testing Data
set.seed(123)
Split7030=initial_split(DataMnist,0.70, strata = Label)
DataTrain=training(Split7030)
DataTest=testing(Split7030)

# Defines Recipe
RecipeMnist=recipe(Label~., DataTrain)

# Defines Model-Design
ModelDesignKNN=nearest_neighbor(neighbors=5, weight_func = "rectangular") %>%
                 set_engine("kknn")%>%
                 set_mode("classification") 

# Adds Recipe and Model-Design to Workflow
# and Fits Workflow to Training Data
# (this can take a long time depending on 
#  the size of the training dataset.
#  Be patient and drink a coffee)
WFModelMnist=workflow() %>%
             add_recipe(RecipeMnist) %>% 
             add_model(ModelDesignKNN)%>% 
             fit(DataTrain)

# Uses the Fitted Workflow to Predict the Testing Data
# (DataTest is augmented with a column .pred_class 
#  containing the resulting predictions. This can take 
#  a long time depending on the size of the 
#  testing dataset. Be patient and drink a coffee)
DataTestWithPred = augment(WFModelMnist, data=DataTest)

# Prints the Confusion Matrix
conf_mat(DataTestWithPred, Label, .pred_class)

# Defines the Function MetricsSetMnist then
# Uses it to Calculte the Metrics
MetricsSetMnist=metric_set(accuracy, sensitivity, specificity)
MetricsSetMnist(DataTestWithPred,truth = Label, estimate=.pred_class)

# Finds Mispredicted Observations and Plots
# 6 Random Mispredicted Observations
set.seed(111)
DataMispredicted=DataTestWithPred %>% 
  filter(.pred_class!=Label) %>% 
  select(-.pred_class) %>% 
  sample_n(6)

FctPlotImages(DataMispredicted)
