Building a Bayesian counterpart to R's lm() function


One of my recent projects used a simple Bayesian version of the linear model. While the function I wrote to fit the model certainly worked, I was interested in making it more robust and easier to use by imitating the form of R’s built-in lm() function, which handles the frequentist linear model. In this post I’ll take apart the existing lm() function, describe a simple Bayesian linear model with conjugate priors, and build a counterpart blm() function. Along the way I’ll go through some properties of the R language that don’t come up in everyday data analytical use, but can be immensely helpful when writing your own methods. Specifically, I’ll explain how to take a full data frame and use a model formula and a subsetting criterion to produce a response vector and a design matrix that can be used directly for model fitting.

Dissecting R’s lm() function

The source code of any function in R can be viewed by entering the name of the function into the console. Sometimes the source code isn’t terribly helpful, especially if it just sends you to a C call, but we get lucky in the case of lm(), which provides a good explanation of what’s going on.

function (formula, data, subset, weights, na.action, method = "qr",
          model = TRUE, x = FALSE, y = FALSE, qr = TRUE, singular.ok = TRUE,
          contrasts = NULL, offset, ...)
{
  ret.x <- x
  ret.y <- y
  cl <- match.call()
  mf <- match.call(expand.dots = FALSE)
  m <- match(c("formula", "data", "subset", "weights", "na.action",
               "offset"), names(mf), 0L)
  mf <- mf[c(1L, m)]
  mf$drop.unused.levels <- TRUE
  mf[[1L]] <- quote(stats::model.frame)
  mf <- eval(mf, parent.frame())
  if (method == "model.frame")
    return(mf)
  else if (method != "qr")
    warning(gettextf("method = '%s' is not supported. Using 'qr'",
                     method), domain = NA)
  mt <- attr(mf, "terms")
  y <- model.response(mf, "numeric")
  w <- as.vector(model.weights(mf))
  if (!is.null(w) && !is.numeric(w))
    stop("'weights' must be a numeric vector")
  offset <- as.vector(model.offset(mf))
  if (!is.null(offset)) {
    if (length(offset) != NROW(y))
      stop(gettextf("number of offsets is %d, should equal %d (number of observations)",
                    length(offset), NROW(y)), domain = NA)
  }
  if (is.empty.model(mt)) {
    x <- NULL
    z <- list(coefficients = if (is.matrix(y)) matrix(, 0,
                                                      3) else numeric(), residuals = y, fitted.values = 0 *
                y, weights = w, rank = 0L, df.residual = if (!is.null(w)) sum(w !=
                                                                                0) else if (is.matrix(y)) nrow(y) else length(y))
    if (!is.null(offset)) {
      z$fitted.values <- offset
      z$residuals <- y - offset
    }
  }
  else {
    x <- model.matrix(mt, mf, contrasts)
    z <- if (is.null(w))
      lm.fit(x, y, offset = offset, singular.ok = singular.ok,
             ...)
    else lm.wfit(x, y, w, offset = offset, singular.ok = singular.ok,
                 ...)
  }
  class(z) <- c(if (is.matrix(y)) "mlm", "lm")
  z$na.action <- attr(mf, "na.action")
  z$offset <- offset
  z$contrasts <- attr(x, "contrasts")
  z$xlevels <- .getXlevels(mt, mf)
  z$call <- cl
  z$terms <- mt
  if (model)
    z$model <- mf
  if (ret.x)
    z$x <- x
  if (ret.y)
    z$y <- y
  if (!qr)
    z$qr <- NULL
  z
}

If you look closely, it becomes apparent that lm() doesn’t actually perform any of the actual model fitting, but is really a wrapper function that converts the arguments into a friendly format for lm.fit(), which is called on line 45. This is okay with us, because the method for fitting the Bayesian model will be entirely different from the method for fitting the frequentist model. We will probably have a blm.fit() function that does all the fitting work, given data in a friendly format, and the user-facing function blm() that makes the whole process easy for the user. This blm() function performs much of the same functions as lm() and thus could look fairly similar.

Arguments

First, we’ll take a look at the list of arguments (formula, data, subset, etc.). A description of the purpose of each argument is given in the help file, which can be accessed by typing ?lm. The arguments we’ll be focusing on are the following:

  1. formula, which specifies a model structures, e.g. Ozone ~ Wind + Temp + Wind:Temp,
  2. data, a data frame containing the observations,
  3. subset, which indicates which observations from data to use.

