-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy path140-knn.Rmd
More file actions
162 lines (125 loc) · 5.63 KB
/
140-knn.Rmd
File metadata and controls
162 lines (125 loc) · 5.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# Lesson: K-Nearest Neighbour classification and regression
```{r echo=FALSE, warning=FALSE}
library(palmerpenguins)
library(tidyverse)
library(tidymodels)
library(kknn)
library(janitor)
# library(caret)
```
## Introduction
The k-nearest neighbours algorithm can be used to classify an observation or predict a response variable by combining observations of a set (of size $k$) of the nearest observations. The result is a very flexible method for prediction which incorprates information near a point. Changing the number of nearby neighbours used adjusts the amount of smoothing done, but if the observations are very unevenly spread out this can lead to surprising results. Similarly, prediction for points that are far from other data points suffers from some of the same problems as extrapolation and may have undesirable outcomes.
## Example
Let's apply the method to a the Palmer Penguin data. I will use some measurements on the penguins to determine the species of a few observations. I've concealed the species identity of a few observations, labeled as "missing" in this plot. There are some missing values for sex, but I won't use those in my model, so I'll remove them.
```{r echo=FALSE, warning=FALSE}
penguins %>%
mutate(species = as.character(species),
species2 = case_when(
row_number(species) %in% c(213, 188, 173, 165,146, 100, 231, 96, 217, 254) ~ "missing",
TRUE ~ species),
species2 = as.factor(species2)) %>%
filter(!is.na(sex)) -> penguins2
penguins2 %>% ggplot(aes(x=bill_length_mm, y = flipper_length_mm, color=species2)) +
geom_point() # + geom_text(aes(label=row_number(species)), color="black", label.size=0.1)
```
Now we make a model and try to predict the unknown species.
```{r}
penguins2 %>%
mutate(random = runif(n()),
train = (random > 0.25) & (species2 != "missing"),
test = (!train) ) -> penguins2
penguin_recipe <-
recipe(species2 ~ bill_length_mm + flipper_length_mm, # + bill_depth_mm + body_mass_g, # add to improve model
data = penguins2 ) %>%
step_scale(all_predictors()) %>%
step_center(all_predictors())
knn_specification <- nearest_neighbor(mode = "classification") # optional: neighbours = <integer>
penguin_knn <- workflow() %>%
add_recipe(penguin_recipe) %>%
add_model(knn_specification) %>%
fit(data = penguins2 %>% filter(train))
penguin_knn
penguin_predicted <- bind_cols(penguins2,
predict(penguin_knn, penguins2 ))
```
Compare the original values and the predicted values: we get a perfect recovery!
```{r}
penguin_predicted %>% filter(species2 == "missing") %>% tabyl(species, .pred_class)
penguin_predicted %>% filter(test) %>% tabyl(species, .pred_class)
```
Plot the data, using larger symbols for misclassified data.
```{r}
penguin_predicted %>% mutate(correct = species == .pred_class) %>%
ggplot(aes(x = bill_length_mm, y = flipper_length_mm, shape = species,
size = train, color = correct)) +
geom_point(alpha = 0.75) +
scale_color_viridis_d(begin = 0.2, end = 0.8) +
scale_size_manual(values = c(3, 2))
```
Now try regression. Predict bill_length from the other variables.
```{r}
penguins3 <- penguins2 %>%
mutate(flipper_length_s = scale(flipper_length_mm),
body_mass_s = scale(body_mass_g),
bill_depth_s = scale(bill_depth_mm),
bill_length_s = scale(bill_length_mm))
penguin_recipe2 <-
recipe(body_mass_g ~ bill_length_s + species + flipper_length_s + bill_depth_s,
data = penguins3 )
knn_specification2 <- nearest_neighbor(mode = "regression") # optional: neighbours = <integer>
penguin_knn2 <- workflow() %>%
add_recipe(penguin_recipe2) %>%
add_model(knn_specification2) %>%
fit(data = penguins3 %>% filter(train))
penguin_knn2
penguin_predicted <- bind_cols(penguins3,
predict(penguin_knn2, new_data = penguins3 ))
penguin_predicted %>% ggplot(aes(x = body_mass_g, y = .pred,
color = species, shape = train)) +
geom_point() +
geom_abline()
```
Compare to linear regression. Rank deficient since there is an intercept and an estimate for each species. Not clear how to suppress the intercept.
```{r}
linear_specification2 <- linear_reg(mode = "regression") # optional: neighbours = <integer>
penguin_linear <- workflow() %>%
add_recipe(penguin_recipe2) %>%
add_model(linear_specification2) %>%
fit(data = penguins3 %>% filter(train))
penguin_linear
penguin_predicted <- bind_cols(penguins3,
predict(penguin_linear, new_data = penguins3 ))
penguin_predicted %>% ggplot(aes(x = body_mass_g, y = .pred,
color = species, shape = train)) +
geom_point() +
geom_abline()
```
What does it do?
Picture
Classification
Regression
Local method
Weighting by distance
## How to use it
Requires pairs predictors, classification/value
Assigns classification or value to new points based on nearby points
Nonparametric
Standardize datasets
Can use any distance metric
Can be Slow on large datasets
Use PCA for dimension reduction if large number of correlated predictors
Test/train
Confusion matrix
Determining k by optimization
Good problems
Penguins
Diamonds dataset?
Cars dataset?
Packages
Caret
Mlr
Tidymodels / parsnip
## Sources
* [Wikipedia page](https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm)
* UBC data science course on knn [classification](https://ubc-dsci.github.io/introduction-to-datascience/classification.html) and [regression](https://ubc-dsci.github.io/introduction-to-datascience/regression1.html)
* Tidymodels documentation for [knn](https://www.tidymodels.org/learn/statistics/k-means/)