$$ \newcommand{\mybold}[1]{\boldsymbol{#1}} \newcommand{\trans}{\intercal} \newcommand{\norm}[1]{\left\Vert#1\right\Vert} \newcommand{\abs}[1]{\left|#1\right|} \newcommand{\bbr}{\mathbb{R}} \newcommand{\bbz}{\mathbb{Z}} \newcommand{\bbc}{\mathbb{C}} \newcommand{\gauss}[1]{\mathcal{N}\left(#1\right)} \newcommand{\chisq}[1]{\mathcal{\chi}^2_{#1}} \newcommand{\studentt}[1]{\mathrm{StudentT}_{#1}} \newcommand{\fdist}[2]{\mathrm{FDist}_{#1,#2}} \newcommand{\iid}{\overset{\mathrm{IID}}{\sim}} \newcommand{\argmin}[1]{\underset{#1}{\mathrm{argmin}}\,} \newcommand{\projop}[1]{\underset{#1}{\mathrm{Proj}}\,} \newcommand{\proj}[1]{\underset{#1}{\mybold{P}}} \newcommand{\expect}[1]{\mathbb{E}\left[#1\right]} \newcommand{\prob}[1]{\mathbb{P}\left(#1\right)} \newcommand{\dens}[1]{\mathit{p}\left(#1\right)} \newcommand{\var}[1]{\mathrm{Var}\left(#1\right)} \newcommand{\cov}[1]{\mathrm{Cov}\left(#1\right)} \newcommand{\sumn}{\sum_{n=1}^N} \newcommand{\meann}{\frac{1}{N} \sumn} \newcommand{\cltn}{\frac{1}{\sqrt{N}} \sumn} \newcommand{\trace}[1]{\mathrm{trace}\left(#1\right)} \newcommand{\diag}[1]{\mathrm{Diag}\left(#1\right)} \newcommand{\grad}[2]{\nabla_{#1} \left. #2 \right.} \newcommand{\gradat}[3]{\nabla_{#1} \left. #2 \right|_{#3}} \newcommand{\fracat}[3]{\left. \frac{#1}{#2} \right|_{#3}} \newcommand{\W}{\mybold{W}} \newcommand{\w}{w} \newcommand{\wbar}{\bar{w}} \newcommand{\wv}{\mybold{w}} \newcommand{\X}{\mybold{X}} \newcommand{\x}{x} \newcommand{\xbar}{\bar{x}} \newcommand{\xv}{\mybold{x}} \newcommand{\Xcov}{\mybold{M}_{\X}} \newcommand{\Xcovhat}{\hat{\mybold{M}}_{\X}} \newcommand{\Covsand}{\Sigmam_{\mathrm{sand}}} \newcommand{\Covsandhat}{\hat{\Sigmam}_{\mathrm{sand}}} \newcommand{\Z}{\mybold{Z}} \newcommand{\z}{z} \newcommand{\zv}{\mybold{z}} \newcommand{\zbar}{\bar{z}} \newcommand{\Y}{\mybold{Y}} \newcommand{\Yhat}{\hat{\Y}} \newcommand{\y}{y} \newcommand{\yv}{\mybold{y}} \newcommand{\yhat}{\hat{\y}} \newcommand{\ybar}{\bar{y}} \newcommand{\res}{\varepsilon} \newcommand{\resv}{\mybold{\res}} \newcommand{\resvhat}{\hat{\mybold{\res}}} \newcommand{\reshat}{\hat{\res}} \newcommand{\betav}{\mybold{\beta}} \newcommand{\betavhat}{\hat{\betav}} \newcommand{\betahat}{\hat{\beta}} \newcommand{\betastar}{{\beta^{*}}} \newcommand{\betavstar}{{\betav^{*}}} \newcommand{\loss}{\mathscr{L}} \newcommand{\losshat}{\hat{\loss}} \newcommand{\f}{f} \newcommand{\fhat}{\hat{f}} \newcommand{\bv}{\mybold{\b}} \newcommand{\bvhat}{\hat{\bv}} \newcommand{\alphav}{\mybold{\alpha}} \newcommand{\alphavhat}{\hat{\av}} \newcommand{\alphahat}{\hat{\alpha}} \newcommand{\omegav}{\mybold{\omega}} \newcommand{\gv}{\mybold{\gamma}} \newcommand{\gvhat}{\hat{\gv}} \newcommand{\ghat}{\hat{\gamma}} \newcommand{\hv}{\mybold{\h}} \newcommand{\hvhat}{\hat{\hv}} \newcommand{\hhat}{\hat{\h}} \newcommand{\gammav}{\mybold{\gamma}} \newcommand{\gammavhat}{\hat{\gammav}} \newcommand{\gammahat}{\hat{\gamma}} \newcommand{\new}{\mathrm{new}} \newcommand{\zerov}{\mybold{0}} \newcommand{\onev}{\mybold{1}} \newcommand{\id}{\mybold{I}} \newcommand{\sigmahat}{\hat{\sigma}} \newcommand{\etav}{\mybold{\eta}} \newcommand{\muv}{\mybold{\mu}} \newcommand{\Sigmam}{\mybold{\Sigma}} \newcommand{\rdom}[1]{\mathbb{R}^{#1}} \newcommand{\RV}[1]{{#1}} \def\A{\mybold{A}} \def\A{\mybold{A}} \def\av{\mybold{a}} \def\a{a} \def\B{\mybold{B}} \def\b{b} \def\S{\mybold{S}} \def\sv{\mybold{s}} \def\s{s} \def\R{\mybold{R}} \def\rv{\mybold{r}} \def\r{r} \def\V{\mybold{V}} \def\vv{\mybold{v}} \def\v{v} \def\vhat{\hat{v}} \def\U{\mybold{U}} \def\uv{\mybold{u}} \def\u{u} \def\W{\mybold{W}} \def\wv{\mybold{w}} \def\w{w} \def\tv{\mybold{t}} \def\t{t} \def\Sc{\mathcal{S}} \def\ev{\mybold{e}} \def\Lammat{\mybold{\Lambda}} \def\Q{\mybold{Q}} \def\eps{\varepsilon} $$

Cross validation and variable selection

\(\,\)

library(tidyverse)
library(sandwich)
library(gridExtra)

theme_update(text = element_text(size=24))
options(repr.plot.width=12, repr.plot.height=6)
── Attaching core tidyverse packages ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse 2.0.0 ──
✔ dplyr     1.1.2     ✔ readr     2.1.4
✔ forcats   1.0.0     ✔ stringr   1.5.0
✔ ggplot2   3.4.2     ✔ tibble    3.2.1
✔ lubridate 1.9.2     ✔ tidyr     1.3.0
✔ purrr     1.0.1     
── Conflicts ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag()    masks stats::lag()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors

Attaching package: ‘gridExtra’


The following object is masked from ‘package:dplyr’:

    combine

# Create a feature matrix at the grid points x in [0,1] using pmax
# sinusoidal Fourier features
EvalFeatures <- function(x, pmax) {
    n_obs <- length(x)
    freqs <- rep(pi * (1:pmax), each=n_obs)
    f_mat <- matrix(sin(freqs * rep(x, pmax)), nrow=n_obs, ncol=pmax)
    colnames(f_mat) <- paste0("f", 1:pmax)
    return(f_mat)
}

# Draw a regression dataset with
# - n_obs: number of observations
# - sigma_true: residual standard deviation
# - beta_true: regression coefficient
# - x: A set of x in [0,1] at which to evaluate the features, or a uniform draw if null
DrawData <- function(n_obs, sigma_true, beta_true, x=NULL) {
    if (is.null(x)) {        
        x <- runif(n_obs)
    }
    f_mat <- EvalFeatures(x, length(beta_true))
    ey_true <- f_mat %*% beta_true
    y <- ey_true + rnorm(n_obs, sd=sigma_true)
    data_df <- data.frame(f_mat) %>%
        mutate(x=x, y=y, ey_true=ey_true) %>%
        mutate(n=1:n())
    return(data_df)
}


# Select a beta with decaying content at higher frequencies,
# but with only the first p_true components non-zero 
GetBeta <- function(p_true) {
    beta_full <- (runif(pmax) + 0.1) / (1:pmax)
    beta_full <- beta_full / sqrt(sum(beta_full^2))
    beta_true <- rep(0, pmax)
    beta_true[1:p_true] <- beta_full[1:p_true]
    return(beta_true)
}

Simulated example: A sinusoidal basis

# Construct a basis of sin functions of increasing frequency

pmax <- 50

x_grid <- seq(0, 1, length.out=100)
f_mat <- EvalFeatures(x_grid, pmax)
f_names <- colnames(f_mat)
f_df <- data.frame(f_mat) %>% 
    mutate(x=x_grid) %>% 
    pivot_longer(cols=(-x)) %>%
    mutate(freq=as.numeric(sub("^f", "", name)))
ggplot(f_df %>% filter(freq < 10)) +
    geom_line(aes(x=x, y=value, color=name))

n_obs <- 500
sigma_true <- 0.4

models_df <- data.frame()
for (p_true in seq(1, pmax, 7)) {
    models_df <-
        bind_rows(models_df, DrawData(n_obs, sigma_true, GetBeta(p_true)) %>% 
                  mutate(p_true=p_true))
}

As we increase the number of true regressors, the functions get more wiggly, since we are including higher–frequency components. Our task is to try to figure out how wiggly our function should be by trying to estimate how many regressors to include in the model!

#models_df
ggplot(models_df) +
    geom_line(aes(x=x, y=ey_true, color=p_true)) +
facet_grid(~ p_true)

Fix a ground truth and run some simulations

We’ll fix a particular number of components, fix our regressors, and then generate data to see the bias / variance tradeoff.

p_true <- 20
beta_true <- GetBeta(p_true)

data_df <- DrawData(n_obs, sigma_true, beta_true)
test_data_df <- DrawData(n_obs, sigma_true, beta_true)

ggplot(data_df) +
    geom_line(aes(x=x, y=ey_true)) +
    geom_point(aes(x=x, y=y))

RegressionFormula <- function(p) {
    stopifnot(p <= pmax)
    form <- sprintf("y ~ -1 + %s", paste(f_names[1:p], collapse=" + "))
    return(form)
}


# Run a regression with the first p predictors and return the fitted values.
ComputePredictions <- function(data_df, p) {
    stopifnot(p <= pmax)
    form <- RegressionFormula(p)
    lm_fit <- lm(formula(form), data_df)
    return(data.frame(y=data_df$y, ey=data_df$ey_true, y_pred=lm_fit$fitted.value, n=data_df$n))
}

Now we run some simulations. For each simulation, we use a new draw of \(y\), but all the same \(x\), and all the same \(\beta\). For each dataset, we run regressions for a range of \(p\). Because we know the ground truth, we can compare the predictions to the truth.

n_obs <- 500
sigma_true <- 0.4
n_sims <- 20

# Draw multiple datasets with x fixed (so we can estimate bias and variance)
x <- DrawData(n_obs, sigma_true, beta_true)$x
data_df_list <- lapply(1:n_sims, \(s) DrawData(n_obs, sigma_true, beta_true, x=x))
err_df <- data.frame()

# sanity check that all the datasets have the same x
stopifnot(abs((data_df_list[[1]])$x[1] - (data_df_list[[2]])$x[1]) < 1e-8)

p_seq <- unique(c(1, 2, 3, 4, 5, seq(1, pmax, 5)))
pb <- txtProgressBar(min=0, max=n_sims, style=3)
for (sim in 1:n_sims) {
    setTxtProgressBar(pb, sim)
    for (p in p_seq) {
        data_df <- data_df_list[[sim]]
        pred_df <- ComputePredictions(data_df, p)
        this_err_df <- pred_df %>% mutate(p=p, sim=sim)
        err_df <- bind_rows(err_df, this_err_df)
    }
}
close(pb)
  |======================================================================| 100%

We estimate the different sources of error:

  • pred_err = \(\hat y_n - \beta^T x_n\)
  • reg_err = \(\hat y_n - \mathbb{E}[\hat\beta]^T x_n\)
  • bias = \(\mathbb{E}[\hat\beta]^T x_n - \beta^T x_n\)

so that pred_err = reg_err + bias.

# Estimate the expectated prediction at each datapoint (to estimate the bias)
# Recall that n indexes individual datapoints, so ey_pn is an estimate of
# the expected yhat for a particular x, averaged over all simulations.
err_df <-
    err_df %>%
    group_by(p, n) %>%
    mutate(ey_pn=mean(y_pred)) %>%
    ungroup()

# Compute the MSE and related quantities
mse_df <- 
    err_df %>%
    group_by(p) %>%
    mutate(pred_err=y_pred - ey, 
           reg_err=y_pred - ey_pn,
           bias=ey_pn - ey) %>%
    summarize(mse=mean(pred_err^2), 
              var=mean(reg_err^2), 
              bias2=mean(bias^2), # The averaging over new data 
              rand_err=sigma_true^2) %>%
    mutate(mse_check=(mse - (var + bias2)) / mse) # Sanity check -- should be zero
mse_df
A tibble: 14 × 6
p mse var bias2 rand_err mse_check
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 0.154295538 0.0002612088 0.1540343294 0.16 -1.798858e-16
2 0.071492438 0.0004750056 0.0710174322 0.16 0.000000e+00
3 0.025757448 0.0007859902 0.0249714576 0.16 0.000000e+00
4 0.024913654 0.0010975353 0.0238161190 0.16 0.000000e+00
5 0.023965436 0.0013810972 0.0225843386 0.16 0.000000e+00
6 0.013597196 0.0015696131 0.0120275824 0.16 0.000000e+00
11 0.004800805 0.0026488117 0.0021519935 0.16 -1.806701e-16
16 0.005012249 0.0039636371 0.0010486124 0.16 0.000000e+00
21 0.006122883 0.0059289516 0.0001939315 0.16 0.000000e+00
26 0.007831561 0.0075814637 0.0002500975 0.16 0.000000e+00
31 0.009338688 0.0090346805 0.0003040077 0.16 0.000000e+00
36 0.010682421 0.0103274960 0.0003549251 0.16 0.000000e+00
41 0.012504986 0.0119679744 0.0005370119 0.16 0.000000e+00
46 0.014585520 0.0139529689 0.0006325514 0.16 0.000000e+00

Look at how the MSE and its components vary with p, the number of regressors included in the regression.

mse_graph <- ggplot(mse_df, aes(x=p)) +
    geom_line(aes(y=mse, color="MSE")) +
    geom_line(aes(y=var, color="var")) +
    geom_line(aes(y=bias2, color="bias2")) +
    geom_vline(aes(xintercept=p_true), color="purple")
grid.arrange(
    mse_graph,
    mse_graph + scale_y_log10(),
    ncol=2)

Cross-validation

Now let’s see whether we can accurately recover the above curves using cross-validation.

We’ll use k-fold cross validation. We run the regression once for each p, and once for each simulated dataset, so we can see the variability in the CV procedure.

# Perform cross-validation.
# `mse` is the estimated held--out fold mean squared error.

n_folds <- 10
fold_index <- sample(1:n_folds, n_obs, replace=TRUE)

err_cv_df <- data.frame()
n_data_sets <- length(data_df_list)
stopifnot(n_data_sets <= length(data_df_list))

pb <- txtProgressBar(min=0, max=length(p_seq) * n_data_sets * n_folds, style=3)
pb_ind <- 0
for (p in p_seq) {
    for (data_ind in 1:n_data_sets) {
        for (fold in 1:n_folds) {
            pb_ind <- pb_ind + 1
            setTxtProgressBar(pb, pb_ind)
            data_df <- data_df_list[[data_ind]] %>% mutate(fold_index=fold_index)
            lm_fit <-lm(formula(RegressionFormula(p)), data_df %>% filter(fold_index != !!fold))
            data_fold_df <- data_df %>% filter(fold_index == !!fold)
            y_pred <- predict(lm_fit, data_fold_df)
            mse <- mean((y_pred - data_fold_df$y)^2)
            
            this_err_cv_df <- data.frame(mse=mse, fold=fold, p=p, data_ind=data_ind)
            err_cv_df <- bind_rows(err_cv_df, this_err_cv_df)
}}}
close(pb)
  |======================================================================| 100%
# For each p, and each dataset, average the mse over folds,
# and identify the p for each dataset that minimizes MSE.
err_cv_agg_df <-
    err_cv_df %>%
        group_by(data_ind, p) %>%
        summarize(mse=mean(mse), .groups="keep") %>%
        ungroup() %>%
        group_by(data_ind) %>%
        mutate(is_mse_min=mse <= min(mse)) 

# For each p, see how variable the mse is over datasets.  
# (In practice, you only see one dataset.)
err_cv_agg_agg_df <-
    err_cv_agg_df %>%
    group_by(p) %>%
    summarize(mse_sd=sd(mse), mse=mean(mse), .groups="drop")

min_ind <- which.min(err_cv_agg_agg_df$mse)
min_mse <- err_cv_agg_agg_df$mse[min_ind]
min_mse_sd <- err_cv_agg_agg_df$mse_sd[min_ind]
min_p <- err_cv_agg_agg_df$p[min_ind]
p_cv <- err_cv_agg_agg_df %>%
    filter(mse <= min_mse + min_mse_sd) %>%
    pull(p) %>%
    min()
print(p_cv)
[1] 6

We see that

  • CV selects more variables than are correct
  • There is a fair amount of variability in the estimated MSE
ggplot(err_cv_agg_df, aes(x=p)) +
    geom_line(aes(y=mse, group=data_ind, color="CV MSE")) +
    geom_line(aes(x=p, y=mse + sigma_true^2, color="True MSE"), data=mse_df, lwd=4) +
    geom_point(aes(p, mse), data=filter(err_cv_agg_df, is_mse_min == TRUE)) +
    geom_vline(aes(xintercept=p_true)) +
    scale_y_log10()



ggplot(err_cv_agg_df, aes(x=p)) +
    geom_line(aes(x=p, y=mse, color="Average CV MSE"), data=err_cv_agg_agg_df) +
    geom_ribbon(aes(x=p, ymin=mse - mse_sd, ymax=mse + mse_sd), data=err_cv_agg_agg_df, alpha=0.3) +
    geom_line(aes(x=p, y=mse + sigma_true^2, color="True MSE"), data=mse_df, lwd=2) +
    geom_point(aes(x=p_cv, y=min_mse + min_mse_sd)) +
    geom_point(aes(x=min_p, y=min_mse)) +
    geom_vline(aes(xintercept=p_true)) +
    scale_y_log10()