library(tidyverse)
library(sandwich)
library(gridExtra)
source("sin_basis_lib.R")
theme_update(text = element_text(size=24))
options(repr.plot.width=12, repr.plot.height=6)
Cross validation and variable selection
\(\,\)
Simulated example
# Construct a basis of sin functions of increasing frequency
<- 50
pmax
<- seq(0, 1, length.out=100)
x_grid <- EvalFeatures(x_grid, pmax)
f_mat <- colnames(f_mat)
f_names <- data.frame(f_mat) %>%
f_df mutate(x=x_grid) %>%
pivot_longer(cols=(-x)) %>%
mutate(freq=as.numeric(sub("^f", "", name)))
ggplot(f_df %>% filter(freq < 10)) +
geom_line(aes(x=x, y=value, color=name))
<- 500
n_obs <- 0.4
sigma_true
<- data.frame()
models_df for (p_true in seq(1, pmax, 7)) {
<-
models_df bind_rows(models_df, DrawData(n_obs, sigma_true, GetBeta(p_true)) %>%
mutate(p_true=p_true))
}
#models_df
ggplot(models_df) +
geom_line(aes(x=x, y=ey_true, color=p_true)) +
facet_grid(~ p_true)
<- 20
p_true <- GetBeta(p_true)
beta_true
<- DrawData(n_obs, sigma_true, beta_true)
data_df <- DrawData(n_obs, sigma_true, beta_true)
test_data_df
ggplot(data_df) +
geom_line(aes(x=x, y=ey_true)) +
geom_point(aes(x=x, y=y))
<- function(p) {
RegressionFormula stopifnot(p <= pmax)
<- sprintf("y ~ -1 + %s", paste(f_names[1:p], collapse=" + "))
form return(form)
}
<- function(data_df, p) {
ComputePredictions stopifnot(p <= pmax)
<- RegressionFormula(p)
form <- lm(formula(form), data_df)
lm_fit return(data.frame(y=data_df$y, ey=data_df$ey_true, y_pred=lm_fit$fitted.value, n=data_df$n))
}
<- 500
n_obs <- 0.4
sigma_true <- 20
n_sims
# Draw multiple datasets with x fixed (so we can estimate bias and variance)
<- DrawData(n_obs, sigma_true, beta_true)$x
x <- lapply(1:n_sims, \(s) DrawData(n_obs, sigma_true, beta_true, x=x))
data_df_list <- data.frame()
err_df
# sanity check that all the datasets have the same x
stopifnot(abs((data_df_list[[1]])$x[1] - (data_df_list[[2]])$x[1]) < 1e-8)
<- unique(c(1, 2, 3, 4, 5, seq(1, pmax, 5)))
p_seq <- txtProgressBar(min=0, max=n_sims, style=3)
pb for (sim in 1:n_sims) {
setTxtProgressBar(pb, sim)
for (p in p_seq) {
<- data_df_list[[sim]]
data_df <- ComputePredictions(data_df, p)
pred_df <- pred_df %>% mutate(p=p, sim=sim)
this_err_df <- bind_rows(err_df, this_err_df)
err_df
}
}close(pb)
|======================================================================| 100%
# Estimate the expectated prediction at each datapoint (to estimate the bias)
<-
err_df %>%
err_df group_by(p, n) %>%
mutate(ey_pn=mean(y_pred)) %>%
ungroup()
# Compute the MSE and related quantities
<-
mse_df %>%
err_df group_by(p) %>%
mutate(pred_err=y_pred - ey,
reg_err=y_pred - ey_pn,
bias=ey_pn - ey) %>%
summarize(mse=mean(pred_err^2),
var=mean(reg_err^2),
bias2=mean(bias^2), # The averaging over new data
rand_err=sigma_true^2) %>%
mutate(mse_check=(mse - (var + bias2)) / mse) # Sanity check -- should be zero
mse_df
p | mse | var | bias2 | rand_err | mse_check |
---|---|---|---|---|---|
<dbl> | <dbl> | <dbl> | <dbl> | <dbl> | <dbl> |
1 | 0.279678923 | 0.0002413654 | 0.2794375575 | 0.16 | 0.000000e+00 |
2 | 0.074893516 | 0.0004585316 | 0.0744349847 | 0.16 | 0.000000e+00 |
3 | 0.047569732 | 0.0009167353 | 0.0466529970 | 0.16 | 0.000000e+00 |
4 | 0.043212298 | 0.0011564507 | 0.0420558478 | 0.16 | 0.000000e+00 |
5 | 0.038724965 | 0.0015370072 | 0.0371879578 | 0.16 | 1.791840e-16 |
6 | 0.028645389 | 0.0017793718 | 0.0268660174 | 0.16 | 0.000000e+00 |
11 | 0.016470018 | 0.0030544002 | 0.0134156180 | 0.16 | 0.000000e+00 |
16 | 0.009757164 | 0.0045852866 | 0.0051718777 | 0.16 | 0.000000e+00 |
21 | 0.006371954 | 0.0060273716 | 0.0003445819 | 0.16 | -1.361218e-16 |
26 | 0.008059485 | 0.0076840986 | 0.0003753863 | 0.16 | 0.000000e+00 |
31 | 0.009651518 | 0.0092509537 | 0.0004005644 | 0.16 | 0.000000e+00 |
36 | 0.011174181 | 0.0107215434 | 0.0004526378 | 0.16 | 0.000000e+00 |
41 | 0.012872124 | 0.0123678116 | 0.0005043121 | 0.16 | 1.347659e-16 |
46 | 0.014374860 | 0.0138330781 | 0.0005417821 | 0.16 | 0.000000e+00 |
<- ggplot(mse_df, aes(x=p)) +
mse_graph geom_line(aes(y=mse, color="MSE")) +
geom_line(aes(y=var, color="var")) +
geom_line(aes(y=bias2, color="bias2")) +
geom_vline(aes(xintercept=p_true), color="purple")
grid.arrange(
mse_graph,+ scale_y_log10(),
mse_graph ncol=2)
# Perform cross-validation
<- 10
n_folds <- sample(1:n_folds, n_obs, replace=TRUE)
fold_index
<- data.frame()
err_cv_df <- length(data_df_list)
n_data_sets stopifnot(n_data_sets <= length(data_df_list))
<- txtProgressBar(min=0, max=length(p_seq) * n_data_sets * n_folds, style=3)
pb <- 0
pb_ind for (p in p_seq) {
for (data_ind in 1:n_data_sets) {
for (fold in 1:n_folds) {
<- pb_ind + 1
pb_ind setTxtProgressBar(pb, pb_ind)
<- data_df_list[[data_ind]] %>% mutate(fold_index=fold_index)
data_df <-lm(formula(RegressionFormula(p)), data_df %>% filter(fold_index != !!fold))
lm_fit <- data_df %>% filter(fold_index == !!fold)
data_fold_df <- predict(lm_fit, data_fold_df)
y_pred <- mean((y_pred - data_fold_df$y)^2)
mse <- data.frame(mse=mse, fold=fold, p=p, data_ind=data_ind)
this_err_cv_df <- bind_rows(err_cv_df, this_err_cv_df)
err_cv_df
}}}close(pb)
|================================================================ | 91%
<-
err_cv_agg_df %>%
err_cv_df group_by(data_ind, p) %>%
summarize(mse=mean(mse), .groups="keep") %>%
ungroup() %>%
group_by(data_ind) %>%
mutate(is_mse_min=mse <= min(mse))
<-
err_cv_agg_agg_df %>%
err_cv_agg_df group_by(p) %>%
summarize(mse_sd=sd(mse), mse=mean(mse), .groups="drop")
<- which.min(err_cv_agg_agg_df$mse)
min_ind <- err_cv_agg_agg_df$mse[min_ind]
min_mse <- err_cv_agg_agg_df$mse_sd[min_ind]
min_mse_sd <- err_cv_agg_agg_df$p[min_ind]
min_p <- err_cv_agg_agg_df %>%
p_cv filter(mse <= min_mse + min_mse_sd) %>%
pull(p) %>%
min()
print(p_cv)
ggplot(err_cv_agg_df, aes(x=p)) +
geom_line(aes(y=mse, group=data_ind, color="CV MSE")) +
geom_line(aes(x=p, y=mse + sigma_true^2, color="True MSE"), data=mse_df, lwd=4) +
geom_point(aes(p, mse), data=filter(err_cv_agg_df, is_mse_min == TRUE)) +
geom_vline(aes(xintercept=p_true)) +
scale_y_log10()
ggplot(err_cv_agg_df, aes(x=p)) +
geom_line(aes(x=p, y=mse, color="Average CV MSE"), data=err_cv_agg_agg_df) +
geom_ribbon(aes(x=p, ymin=mse - mse_sd, ymax=mse + mse_sd), data=err_cv_agg_agg_df, alpha=0.3) +
geom_line(aes(x=p, y=mse + sigma_true^2, color="True MSE"), data=mse_df, lwd=2) +
geom_point(aes(x=p_cv, y=min_mse + min_mse_sd)) +
geom_point(aes(x=min_p, y=min_mse)) +
geom_vline(aes(xintercept=p_true)) +
scale_y_log10()