nj.

Published on

Solving differential equations in Python with ode-explorer

Authors
  • avatar
    Name
    Nicholas Junge

This post is a work in progress, just like the project itself! You can find the source code on GitHub.

I have been interested in numerical computing since about 2015. At that time, I left the blackboard and "pure" calculus for the first time to actually calculate and simulate things on the computer, and I remember I was absolutely blown away by how I could "make math alive" by translating theory and equations into numerical programs.

Since then, I have learned a good bit about numerics, but my fascination has just increased. There is so much to discover - how to compute efficiently, how to compute in parallel, how to scale to large problems, how to scale across many machines, how to tune performance in microcode, it simply does not stop!

So today, I want to share a small thing I have been working on - recreating my first encounter with numerical programming, ordinary differential equation solving. Meet the ODE Explorer!

A quick primer on ODEs

In numerical ODE solving, you compute a solution curve y(t)y(t) that solves an ordinary differential equation

y(t)=f(t,y).y'(t) = f(t, y).

It is easy to verify (fundamental theorem of calculus) that the solution, if it exists, must obey the relation

y(t)=y(t0)+t0tf(t,y(x))dx.y(t) = y(t_0) + \int\limits_{t_0}^t f(t, y(x)) \mathrm{d}x.

for some known starting point t0t_0. From here, we also see that a solution to an ordinary differential equation is just the superposition of all values of its right-hand side ff at previous times; it is thus fair to say that an ODE solution state is the sum of all of its previous "experience". Because of this formula, the act of solving the ODE is also commonly called integration.

You can also prove that when imposing (fairly) manageable restrictions on the right hand side ff (so called Lipschitz continuity), a solution to the above ODE exists, and it is unique. This is known as the Picard-Lindelöf theorem.

The problem is, the amount of ODEs admitting a closed-form solution is actually a very small minority. There are linear differential equations, which introduce the idea of a matrix exponential, and a number of special cases that allow direct integration such as exact ODEs, which are locally representable as an exact differential form in tt and yy.

Numerical solutions to ODEs

Since we cannot expect to find a closed-form solution we can write down on paper, we need to devise numerical schemes to find our solution approximately. This topic can span another series of articles, but the strategy is to flow along the vector field created by ff, and keep an eye out on that we do not stray too far from our (unique) solution. This can be established by creating an algorithm with three main properties:

  1. Convergence, the reproduction of the real solution in the limit of exact integration (see above).
  2. Consistency, the ability to approximate the solution better and better with finer steps, and
  3. Stability, whether it produces well-behaved solutions of the ODE.

The most common class of methods are the single-step methods, which consider only the latest previous step to advance the system in time. The two main classes within those are explicit and implicit methods, the latter of which are more computationally expensive, but better for certain classes of equations. For a small introduction, see this article on numerical methods for ODEs on Wikipedia.

Anatomy of the ode-explorer package

There are three main components to the ode-explorer: The ODEModel, representing the right hand side ff, the StepFunction, which represents the solving algorithm, and the Integrator, which acts as a database for caching integration runs.

The model class

The ODEModel class is designed to represent the right-hand side ff of the ODE. It is basically only a light wrapper around a callable:

class ODEModel:
  def __init__(self, **kwargs):
    ...
  
  def __call__(self, t, y):
    ...

The main thing here is the calling convention being standardized to not include any auxiliary model parameters, like masses in a Hamiltonian System, spring constants or chemical reaction speeds. These are stored inside a kwargs dict inside the class to avoid trouble with JAX later.

The challenge here is supporting a variable length state: For a normal ODE like the above, the state is a tuple (t,y)(t, y), whereas for a Hamiltonian System, the state is a tuple (t,q,p)(t, q, p) of position and momentum. So broadening the signature to

def f(t, *args, **kwargs):
  ...

might be beneficial here.

The step function

This class is responsible for advancing an equation in time. Given a current state tuple (tk,yk)(t_k, y_k), it computes the next state (tk+1,yk+1)(t_{k+1}, y_{k+1}) with a given integration method. Common methods are the classical Runge-Kutta method (often called rk4), the Dormand-Prince 5-4 method (DOPRI5) and the forward Euler method.

Alluding to neural network forward passes, any step function can be implemented by subclassing the StepFunction class and overriding the forward method:

class MyStepFunc(StepFunction):
  def forward(state):
    ...
    # complicated integration logic
    return new_state

The Integrator

This is simply an abstract container that holds previous runs. It should behave like a database, but can be implemented by in-memory holding the previous runs for small systems, or saving results in an actual database and querying them.

#integration is designed to work as quickly as possible
my_int = Integrator()

# example pseudocode integration
my_int.integrate(my_model, start=0.0, end=10.0, step_size=0.001)

# or log the distance to a known solution additionally
my_int.integrate(my_model, start=0.0, end=10.0, step_size=0.001,
                 metrics=[SolutionDistance(sol=known_solution)])

# inspect the results by querying the latest run
res = my_int.get_result(result_id="latest")

# plotting, saving, sending the results
...

Logging and metrics calculations

Another design emphasis is that one should be able to calculate metrics along an integration curve. This could be a distance to a known solution, which can be used to test performance of an integration algorithm on a certain equation, stability indicators like eigenvalue ratios or Lyapunov exponents.

The Metric interface is again a simple callable, like the ODEModel before:

class Metric:
  def __init__(...):
    # initialize variables

  def __call__(...):
    # perform complex metric calculations

Multiple metrics can be chained in parallel, enabling detailed instrumentation of the integration process.

Future outlook

This project has just undergone a major change in that numpy usage was switched to jax.numpy usage to enable GPU computations. This however presents some challenges:

  1. The computation is inherently a sequential loop; calculating the solution in this manner in functional-style programming is not as straightforward and requires changes to work on GPU. (This is currently a big todo.)
  2. Performance! This requires using some low-level jax routines, which might not play nicely with metric logging out of the box.
  3. Making the ODEModel and Metric wrappers play nicely with JAX requires some type registration, so that XLA and the JAX core know what to expect.

All in all, I am excited about heaving the project onto accelerators with JAX; it might be a while before a substantial uplift over numpy is achieved, though.