Johann D. Gaebler

EmailGitHubTwitter

Autodiff for Implicit Functions in Stan

Fast derivatives for functions you can't write down

September 13th, 2021

Post replication materials

Implicit curve as a level curve.
Image: Ag2gaeh via Wikimedia

One of the things that makes Stan powerful is that—in addition to a large library of standard mathematical functions (e.g., \(\exp(x)\), \(x^y\), \(x + y\), \(\Gamma(x)\), etc.)—it also supports the use of higher-order functions, such as such as solving a user-specified system of ODEs. This greatly expands the range of Bayesian models Stan can handle.

One such higher-order function is Stan’s algebra solver, which is useful for building models that require implicit functions.1 Implicit functions show up often when one is looking for “steady states.” For instance, in pharmacokinetics, one often wants to know how much of a drug will build up in a patient’s body if they take a fixed dose at regular intervals. If the patient takes a doses of size \(\delta\) at intervals of length \(\tau\), then we’re looking for the dose \(x_0\) such that \(f(x_0 + \delta, \tau) = x_0\), where \(f(x, t)\) is the concentration of the drug in the patient’s body at time \(t\) if the concentration at time \(0\) was \(x\).2

The challenge with higher-order functions is finding a way to efficiently implement automatic differentiation for them. The HMC sampler repeatedly differentiates the density of the distribution from which we’re sampling. If our model involves, for example, the solution to an ODE, then we have differentiate the solution with respect to the inputs, even though the “solution” itself is the result of numerical integration and not necessarily something we can write down in closed form. This requires some cleverness. The algebra solver presents a similar challenge, as typically we cannot write down \(x\) as a closed-form function of \(\delta\) and \(\tau\) when all we know is the dependence \(f(x + \delta, \tau) = x\). Over the past few months, I’ve been working to help Stan calculate these derivatives more efficiently.

Implicit Functions

The easiest way to think about the algebraic solvers is through the implicit function theorem. The implicit function theorem states that if we have an equation of the form \(f(x, y) = 0\),3 and \(f\) is reasonably well-behaved, then, given a solution \((x_0, y_0)\), we can find a function \(g\) such that \((x, g(x))\) is a solution—i.e., \(f(x, g(x)) = 0\)—for all \(x\) in some neighborhood around \(x_0\). The function \(g(x)\), which “traces out” solutions to \(f(x, y) = 0\) as a function of \(x\), is our implicit function. There is an obstacle to \(g(x)\)’s existence: at least locally, there needs to be only a single value of \(y\) corresponding to each \(x\) such that \(f(x, y) = 0\). This will fail if the curve of solutions doubles back on itself, which, in turn can only happen where \( \tfrac {\text{d} y} {\text{d} x} \to \infty\). So, to get an implicit function \(g(x)\) in a neighborhood of \(x_0\), we need that \( \tfrac {\text{d} y} {\text{d} x} \) is not infinite, or, what comes to the same thing, that the derivative \( \tfrac {\partial f} {\partial y} \) exists and is non-zero at the solution \((x_0, y_0)\).

An implicitly defined limaçon trisectrix.

An implicitly defined limaçon trisectrix, given by \(x^2 + y^2 = (x^2 + y^2 - 2x)^2\).

This figure illustrates a limaçon trisectrix, defined implicitly by \(x^2 + y^2 = (x^2 + y^2 - 2x)^2\).4 The variable \(y\) is a continuously differentiable function of \(x\) locally near the blue line, but not the red, where \( \tfrac {\partial f} {\partial y} = 0 \).

More formally, the implicit function theorem states that if \(f : \mathbb{R}^{n+m} \to \mathbb{R}^m\) is continuously differentiable, \(\mathbf{x}_0 \in \mathbb{R}^n\), \(\mathbf{y}_0 \in \mathbb{R}^m\), and \(f(\mathbf{x}_0, \mathbf{y}_0) = 0\), and \(\tfrac {\partial f} {\partial \mathbf{y}} \upharpoonright_{\mathbf{y} = \mathbf{y}_0} \) is invertible,5 then there is an (implicit) function \(g : \mathbb{R}^n \to \mathbb{R}^m \) defined locally around \(\mathbf{x}_0\) such that \(f(\mathbf{x}, g(\mathbf{x})) = 0\). What’s more, it follows directly from differentiating \(f(\mathbf{x}, g(\mathbf{x})) = 0\) that6 \[ \frac {\partial g} {\partial \mathbf{x}} = - \left[ \frac {\partial f} {\partial \mathbf{y}} \upharpoonright_{\mathbf{y} = g(\mathbf{x})} \right]^{-1} \frac {\partial f} {\partial \mathbf{x}}. \]