The arguments weights, na.action, and offset refer to more complicated ways of running a model, and are outside the scope of this post. The rest of the arguments primarily tell the function what sorts of additional information to return at the end, and will be dealt with later.

Constructing model matrices

The lm.fit() function, like our soon-to-be blm.fit() function, takes as arguments a response vector y and a design matrix x. This is the mathematically convenient starting point for a model fit, but can be cumbersome to a user. Part of the job of lm() is to accept the complete data frame, take only a specified subset of observations (rows), and then use the given model formula to produce a response vector and a design matrix for lm.fit().

Much of the actual work toward this is done by the model.frame() function, which takes a formula, data frame, subsetting criterion, and instructions on what to do with unused levels of factors, and returns a data frame containing the variables used in the formula along with information on how to extract the response vector and design matrix. In lm(), a good chunk of code takes the arguments passed in the original function call and builds it into a call to model.frame() which is then evaluated within the environment from which lm() was called. I’ve reproduced the corresponding section of code and commented each step.

# get the function call with full argument names,
# ignoring arguments passed to the ...
mf <- match.call(expand.dots = FALSE)

# get the indices of the arguments matching the following names
# if a listed name isn't found in the arguments,
# return a 0 for that index
m <- match(c("formula", "data", "subset", "weights", "na.action",
             "offset"), names(mf), 0L)

# trim the function call to only the function name
# and the arguments matched in m
mf <- mf[c(1L, m)]

# when subsetting the data frame, remove levels of factors
# that don't match any observations in the subset
mf$drop.unused.levels <- TRUE

# replace the original function name with "model.frame"
mf[[1L]] <- quote(stats::model.frame)

# call model.frame with the associated arguments
mf <- eval(mf, parent.frame())

There’s some more logic for handling edge cases in lm() before extracting the response vector and design matrix, but I’ve pulled out the parts I’m interested in and commented them. It turns out that once we have the model matrix, extracting these objects is quite easy.

# extract the "terms" attribute from the model frame
mt <- attr(mf, "terms")

# extract the response vector
y <- model.response(mf, "numeric")

# extract the design matrix
x <- model.matrix(mt, mf, contrasts)

If we’re not dealing with contrasts, we can drop the final argument from model.matrix(). That’s it! We can now pass x and y to lm.fit() as the design matrix and response vector, respectively.

Return values

The lm function returns a lot of information encapsulated in an object of class "lm". The minimum attributes required by the specification of the "lm" class are the following:

  1. coefficients, a named vector of coefficients,
  2. residuals, the residuals, that is response minus fitted values,
  3. fitted.values, the fitted mean values,
  4. rank, the numeric rank of the fitted linear model,
  5. weights, (only fir weighted fits) the specified weights,
  6. df.residual, the residual degrees of freedom,
  7. several more items that simply regurgitate information passed to the function or used internally

In fact, items 1–6 are returned by lm.fit() and are simply passed along. The values in item 7 include things like the original function call, response vector, and design matrix, the latter two returned only if the y and x arguments are set to true in the function call, respectively.

The following excerpt form lm() creates a list z of the above items and assigns it the "lm" class designation, then returns it.

class(z) <- c(if (is.matrix(y)) "mlm", "lm")
z$na.action <- attr(mf, "na.action")
z$offset <- offset
z$contrasts <- attr(x, "contrasts")
z$xlevels <- .getXlevels(mt, mf)
z$call <- cl
z$terms <- mt
if (model)
  z$model <- mf
if (ret.x)
  z$x <- x
if (ret.y)
  z$y <- y
if (!qr)
  z$qr <- NULL
z

A simple Bayesian linear model with conjugate priors

Supposing we have a response vector Y and a design matrix X, we can specify the Bayesian hierarchical linear model

in which β is a vector of regression parameters, σ2 is the variance parameter, and ΣνRa0, and b0 are hyperparameters assumed known. In the usual uncorrelated data case, Σ is the identity matrix. For vague priors, we can set ν to be the zero vector, R to be a diagonal matrix with very large diagonal entries, and a0b0 very small.

Given the observed response vector y, the joint posterior distribution is

The marginal distribution of β is the multivariate Student’s t distribution,

All the information we need to make inferences is contained in these posteriors.

Building a blm() function

