2.8 K Nearest Neighbors

In assignment 2.7, we explored how omitted variables can badly bias parameter estimates in OLS. Another way that OLS can be biased is when the underlying data generating process is nonlinear: we call this problem functional misspecification.

Part 1: Load and Visualize the Data

library(tidyverse)

nonlinear <- read_csv("https://raw.githubusercontent.com/cobriant/320data/refs/heads/master/nonlinear.csv")

nonlinear_train <- nonlinear %>% filter(train == 1)
nonlinear_test <- nonlinear %>% filter(train == 0)

a) Data Visualization

Create a scatter plot of the data, coloring points by whether they’re in the training or test set.

# nonlinear %>%
#   ___(aes(x = x, y = y, color = as.factor(train))) +
#   ___()

Part 2: OLS Limitations

Let’s fit three OLS models with increasing polynomial terms to see why linear models might struggle with nonlinear data.

b) Polynomial Regression

Complete the code to create three plots showing predictions from:

  1. Linear model (y ~ x)
  2. Quadratic model (y ~ x + x^2)
  3. Cubic model (y ~ x + x^2 + x^3)
# nonlinear_test %>%
#   mutate(
#     prediction = predict(
#       lm(_____, nonlinear_train), 
#       newdata = nonlinear_test)
#   ) %>%
#   _____(_____(x = x, y = y)) +
#   geom_point() +
#   geom_line(_____(y = prediction), color = "blue")
# 
# nonlinear_test %>%
#   mutate(
#     prediction = predict(
#       lm(_____ + I(x^2), nonlinear_train), 
#       newdata = nonlinear_test)
#   ) %>%
#   _____(_____(x = x, y = y)) +
#   geom_point() +
#   geom_line(_____(y = prediction), color = "purple")
#   
# nonlinear_test %>%
#   mutate(
#     prediction = predict(
#       lm(_____ + I(x^2) + I(x^3), nonlinear_train), 
#       newdata = nonlinear_test)
#   ) %>%
#   _____(_____(x = x, y = y)) +
#   geom_point() +
#   geom_line(_____(y = prediction), color = "orchid")

Part 3: K Nearest Neighbors Implementation

OLS is a parametric approach: you estimate model parameters \(\beta_0\) and \(\beta_1\).

K Nearest Neighbors makes predictions for the test data set based on nearby observations from the training data set. It’s a nonparametric approach: there are no parameters to estimate. This is both a strength and a weakness: the strength is that with KNN, you don’t make any assumptions about the functional form of the data generating process, so you can’t get the functional form wrong. The weakness is that you don’t get parameter estimates, so you can’t do causal inference: you can never say with KNN “a one-unit increase in X leads to a ___-unit increase in Y on average”. KNN lets you do prediction in a flexible way, but not inference.

Let’s understand how KNN works step by step.

c) Fill in the missing pieces to implement KNN from scratch:

k <- 2 # We'll find the 2 nearest neighbors. You could also make this value 3, 4, etc.

# Step 1: For each test data point x, we need to:
#   a) Calculate the absolute distances to all training points x
#   b) Find the k closest training points x
#   c) Take the mean of their y values as our prediction

# Grab the first test data point x:
x_test <- nonlinear_test %>%
  slice(1) %>%
  pull(x)

# Find x_test's nearest neighbors:
# nonlinear_train %>%
#   mutate(distance = abs(x - _____)) %>%
#   _____(distance) %>%
#   _____(1:k) %>%
#   _____(prediction = mean(y)) %>%
#   pull(prediction)

# Step 2: Use `map()` to create a mapping between each row in the test data set and the prediction via KNN.

# nonlinear_test %>%
#   mutate(
#     yhat = map_dbl(
#       pull(nonlinear_test, x), 
#       function(x_test) nonlinear_train %>%
#         mutate(distance = _______) %>%
#         _____(distance) %>% 
#         _____(1:k) %>% 
#         _____(prediction = mean(y)) %>% 
#         pull(prediction)
#       )) %>%
#   ggplot(aes(x = x, y = y)) +
#   geom_point() +
#   geom_line(aes(y = yhat), color = "red")

Part 4: Using the FNN Package (Fast Nearest Neighbors)

Let’s see if FNN gives us similar results.

d) Using knn.reg()

# Install this package, then delete the line of code so you don't keep installing it over and over:
# install.packages("FNN")
library(FNN)

# predictions <- knn.reg(
#   train = matrix(pull(nonlinear_train, _____)),
#   test = matrix(pull(nonlinear_test, _____)),
#   y = pull(nonlinear_train, y),
#   k = 2
# )

# nonlinear_test %>%
#   bind_cols(pred = predictions[["pred"]]) %>%
#   ggplot(aes(x = x, y = y)) +
#   geom_point() +
#   geom_line(aes(y = pred), color = "blue")