This is the key point that makes it possible to use autodiff on implicit functions. The implicit function theorem doesn’t tell us how to calculate the function \(g(\mathbf{x})\)—that’s what the algebra solvers are for—but it does allows us to back out the gradient of \(g\) using only known quantities: the solution, \((\mathbf{x}, g(\mathbf{x}))\), and the algebraic system function, \(f(\mathbf{x}, \mathbf{y})\). (One could, in principle, implement the algebra solver itself on the autodiff stack and try to extract the gradient that way, but that would be impracticably slow.7)

Two Algorithms

The implicit function theorem makes autodiff using implicit functions possible. But the real question is, how fast can we make it?

If you’re unfamiliar with autodiff, there are some great overviews (and reference texts). Here, it is enough to understand that to make reverse-mode autodiff work, all we need to know is, for each “atomic function” \(f : \mathbb{R}^n \to \mathbb{R}^m \) that shows up in our model (e.g., \(\cos(x)\) or \(x + y\) or an implicit function calculated using the algebraic solver), how to calculate the product \(\boldsymbol{\xi} \tfrac {\partial f} {\partial \mathbf{x}}\), where \(\boldsymbol{\xi}\) is a size \(m\) row vector or “cotangent.”

In the case of an implicit function \(g(\mathbf{x})\), the implicit function theorem above suggests an obvious way to do this.

Algorithm 1: The Naïve Method

Data: The algebraic system function \(f(\mathbf{x}, \mathbf{y})\), a solution \((\mathbf{x}_0, \mathbf{y}_0)\), and an initial cotangent \(\boldsymbol{\xi}\).

Result: The adjoint of the implicit function \(\mathbf{x} \mapsto \mathbf{y}\) with respect to the initial cotangent \(\boldsymbol{\xi}\) evaluated at \( \mathbf{x}_0 \), i.e., the product \( \boldsymbol{\xi}_{\operatorname{out}} = \boldsymbol{\xi} \left[ \tfrac {\partial \mathbf{y}} {\partial \mathbf{x}} \upharpoonright_{\mathbf{x}_0} \right] \).

  1. For \(i = 1, \ldots, n\)
    • Calculate \(\tfrac {\partial f} {\partial x_i}\). (One forward-mode pass.)
  2. For \(i = 1, \ldots, m\)
    • Calculate \(\tfrac {\partial f} {\partial y_i}\). (One forward-mode pass.)
  3. Calculate the LU-decomposition of \(\tfrac {\partial f} {\partial \mathbf{y}}\).
  4. For \(i = 1, \ldots, n\)
    • Calculate \(\tfrac {\partial \mathbf{y}} {\partial x_i} = \left[ \tfrac {\partial f} {\partial \mathbf{y}} \right]^{-1} \frac {\partial f} {\partial x_i}\). (One matrix solve.)
  5. Calculate \(\boldsymbol{\xi}_{\text{out}} = \boldsymbol{\xi} \tfrac {\partial \mathbf{y}} {\partial \mathbf{x}}\). (One matrix multiplication.)
  6. Return \(\boldsymbol{\xi}_{\operatorname{out}}\).

Inverting \(\tfrac {\partial f} {\partial \mathbf{y}}\) is quite expensive: in addition to an LU-decomposition, it also requires \(n\) matrix solves and \(n\) forward-mode autodiff passes.

However, if we’re slightly more clever, we can avoid much of the expense of the matrix inversion, and reduce the number of autodiff passes.8

Algorithm 2: The Adjoint Method

Data: The algebraic system function \(f(\mathbf{x}, \mathbf{y})\), a solution \((\mathbf{x}_0, \mathbf{y}_0)\), and an initial cotangent \(\boldsymbol{\xi}\).

Result: The adjoint of the implicit function \(\mathbf{x} \mapsto \mathbf{y}\) with respect to the initial cotangent \(\boldsymbol{\xi}\) evaluated at \( \mathbf{x}_0 \), i.e., the product \( \boldsymbol{\xi}_{\operatorname{out}} = \boldsymbol{\xi} \left[ \tfrac {\partial \mathbf{y}} {\partial \mathbf{x}} \upharpoonright_{\mathbf{x}_0} \right] \).

  1. For \(i = 1, \ldots, m\)
    • Calculate \(\tfrac {\partial f} {\partial y_i}\). (One forward-mode pass.)
  2. Calculate the LU-decomposition of \(\tfrac {\partial f} {\partial \mathbf{y}}\).
  3. Calculate \(\boldsymbol{\eta} = \boldsymbol{\xi} \left[ \frac {\partial f} {\partial \mathbf{y}} \right]^{-1} \). (One matrix solve.)
  4. Calculate \(\boldsymbol{\xi}_{\text{out}} = \boldsymbol{\eta} \tfrac {\partial f} {\partial \mathbf{x}}\). (One reverse-mode pass.)
  5. Return \(\boldsymbol{\xi}_{\operatorname{out}}\).