The first order of business in creating the blm() (and blm.fit()) function is to determine what sort of input we want to take and what should be returned given that input. In order to maintain the user-friendliness of lm(), we want to take the three basic arguments: a model formula, an optional full data frame, and an optional subset vector that can alternatively be specified as an inclusion criterion.

Prior specification

In order to fit the Bayesian model, we also need the hyperparameters. Thankfully we can construct some default values which lead to sensible vague priors most of the time. This will allow users to call blm() just as they would call lm(), ignoring prior specification. However, it is important to allow the manual specification of priors through setting hyperparameters, and we must think carefully about how those should be passed as arguments. We want the hyperparameter arguments to be clearly named and flexible in use, allowing each argument to be included or excluded separately, and for variances to be specified both as diagonals (for independence models) and as full matrices (for correlated models). I choose the following arguments and default values:

  • prior.mean = 0, a scalar or vector, corresponding to ν, which determines the prior mean of the regression coefficients (if a scalar, treat as a vector with all elements equal to that scalar),
  • prior.precision = 0.0001, a scalar, vector or matrix, corresponding to R-1, which determines the prior precision of the regression coefficients (if a vector, R-1 is diagonal with the given vector as the diagonal, if a scalar, treat as a vector with all elements equal to that scalar),
  • prior.df = 0.0001, a scalar, corresponding to 2a0, which determines the corresponding degrees of freedom (and can be thought of as the number of “prior observations”),
  • prior.scale = 1, a scalar, corresponding to sqrt(b0 / a0), the square root of the harmonic mean of the prior mode and prior mean, when the latter exists, of σ2,
  • cov.structure = 1, a scalar, vector or matrix, corresponding to Σ, which determines the conditional covariance structure (but not magnitude) of the responses (if a vector, Σ is diagonal with the given vector as the diagonal, if a scalar, treat as a vector with all elements equal to that scalar).

As a courtesy to those who wish to specify the prior on σ2 directly, we’ll include optional arguments prior.a and prior.b which can be used as alternatives to prior.df and prior.scale as long as arguments from only one pair are used. When prior.df and prior.scale are specified, we compute prior.a <- prior.df / 2 and prior.b <- (prior.df * prior.scale ^ 2) / 2.

The prior coefficient precision matrix is used instead of the prior coefficient covariance matrix because it’s the traditional Bayesian parameterization of normal priors. Additionally, prior.df and prior.scale are the defaults because they are more intuitive interpretations than prior.a and prior.b.

Finally, we have the placeholders p and n in the default values. This is the number of regression parameters, which is equal to the number of columns in the design matrix, and the number of observations, equal to the number of rows of the design matrix, respectively.

More on prior specification

Things can get a little tricky with the prior specifications because of the way terms in model formulae are translated into the design matrix. Roughly, the columns of the design matrix correspond to the terms of the model formula in order of appearance. The nuances most commonly affecting this heuristic are the fact that the intercept term implicitly comes first, and that interactions specified as a * b are translated to a + b + a:b. However, I don’t know of a better way to specify the priors.

Additionally, model.matrix() converts categorical covariates into a set of indicator covariates with the first level treated as the baseline. Levels are taken in the order returned by levels(). The specification of the prior must account for this.

All this can be avoided in many use cases by passing the arguments as scalars, as they’ll automatically be expanded to the correct forms. However, care must be taken when specifying more complicated priors.

Return value

The return value of blm() will use the return value of lm() as a starting point. We’ll return an object of class "blm" in which objects of class "blm" contains the following components:

  1. coefficients, a named vector of coefficients,
  2. residuals, the residuals,
  3. fitted.values, the fitted mean values,
  4. prior.params, a list containing the prior hyperparameters passed as arguments, expanded to full form where applicable,
  5. posterior.params, a list containing the posterior hyperparameters,
  6. cov.structure, the conditional covariance structure passed as an argument, expanded to full form,
  7. call, the matched call,
  8. model, the model frame used.

The values of prior.params and cov.structure are produced within the body of blm() during the handling of passed arguments, and posterior.means will be returned by blm.fit() along with the model coefficients, residuals, fitted values, etc.

The blm.fit() function

library(matrixcalc)
library(Matrix)

