- Published on
Solving differential equations in Python with ode-explorer
- Authors
- Name
- Nicholas Junge
This post is a work in progress, just like the project itself! You can find the source code on GitHub.
I have been interested in numerical computing since about 2015. At that time, I left the blackboard and "pure" calculus for the first time to actually calculate and simulate things on the computer, and I remember I was absolutely blown away by how I could "make math alive" by translating theory and equations into numerical programs.
Since then, I have learned a good bit about numerics, but my fascination has just increased. There is so much to discover - how to compute efficiently, how to compute in parallel, how to scale to large problems, how to scale across many machines, how to tune performance in microcode, it simply does not stop!
So today, I want to share a small thing I have been working on - recreating my first encounter with numerical programming, ordinary differential equation solving. Meet the ODE Explorer!
A quick primer on ODEs
In numerical ODE solving, you compute a solution curve that solves an ordinary differential equation
It is easy to verify (fundamental theorem of calculus) that the solution, if it exists, must obey the relation
for some known starting point . From here, we also see that a solution to an ordinary differential equation is just the superposition of all values of its right-hand side at previous times; it is thus fair to say that an ODE solution state is the sum of all of its previous "experience". Because of this formula, the act of solving the ODE is also commonly called integration.
You can also prove that when imposing (fairly) manageable restrictions on the right hand side (so called Lipschitz continuity), a solution to the above ODE exists, and it is unique. This is known as the Picard-Lindelöf theorem.
The problem is, the amount of ODEs admitting a closed-form solution is actually a very small minority. There are linear differential equations, which introduce the idea of a matrix exponential, and a number of special cases that allow direct integration such as exact ODEs, which are locally representable as an exact differential form in and .
Numerical solutions to ODEs
Since we cannot expect to find a closed-form solution we can write down on paper, we need to devise numerical schemes to find our solution approximately. This topic can span another series of articles, but the strategy is to flow along the vector field created by , and keep an eye out on that we do not stray too far from our (unique) solution. This can be established by creating an algorithm with three main properties:
- Convergence, the reproduction of the real solution in the limit of exact integration (see above).
- Consistency, the ability to approximate the solution better and better with finer steps, and
- Stability, whether it produces well-behaved solutions of the ODE.
The most common class of methods are the single-step methods, which consider only the latest previous step to advance the system in time. The two main classes within those are explicit and implicit methods, the latter of which are more computationally expensive, but better for certain classes of equations. For a small introduction, see this article on numerical methods for ODEs on Wikipedia.
Anatomy of the ode-explorer package
There are three main components to the ode-explorer
: The ODEModel
, representing the right hand side , the StepFunction
, which represents the solving algorithm, and the Integrator
, which acts as a database for caching integration runs.
The model class
The ODEModel
class is designed to represent the right-hand side of the ODE. It is basically only a light wrapper around a callable:
class ODEModel:
def __init__(self, **kwargs):
...
def __call__(self, t, y):
...
The main thing here is the calling convention being standardized to not include any auxiliary model parameters, like masses in a Hamiltonian System, spring constants or chemical reaction speeds. These are stored inside a kwargs
dict inside the class to avoid trouble with JAX later.
The challenge here is supporting a variable length state: For a normal ODE like the above, the state is a tuple , whereas for a Hamiltonian System, the state is a tuple of position and momentum. So broadening the signature to
def f(t, *args, **kwargs):
...
might be beneficial here.
The step function
This class is responsible for advancing an equation in time. Given a current state tuple , it computes the next state with a given integration method. Common methods are the classical Runge-Kutta method (often called rk4
), the Dormand-Prince 5-4 method (DOPRI5) and the forward Euler method.
Alluding to neural network forward passes, any step function can be implemented by subclassing the StepFunction
class and overriding the forward
method:
class MyStepFunc(StepFunction):
def forward(state):
...
# complicated integration logic
return new_state
The Integrator
This is simply an abstract container that holds previous runs. It should behave like a database, but can be implemented by in-memory holding the previous runs for small systems, or saving results in an actual database and querying them.
#integration is designed to work as quickly as possible
my_int = Integrator()
# example pseudocode integration
my_int.integrate(my_model, start=0.0, end=10.0, step_size=0.001)
# or log the distance to a known solution additionally
my_int.integrate(my_model, start=0.0, end=10.0, step_size=0.001,
metrics=[SolutionDistance(sol=known_solution)])
# inspect the results by querying the latest run
res = my_int.get_result(result_id="latest")
# plotting, saving, sending the results
...
Logging and metrics calculations
Another design emphasis is that one should be able to calculate metrics along an integration curve. This could be a distance to a known solution, which can be used to test performance of an integration algorithm on a certain equation, stability indicators like eigenvalue ratios or Lyapunov exponents.
The Metric
interface is again a simple callable, like the ODEModel
before:
class Metric:
def __init__(...):
# initialize variables
def __call__(...):
# perform complex metric calculations
Multiple metrics can be chained in parallel, enabling detailed instrumentation of the integration process.
Future outlook
This project has just undergone a major change in that numpy
usage was switched to jax.numpy
usage to enable GPU computations. This however presents some challenges:
- The computation is inherently a sequential loop; calculating the solution in this manner in functional-style programming is not as straightforward and requires changes to work on GPU. (This is currently a big todo.)
- Performance! This requires using some low-level
jax
routines, which might not play nicely with metric logging out of the box. - Making the
ODEModel
andMetric
wrappers play nicely with JAX requires some type registration, so that XLA and the JAX core know what to expect.
All in all, I am excited about heaving the project onto accelerators with JAX; it might be a while before a substantial uplift over numpy
is achieved, though.