Algorithm 2, by way of contrast, replaces the \(n\) forward-mode passes and matrix solves with a single reverse-mode pass. Since each autodiff pass requires (roughly) the same number of operations as calculating the value of a function itself,9 the savings are significant.

Testing it Out: An Example from Pharmacology

To test out the two algorithms, we borrow a simple example from pharmacology. We consider a two-compartment model. Patient \(i\) consumes a dose \(\delta\) of a drug at intervals of length \(\tau\). The concentration of the drug in the central and peripheral compartments of patient \(i\) satisfies the ODE10 \begin{align*} y_{i, \text{cen}}’(t) &= -\kappa_{i, \text{cen}} \cdot y_{i, \text{cen}}(t), \\
y_{i, \text{per}}’(t) &= \kappa_{i, \text{cen}} \cdot y_{i, \text{cen}}(t) - \kappa_{i, \text{per}} \cdot y_{i, \text{per}}(t), \end{align*} where \(t\) is the time since the last dose. The patient is at a “steady state” when \begin{align*} y_{i, \text{cen}}(0) - y_{i, \text{cen}}(\tau) &= \delta, \\
y_{i, \text{per}}(0) - y_{i, \text{per}}(\tau) &= 0. \end{align*} Measurements are taken of concentration in the main compartment for each patient at steady state. Measurement \(m_{i, k}\) taken at time \(t_{i, k}\) after the last dose satisfies \[ \log(m_{i, k}) \sim \mathcal{N} \left( \log(y_{\text{per}}(t_{i, k})), \tfrac 1 4 \right). \]

To complete the model, we put priors on \(\kappa_{i, \text{cen}}\) and \(\kappa_{i, \text{per}}\). In particular, our prior is that they are i.i.d. lognormal random variables, i.e., \[ \log(\kappa_{i, j}) \sim \mathcal{N} \left( 0, \tfrac 1 4 \right), \] for \(j \in \{ \text{cen}, \text{per} \}\).

Stan Implementation and Results

It’s straightforward to represent the pharmacological model above directly in Stan.

functions {
  /* Solution to drug concentration ODE given inital concentration, time elapsed,
   * and diffusion parameters.
   */
  vector[] drug_conc(vector y_cen, vector y_per, vector kappa_cen,
                     vector kappa_per, vector ts, int N) {
    vector[N] y_cen_out = exp(-kappa_cen .* ts) .* y_cen;
    vector[N] y_per_out = (kappa_cen ./ (kappa_per - kappa_cen))
                            .* (exp(-kappa_cen .* ts) - exp(-kappa_per .* ts))
                            .* y_cen + exp(-kappa_per .* ts) .* y_per;

    return { y_cen_out, y_per_out };
  }

  // Functor with appropriate signature for algebraic solver.
  vector f(vector y, vector kappas, real[] x_r, int[] x_i) {
    // Unpack x_r.
    real delta = x_r[1];
    real tau = x_r[2];

    // Unpack x_i.
    int n = x_i[1];

    // All of the intervals are tau.
    vector[n] ts = rep_vector(tau, n);

    /* The first n entries of y are the concentrations in the central
     * compartment, while the last n entries of y are the concentrations in the
     * peripheral compartments. Likewise for kappa.
     */
    vector[n] y_cen = y[:n];
    vector[n] y_per = y[(n+1):];
    vector[n] kappa_cen = kappas[:n];
    vector[n] kappa_per = kappas[(n+1):];

    // Calculate the concentrations after tau.
    vector[n] y_res[2] = drug_conc(y_cen, y_per, kappa_cen, kappa_per, ts, n);

    /* If a steady state has been reached, the difference between the
     * concentration after tau (with a dose delta) should be the same as the
     * current state.
     */
    return append_row(y_res[1] + rep_vector(delta, n), y_res[2]) - y;
  }
}

data {
  real<lower=0> delta;              // Dosage
  real<lower=0> tau;                // Dose interval
  int n;                            // Number of patients
  int m;                            // Number of observations
  vector<lower=0,upper=tau>[m] ts;  // Times of observations (since last dose)
  int<lower=0,upper=n> idx[m];      // Patient corresponding to observation
  vector<lower=0>[m] obs;           // Observed concetrations
  vector<lower=0>[n] y_guess_cen;   // Guess for central compartment.
  vector<lower=0>[n] y_guess_per;   // Guess for peripheral compartment.
}