blm.fit <- function(x, y,
                    prior.mean, prior.precision,
                    prior.a, prior.b,
                    cov.structure) {
  n <- nrow(x)
  p <- ncol(x)
  rank <- as.numeric(rankMatrix(x))

  if (is.null(colnames(x))) {
    colnames(x) <- paste0("X", 1:p)
  }
  x.names <- colnames(x)

  # error checking
  if (length(y) != n) {
    stop("Length of y must match number of rows of x.")
  }

  if (rank < p) {
    stop("x is rank deficient.")
  }

  if (!is.positive.semi.definite(cov.structure)) {
    stop("cov.structure must be positive semidefinite.")
  }

  if (!is.positive.definite(prior.precision)) {
    stop("prior.precision must be positive definite.")
  }

  # fit model
  Sigma.inv <- svd.inverse(cov.structure)

  H.inv <- t(x) %*% Sigma.inv %*% x + prior.precision
  H <- svd.inverse(H.inv)
  h <- t(x) %*% Sigma.inv %*% y + prior.precision %*% prior.mean
  a <- drop(prior.a + n / 2)
  b <- drop(prior.b + (t(y) %*% Sigma.inv %*% y +
                       t(prior.mean) %*% prior.precision %*% prior.mean -
                       t(h) %*% H %*% h
                      ) / 2)

  # prepare return value
  mean <- as.vector(H %*% h)
  names(mean) <- x.names
  precision <- H.inv
  colnames(precision) <- rownames(precision) <- x.names
  df <- 2 * a
  scale <- sqrt(b / a)

  t.mean <- mean
  t.cov  <- scale ^ 2 * H
  t.df   <- 2 * a

  coefficients <- mean
  fitted.values <- as.vector(x %*% coefficients)
  residuals <- y - fitted.values

  posterior.params <- list(mean=mean,
                           precision=precision,
                           df=df,
                           scale=scale,
                           a=a,
                           b=b,
                           t.mean=t.mean,
                           t.cov=t.cov,
                           t.df=t.df)

  list(coefficients=coefficients,
       residuals=residuals,
       fitted.values=fitted.values,
       rank=rank,
       df.residual=df,
       posterior=posterior.params)
}

The blm() function

blm <- function(formula, data, subset,
                prior.mean=0, prior.precision=0.0001,
                prior.df=0.0001, prior.scale=1,
                cov.structure=1,
                prior.a=prior.df / 2,
                prior.b=(prior.df * prior.scale ^ 2) / 2) {

  # add names to data frame if needed
  if (is.null(colnames(data))) {
    colnames(data) <- paste0("X", 1:p)
  }

  # build call to stats::model.frame from passed arguments
  mf <- match.call(expand.dots = FALSE)
  m <- match(c("formula", "data", "subset"), names(mf), 0L)
  mf <- mf[c(1L, m)]
  mf$drop.unused.levels <- TRUE
  mf[[1L]] <- quote(stats::model.frame)
  mf <- eval(mf, parent.frame())

  # extract terms from the model frame
  mt <- attr(mf, "terms")
  y <- model.response(mf, "numeric")   # response vector
  x <- model.matrix(mt, mf, contrasts) # design matrix
  x.names <- colnames(x)

  # make sure arguments are passed from at most one pair
  # (prior.df, prior.scale), (prior.a, prior.b)
  if ((!missing(prior.df) || !missing(prior.scale)) &&
      (!missing(prior.a)  || !missing(prior.b))) {
    stop(paste("Can only set arguments from one pair:",
                "(prior.df, prior.scale) or (prior.a, prior.b)."))
  }

  # expand prior.mean, prior.precision, cov.structure
  n <- nrow(x)
  p <- ncol(x)

  if (length(prior.mean) == 0) {
    stop("prior.mean cannot have length zero.")
  } else if (length(prior.mean) == 1) {
    prior.mean <- rep(prior.mean, p)
  } else if (length(prior.mean) != p) {
    stop(paste("Length of prior.mean is not 1 and does not",
               "match number of columns in design matrix."))
  }
  names(prior.mean) <- x.names

  if (is.null(ncol(prior.precision))) {
    if (length(prior.precision) == 0) {
      stop("prior.precision cannot have length zero.")
    } else if (length(prior.precision) == 1) {
      prior.precision <- rep(prior.precision, p)
    } else if (length(prior.precision) != p) {
      stop(paste("Length of prior.precision is not 1 and does not",
                 "match the number of columns in design matrix."))
    }
    prior.precision <- diag(prior.precision, p)
  } else if (ncol(prior.precision) != p ||
             nrow(prior.precision) != p) {
    stop(paste("Number of rows and columns of prior.precision must",
               "match the number of columns in design matrix."))
  }
  colnames(prior.precision) <- rownames(prior.precision) <- x.names

  if (is.null(ncol(cov.structure))) {
    if (length(cov.structure) == 0) {
      stop("cov.structure cannot have length zero.")
    } else if (length(cov.structure) == 1) {
      cov.structure <- rep(cov.structure, n)
    } else if (length(prior.precision) != n) {
      stop(paste("Length of cov.structure is not 1 and does not",
                 "match the number of rows in design matrix."))
    }
    cov.structure <- diag(cov.structure, n)
  } else if (ncol(cov.structure) != n ||
             nrow(cov.structure) != n) {
    stop(paste("Number of rows and columns of cov.structure must",
               "match the number of rows in design matrix."))
  }

  # fit model
  fit <- blm.fit(x, y,
                 prior.mean, prior.precision,
                 prior.a, prior.b,
                 cov.structure)

  # prepare return value
  prior.params <- list(mean=prior.mean,
                       precision=prior.precision,
                       df=prior.df,
                       scale=prior.scale,
                       a=prior.a,
                       b=prior.b)
  fit$prior <- prior.params
  fit$cov.structure <- cov.structure
  fit$call <- match.call()
  fit$model <- mf
  class(fit) <- c("blm")
  fit
}

