Yes, I know, post-holiday slumps hit everyone hard. But hey, someone has to do the dirty work, right? So let's pick up, after a few weeks' break, our delightful little chat about Neural ODEs, or NODEs. If you missed the earlier parts, don't worry, they're linked at the end of the article. To recap quickly, we’ve already talked about differential and integral calculus, and analytical and numerical methods for solving differential equations. Now, finally, it’s time to put it all together and see how we arrive at the formulation of these blessed NODEs. So let’s begin...
From Discrete to Continuous
There are many neural network architectures. Too many, in fact. Among them, there's one in particular known as ResNet, or Residual Network. ResNets were a major breakthrough in the world of Deep Learning because they helped overcome certain training issues. What kind of issues, you might ask? Well, here's the idea: to improve performance, a neural network typically needs to be deeper. But the deeper it gets, the worse the performance becomes due to Gradient Vanishing/Explosion nasty problems we won’t delve into here, but trust me, they’re trouble. What makes ResNets special is the use of residual blocks. To put it simply, a ResNet is a series of these blocks, each structured like this:
Figure 1. Residual Block
Now it should be clear why it’s called residual. The output is:
which can be rearranged as:
Where:
- \(x\) is the input
- \(y\) is the output
- \(F(x)\) is the transformation performed by the layers (i.e., the residual)
So what the individual layers learn is not a full mapping to \(y\), but rather the residuals.
But what does this have to do with NODEs? Hold your horses, we’re getting there. Let's start by imagining a ResNet with one layer:
Figure 2. One-Layer ResNet
The output is given by:
Where:
- \(x_0\) is the input to the ResNet
- \(f_1\) is the transformation layer
- \(\theta_1\) are the weights of that layer
- \(x_1\) is the output
Now let's take a two-layer ResNet:
Figure 3. Two-Layer ResNet
Now the output is:
One more. Let's look at a three-layer ResNet:
Figure 4. Three-Layer ResNet
The output is:
Did you get the hint...? So why not think of the depth of the network as a temporal sequence, where each layer \(f_i\) represents a discrete time step? And why not suppose we have infinite layers, and thus infinite \(f_i\), one for each time step...? Sure, let’s do it! And let’s generalize:
But at this point, why not imagine having a single function \(f\) that varies with time, instead of infinite \(f_i\)? So let's write it like this:
Does this formulation ring a bell? No!? Maybe it will if I add a multiplicative factor \(h=1\):
Hopefully that’s clearer now. By now you should be familiar with the Euler formula. If not, catch up here.
From this formulation, we can write:
This should also look familiar, it's the finite difference quotient. So, assuming we can control the value of \(h\), we can easily write:
Formula 1. Differential Formulation of the Neural ODE
So, if we could find a way to consider small variations of our \(f\) by introducing an arbitrarily small \(h\), our problem would shift from being a sequence of large steps (i.e., discrete layers with \(h = 1\)) to infinitely small steps.
And so, at least conceptually, we've transitioned from a discrete problem with many functions \(f_t(x, \theta_t)\) to a continuous one that evolves smoothly over time, guided by a single \(f(x(t), t, \theta)\). Visually, this looks like:
Figure 5. From Discrete to Continuous
At this point, the next step comes naturally. If we have a derivative that tells us how our state changes, we can ask:
how much does it change over time?
The answer, as always: we integrate.
Integrate It Away
We've just obtained a neat differential equation that tells us how the state changes over time. Great, but what we really want to know is how much it changes, that is:
If I start from \(x_0\) at time \(t_0\), where will I be at time \(t_N\)?
The answer, if you've read even a single line of the previous articles, you already know: we integrate.
Formula 2. Integral Formulation of the Neural ODE
This, ladies and gentlemen, is the heart of a NODE. No more list of layers, but a continuous dynamic evolving over time, driven by a function \(f\) that learns how the state should change. And this \(f\), surprise surprise, is nothing but a neural network.
The fun part begins. How do you analytically integrate a function \(f\) that, let’s not forget, is a neural network packed with ReLUs, nonlinear layers, and scattered weights?
Quick answer? You can’t.
That’s why I wrote an entire article on numerical methods. I’m no amateur. And remember in those solvers, what was the parameter that controlled the integration step? Exactly... it was \(h\). So that idea we introduced in the previous paragraph, introducing an \(h\) as small as we like, not only makes sense, it’s exactly what we do.
Houston, We Have a Problem... With Training
We know we can’t solve our integral analytically, but numerically? Who’s stopping us? So, for now, let’s just write this:
Formula 3. Numerical Formulation of the Neural ODE
where:
- \(f\) is not your typical problem. It’s not Lotka Volterra or anything like that, it’s a neural network with its layers, nonlinearities, and activations.
- \(\theta\) are the weights of the neural network, i.e., the parameters to optimize.
- \(x_0\) is the initial state (the input to the neural network), and it’s known.
- \(t_0\) is the starting time.
- \(t_N\) is the ending time, the moment we observe the output.
If you’ve read the articles on neurons, you know that training a neural network means optimizing the weights. This is done by computing the Loss, which is the error between the predicted output and the desired one, and propagating it backward through backpropagation, modifying the weights based on how much they contributed to the error. Easy enough in the discrete world, right? Layer by layer, weight by weight. And now? Now, every instant in the integration time contributes, bit by bit, to the final loss.
How do we compute the error of something that evolves continuously over time?
The network is no longer made up of distinct layers. It’s now hidden inside an integrator. There’s no longer a well-defined sequence of layers, like this:
where to reach the output \(y\), a single forward pass through the network is enough, whereas with NODEs we have:
We can no longer see or control the direct effect of each weight on the error, because the output is not directly produced by a neural network, but by an ODESolver that integrates the dynamic system defined by the network itself. As we've seen, a solver repeatedly calls the dynamic function \(f\), depending on the value of \(h\). So now, each call to \(f\) is a forward pass of the network. The result is used to compute the next step's value, until we reach the output \(y\) at time \(t_N\), i.e., after \(n\) forwards.
The differences between a NODE forward pass and that of a standard neural network are shown in the following diagrams:
Figure 5. Neural Network Forward Pass
Figure 6. Neural ODE Forward Pass
The Loss, then, is no longer a simple error computed as a function of true output and predicted output, but something much more complex, it depends on all time steps, as shown in the image below.
Figure 7. Loss in a Neural ODE
You can see each black dot as a forward pass that generates the next step's output, contributing to the Loss, starting from the initial instant \(t_0\) and ending at \(t_N\), which produces the prediction \(y = x(t_N)\).
To recap, the initial inputs are given by \(x_0\), and the goal is to minimize the value of
where:
- \(L\) is the Loss function
- \(y\) is the true output
- \(x(t_N)\) is the output from the ODESolver, obtained after dozens or hundreds of forward passes of the same neural network \(f\).
At this point, though, the question remains: How do we train the network to optimize \(\theta\)? In other words, how do we perform backpropagation of the error knowing that the calculated Loss is not the result of a single forward pass of the neural network, but the result of an ODESolver, i.e., \(n\) forward passes of the neural network?
Adjoint at the Table
If you’ve made it this far and actually understand what’s going on, great! You can explain it to me too! But jokes aside, let’s continue. From the neuron series, we learned that minimizing the Loss means taking the derivative of the error with respect to the weights, that is:
In the previous section, however, we saw that in a NODE, the error results from many infinitesimal contributions generated by the states \(x(t_i)\), each of which uses the exact same set of weights \(\theta\). So, by the chain rule, we can write:
Formula 4. Gradient of the Loss
where:
- \(\frac{\partial L}{\partial x(t_i)}\) is the error contribution of the \(i\)-th state
- \(\frac{dx}{d\theta}\) is the contribution of the weights to the state \(x(t_i)\)
Let’s focus on the quantity:
Formula 5. Sensitivity
This value is also called the sensitivity of the Loss with respect to the state at time \(t\), and it represents the instantaneous gradient that evolves over time, that is, how each state impacts the error. It answers the question:
If I changed the state value, how much would the Loss value change?
And how do we quantify how much one quantity instantaneously affects another? Exactly, through the derivative. So we can write:
By Schwarz’s theorem, we can rewrite it as:
Now we’ve got the Loss as a function of time, but as we know, the Loss is a function of the states, and the states are a function of time, so we can apply the chain rule and write:
Substituting into our previous expression gives:
Amid this jungle of derivatives, you probably didn’t notice that part of this formula was already defined here. So, replacing the notation for the \(i\)-th state with a more general one, we can write:
And if you recall, we also wrote:
So let’s make one more substitution and write:
Now we’re just one step away. We’ve seen that to reach the final state, we used an ODESolver that evolves forward in time: we start from the initial state \(x_0\) (the input) and reach the final state \(x_{t_N}\) (the output). It’s at the final state that we can compute the Loss, but remember, we have no information about the intermediate states, so backpropagation is off the table. To understand how each individual state contributed to the error, we must go back in time starting from the state \(x(t_N)\) and reaching the state \(x_0\). But this has some implications. Let me explain with an example: suppose you’re walking uphill and take a step forward. You’ve moved half a meter forward and find yourself higher up, say by \(5\%\). Now take that same exact step, but backward. You still move half a meter, still on a \(5\%\) slope, but this time you’re going downhill, you’ve descended by \(5\%\), which means you’re at \(-5\%\) relative to where you started. Mathematically, what we’re doing is ensuring that the derivative (i.e., the slope) remains consistent, it should correctly indicate not only how steep the change is, but also whether we’re going up or down. In short, the error at time \(t_0\) is zero, and as the dynamics evolve, it accumulates gradually until reaching the final error at time \(t_N\). Starting from the final error and retracing our steps backward, we must subtract the individual components until we reach time \(t_0\), where it becomes zero again. So, to go back in time from the final state while keeping the gradient coherent, we need a nice big minus sign. Let’s add it:
Formula 6. Differential Formulation of the Sensitivity
Yes, I know, I made it sound like this just happens, maybe even a bit forced, but really, you should be thanking me for sparing you the hassle of more formulas. Still curious? You can find the proof for why we need the minus sign here. But back to us. From the fundamental theorem of calculus, we can write:
Formula 7. Integral Formulation of Sensitivity
What we just introduced is called the Adjoint Method, and why it’s not only useful but indispensable, we’ll see very soon. Visually, what the adjoint method does is shown in the figure below:
Figure 7. Adjoint Method
We start from the final time step, for which we have all the necessary information since it’s the output, and we move backward in time, calculating at each time step the error sensitivity with respect to the current state. If there's one thing we've learned, it's that we hate analytical integrals. So how do we solve our sensitivity equation...? And why, of course, with an ODESolver?
Formula 8. Numerical Formulation of Sensitivity
where:
- \(-a(t) \cdot \frac{df(x(t), t, \theta)}{dx}\) is the dynamic.
- \(\frac{dL}{dx(t_N)} = a(t_N)\) is the sensitivity at the final state and is computable, hence known.
- \(t_N\) is the final time, from which we start integration.
- \(t_0\) is the initial time, where we end the integration.
An Adjoint to Rule Them All
If you've made it this far, congratulations, but I imagine you're more confused than before. Because, in fact, the adjoint method as it stands seems quite useless. So, to avoid lynching, let’s give all this some meaning. Let's start again from the formulation of the loss. The basics of mathematics tell us that whatever we do on the right, we must also do on the left, and everything will continue to hold. So, we differentiate both sides with respect to time:
By Schwarz’s theorem we can write:
From which, recalling this equality and this other one:
Integrating both sides over time gives:
But remember, we’re doing backpropagation. That means we’re starting from the final instant. So, for the same reason as before, we must treat sensitivity as a negative quantity and invert the sign of the integral, applying its properties:
And thus, solving:
We know this gradient accumulates the contribution of the weights with respect to the loss. Initially, this container of contributions is empty, and as we move backward through time, we can quantify and add them. This means that:
so rewriting the equation above, we have:
Formula 9. Integral Formulation of the Gradient
We have thus reached a new formulation of the error gradient that enables backpropagation.
And do you know what’s special about this formulation?
That it does not depend on the internal states.
And if you remember, that was exactly our problem. Since the neural network is hidden within a solver, we don’t have access to its internal states.
Even if we could find a way to save them, it wouldn’t be memory-efficient, because remember, we have hundreds of forward passes, and therefore hundreds of internal states.
Same story: we have an integral. So, once again, we call upon the solvers and write:
Formula 10. Numerical Formulation of the Gradient
Since we started from \(t_N\), the value in \(\frac{dL(t_0)}{d\theta}\) is the cumulative result of all the contributions gathered across time. It is therefore the complete gradient of the loss with respect to the weights \(\theta\). This is precisely the value used to update the parameters during the optimization process.
Thus, the ODESolver on the sensitivity function allows us to compute the instantaneous sensitivity, which is then used by the ODESolver to compute the gradient function. One final piece remains: what are the inputs of our \(f(x(t), t, \theta)\)? We know they are the states \(x(t)\), but where do these states come from? We had them during the forward phase, because they’re exactly what we computed with this ODESolver. But, as we’ve said over and over, there are thousands of them, and storing them all would take too much memory. So how do we proceed? Well, once again… we use another ODE that solves our problem backward. So if this ODESolver performs the forward pass, then the following:
Formula 11. Numerical Formulation of the State in Backward
retraces our states backward. So we use the neural network again to predict the state, but instead of proceeding as
we go in reverse:
where \(x_N\) is known, since it’s the output of the ODESolver from the forward phase.
So yes, computationally we must solve another ODE by performing \(n\) forward passes of the neural network, but this way we instantaneously
have the state \(x_i\) needed to solve the ODE for sensitivity, and consequently, the one for the gradient.
And since this result is only needed for a single instant in time, once it’s used for one step we can discard it, keeping memory usage minimal.
To conclude, it’s important to clarify one thing. We introduced this whole mechanism to avoid depending on the states,
and now I’m telling you that we still need to compute them. Sounds like a contradiction, doesn’t it? Well, it isn't.
Before, we were dependent on the entire state, a matrix describing how every parameter influences every component of the system.
Now, instead, we only depend on the value of the state over time, a vector that tells us where we are at that specific instant.
In other words, we no longer need to keep track of all possible directions of the evolution: it’s enough to know the current position
to compute how the error propagates.
God Bless Recaps
If you’ve made it this far congratulations, not because you’re the millionth visitor (in fact, you’re probably the second, I’m the first), but because I threw you into the Maenads and you came out as Orpheus. Not that that’s a good omen. Anyway, now that we have all the tools, let’s do a recap.
Let’s start by looking at the following diagram that I so diligently sketched out for you:
Figure 8. Overview
Just like in traditional neural networks, training a NODE consists of two phases: forward and backward.
Forward
We use an ODESolver (Euler, RK4, DOPRI5, etc.) from time \(t_0\) to a chosen time \(t_N\), using a dynamics function defined by a neural network \(f\), with \(x_0\) as input:
At time \(t_N\), we have our output (or prediction), given by the state \(x(t_N)\). We use the state \(x(t_N)\) and the true output \(y\) to calculate the Loss \(L\) and its derivative with respect to the state, which is the sensitivity at time \(t_N\):
Both \(x(t_N)\) and \(a(t_N)\), which have dimensions equal to the output of the neural network, are used in the backward steps.
Backward
We use a system of three ODESolvers, of the same type used for the forward phase. All three solve the dynamics in reverse, starting from \(t_N\) and going back to \(t_0\):
where:
- The first takes in the state \(x(t_N)\) predicted in the forward phase and outputs the state at time \(x(t_{N-1})\).
- The second takes in the sensitivity \(a(t_N)\) calculated during the forward phase, and the current state \(x(t_{N-1})\) calculated by the first ODESolver, and outputs the sensitivity \(a(t_{N-1})\) at time \(t_{N-1}\).
- The third takes in the current sensitivity \(a(t_{N-1})\) produced by the second ODESolver, and the current state \(x(t_{N-1})\) from the first ODESolver. The output of this ODESolver is a gradient matrix with the same dimensions as the neural network (or rather its parameters).
Once we reach time \(t_0\), we can use the gradient matrix produced by the third ODESolver, which contains the partial derivatives of the Loss with respect to the weights, to update the weights of the network.
One final note before we move on to conclusions: If you have any doubts about how to calculate the derivatives
It’s nothing more or less than what I’ve already shown you in the series of articles on neurons. Yet another reason to go check it out.
To Wrap Up
We've reached the end of this (final, for now) theoretical chapter on NODEs. It's been a long journey, starting from derivatives and integrals, that led us to formulate a continuous dynamic model capable of learning complex transformations over time.
As we've seen, NODEs are computationally more expensive than classic networks, but they offer unique advantages:
- They don’t need to store all intermediate activations, thus saving memory.
- They can operate with smaller networks. Thanks to the continuous nature of integration, they are able to generalize better with less complex architectures.
So, when does it make sense to use a NODE?
As often happens in machine learning, the answer is: it depends.
If your problem involves continuous temporal dynamics, irregular data, or physical systems modelable through ODEs, then NODEs can truly make a difference.
In all other cases, they might just make your life more complicated.
In the next (and final) article of the series, we’ll finally get practical: we’ll look at a concrete application of NODEs and discuss some real-world use cases.
Until next time.