transformed data {
  // Reshape the data for the algebra solver.
  vector[2*n] y_guess = append_row(y_guess_cen, y_guess_per);
  real x_r[2]         = { delta, tau };
  int  x_i[1]         = { n };
}

parameters {
  vector<lower=0>[n] kappa_cen; // Dispersion from central compartment
  vector<lower=0>[n] kappa_per; // Dispersion from peripheral compartment
}

transformed parameters {
  // Reshape dispersion parameters for algebra solver.
  vector[2*n] kappas = append_row(kappa_cen, kappa_per);

  // Get the steady-state for each patient.
  vector[2*n] y_steady = algebra_solver(f, y_guess, kappas, x_r, x_i);
  vector[n] y_steady_cen = y_steady[:n];
  vector[n] y_steady_per = y_steady[(n+1):];

  /* Get the concentrations in the central compartment at each time observation,
   * given the currently sampled diffusion parameters.
   */
  vector[m] y_true[2] = drug_conc(y_steady_cen[idx], y_steady_per[idx],
                                  kappa_cen[idx], kappa_per[idx], ts, m);
}

model {
  kappa_cen ~ lognormal(0, 1.0/4);
  kappa_per ~ lognormal(0, 1.0/4);
  obs ~ lognormal(log(y_true[2]), 1.0/4);
}

generated quantities {
  real fake_obs[m] = lognormal_rng(log(y_true[2]), 1.0/4);
}

To test it out, we simulate fake data for patient populations of a rage of different sizes. For each population size, we generate 100 fake datasets, each with approximately 100 observations for each patient, and then fit the Stan model shown above. We repeat this experiment using both the naïve and adjoint algorithms.11 The results are shown below.

A plot of the raw runtimes of the naïve and adjoint algorithms.

Raw runtimes of the naïve and adjoint methods. Runtimes of the adjoint algorithm have been translated to the right for better readability.

As expected, there is a clear speedup as the size of the problem gets larger. To better visualize the speedup across all orders of magnitude, we also calculate the relative speedup for each problem size.

A plot of the relative speedup of the adjoint algorithm and new algorithms.

Relative average speedup of adjoint method over naïve method. Note that the plot shows the ratio of average speeds, rather than an average ratio of speeds.

In general, the adjoint method is as fast or faster than the naïve method, fitting the model in roughly 5–10% less time for smaller patient populations, and more than 30% less time for on the order of a hundred patients.12

Future Directions

While 30% speedup is substantial, more remains to be done. In addition to the Powell and Newton algebraic solvers, the Stan math library also has a fixed point solver, for which this adjoint method has not yet been implemented. It may also be possible to more efficiently use the Jacobians calculated by the Powell and Newton solvers themselves.