Helper functions print() and summary()

After using fit <- lm(), the result is often then either printed immediately using fit or summarized using summary(fit). There are other methods that are often used, for example coefficients() and vcov(), but once the basics of creating new methods in R are learned, creating these particular methods for class "blm" is straightforward. When a user types fit, the function that’s actually called is print(fit), which is dispatched to print.lm(fit). We then need to define the function print.blm():

print.blm <- function(blm) {
  cat("\nCall:\n")
  print(blm$call)
  cat("\nCoefficients:\n")
  print(blm$posterior$mean)
}

The summary() function works similarly. For now we’ll just print out the function call, a coefficient table containing posterior means and (marginal) standard errors, an estimate of the variance parameter, and the degrees of freedom.

summary.blm <- function(blm) {
  cat("\nCall:\n")
  print(blm$call)

  se <- sqrt(diag(blm$posterior$t.cov) *
               (blm$posterior$a / (blm$posterior$a - 1)))
  coeff <- cbind(blm$posterior$t.mean, se)
  colnames(coeff) <- c("Post. Mean", "Marg. Post. SE")
  cat("\nCoefficients:\n")
  print(coeff)

  cat("\nEstimated Variance:\n")
  cat(blm$posterior$b / blm$posterior$a)

  cat("\n\nResidual degrees of freedom:\n")
  cat(blm$posterior$t.df - length(blm$posterior$mean))
}

A test run

Now that we have everything we need, let’s run a basic test of our code.

data(airquality)

fit.f <- lm(Ozone ~ Wind + Temp + Wind:Temp, data=airquality)
fit.b <- blm(Ozone ~ Wind + Temp + Wind:Temp, data=airquality)

summary(fit.f)
summary(fit.b)

The results:

> summary(fit.b)

Call:
blm(formula = Ozone ~ Wind + Temp + Wind:Temp, data = airquality)

Coefficients:
              Post. Mean Marg. Post. SE
(Intercept) -248.3768291    47.70604708
Wind          14.3237294     4.20065758
Temp           4.0740781     0.58223882
Wind:Temp     -0.2237744     0.05350456

Variance:
403.3925

Residual degrees of freedom:
112.0001

> summary(fit.f)

Call:
lm(formula = Ozone ~ Wind + Temp + Wind:Temp, data = airquality)

Residuals:
    Min      1Q  Median      3Q     Max 
-39.906 -13.048  -2.263   8.726  99.306 

Coefficients:
              Estimate Std. Error t value Pr(>|t|)    
(Intercept) -248.51530   48.14038  -5.162 1.07e-06 ***
Wind          14.33503    4.23874   3.382 0.000992 ***
Temp           4.07575    0.58754   6.937 2.73e-10 ***
Wind:Temp     -0.22391    0.05399  -4.147 6.57e-05 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 20.44 on 112 degrees of freedom
  (37 observations deleted due to missingness)
Multiple R-squared:  0.6261,	Adjusted R-squared:  0.6161 
F-statistic: 62.52 on 3 and 112 DF,  p-value: < 2.2e-16

Everything matches up! There’s certainly some more work to be done as far as handling edge cases and making the output cleaner and more informative, but this constitutes a solid foundation on which to build.