Lastly, profiling shows that the most expensive portion of the computation tends to be the LU decomposition, which is \(O(n^3)\), rather than the matrix solves, which are \(O(n^2)\).13 While support for sparse matrices in Stan is currently limited, Krylov subspace or other sparse matrix methods could be applied in the future in cases, such as the example considered above, where the Jacobian of the algebraic system has a sparse structure.

  1. Here by implicit functions, we refer exclusively to relations of the form \(R(x_1, \ldots, x_n) = 0\) for real variables \(x_1, \ldots, x_n \in \mathbb{R}\) and \(R : \mathbb{R}^n \to \mathbb{R}^k\) that are locally functions, rather than implicit functions defined on more general Banach spaces or in terms of differential operators, which require more care. See, e.g., Efficient Automatic Differentiation of Implicit Functions for more information. 

  2. Thanks to Charles Margossian for initially sharing this example with me, as well as for a number of helpful suggestions and corrections on this blog post. 

  3. We refer to such an \(f\) as the “algebraic system function” and to \(f(x, y) = 0\) as the “algebraic system” that we are trying to solve. 

  4. That is, our algebraic system function is \(f(x, y) = x^2 + y^2 - (x^2 + y^2 - 2x)^2\), and we are interested in the set of \(x\) and \(y\) such that \(f(x, y) = 0\). 

  5. This is the appropriate multivariate analogue of requiring that \( \tfrac {\partial f} {\partial y} \neq 0\). 

  6. We denote the Jacobian of \(f\) by \(\tfrac {\partial f} {\partial \mathbf{x}}\). Individual partial derivatives are written \(\tfrac {\partial f_i} {\partial x_j}\), gradients \( \tfrac {\partial f_i} {\partial \mathbf{x}} \), etc. 

  7. Even this is not the whole story—strictly speaking, the algebraic solvers actually can have local discontinuities because, e.g., the number of Newton steps varies. This means that our implementation would need to apply step functions to parameters, which can seriously impact sampling, since step functions are not themselves differentiable around the “step.” 

  8. Charles Margossian originally introduced this technique to me, but, e.g., Kolter, Duvenaud, and Johnson had previously proposed it in the autodiff literature. 

  9. See Chapter 4 of Griewank & Walther

  10. This ODE has a closed-form solution which can be derived from the matrix exponential. Note that we can write, equivalently, \[ \mathbf{y}_i’(t) = \begin{bmatrix} - \kappa_{i, \text{cen}} & 0 \\
    \kappa_{i, \text{cen}} & - \kappa_{i, \text{per}} \end{bmatrix} \mathbf{y}_i(t). \] Note that the matrix factors: \begin{align*} \begin{bmatrix} - \kappa_{i, \text{cen}} & 0 \\
    \kappa_{i, \text{cen}} & - \kappa_{i, \text{per}} \end{bmatrix} &= \mathbf{X}^{-1} \mathbf{\Lambda} \mathbf{X} \\
    &= \begin{bmatrix} 1 & 0 \\
    \tfrac {\kappa_{i, \text{cen}}} {\kappa_{i, \text{cen}}-\kappa_{i, \text{per}}} & 1 \end{bmatrix} \begin{bmatrix} -\kappa_{i, \text{cen}} & 0 \\
    0 & -\kappa_{i, \text{per}} \end{bmatrix} \begin{bmatrix} 1 & 0 \\
    \tfrac {\kappa_{i, \text{cen}}} {\kappa_{i, \text{per}}-\kappa_{i, \text{cen}}} & 1 \end{bmatrix}. \end{align*} Therefore, the solution is given by \begin{align*} \mathbf{y}_i(t) &= \exp \left(t \cdot \begin{bmatrix} - \kappa_{i, \text{cen}} & 0 \\
    \kappa_{i, \text{cen}} & - \kappa_{i, \text{per}} \end{bmatrix} \right) \mathbf{y}_i(0) \\
    &= [\, \mathbf{X}^{-1} \exp(t \mathbf{\Lambda}) \mathbf{X} \,] \, \mathbf{y}_i(0) \\
    &= \begin{bmatrix} \exp(-\kappa_{i, \text{cen}}t) & 0 \\
    \tfrac {\kappa_{i, \text{cen}}} {\kappa_{i, \text{per}} - \kappa_{i, \text{cen}}} \cdot (\exp(-\kappa_{i, \text{cen}}t) - \exp(-k_{i, \text{per}}t)) & \exp(-\kappa_{i,\text{per}}t) \end{bmatrix} \mathbf{y}_i(0), \end{align*} which is what is used in the Stan model, although we could have used the numerical ODE solver.

    We can also use this closed form solution to explicitly calculate fixed points given \(\kappa_{i, \text{cen}}\), \(\kappa_{i, \text{per}}\), \(\delta\), and \(\tau\). In particular, straightforward algebraic manipulation yields that the steady state (at time \(\tau\) after a dose \(\delta\)) is \begin{align*} y_{i, \text{cen}} &= \frac \delta {1 - \exp(-\kappa_{i, \text{cen}} \cdot \tau)}, \\
    y_{i, \text{per}} &= y_{i, \text{cen}} \cdot \frac {\kappa_{i, \text{cen}}} {\kappa_{i, \text{per}} - \kappa_{i, \text{cen}}} \cdot \exp(-\kappa_{i, \text{cen}} \cdot \tau) - \frac {\exp(-\kappa_{i, \text{per}} \cdot \tau)} {1 - \exp(-\kappa_{i, \text{per}} \cdot \tau)}. \end{align*} This expression is needed to simulate fake data. 

  11. That is, we use commit 74c5354, which introduced the adjoint algorithm and commit cfb93f5, which immediately preceded it. Note that due to technical limitations, the actual implementation in Stan differs from the algorithms as written in minor ways (for instance, in the adjoint method, using reverse-mode rather than forward-mode to calculate the Jacobian and adding a small amount of memory-management–related overhead). 

  12. The fact that it appears to be slightly slower for on the order of 30 parameters is likely an artifact of the memory-management–related overhead currently necessary to implement the adjoint method in Stan. This overhead will be removed in the future. 

  13. That is, the most expensive portion of the computation outside of evaluating the algebraic system itself, which can be arbitrarily expensive.