← All posts

Tuning PID Controllers with Automatic Differentiation

Control TheoryOptimizationJAXAutomatic Differentiation

PID controllers are the workhorses of industrial control. They are simple enough to fit on a microcontroller, yet surprisingly effective across a wide range of systems. The catch is that tuning them well takes time and — if you are being honest — a fair amount of intuition. Classic recipes like Ziegler-Nichols are fast but often leave performance on the table. Model-free autotuning approaches work but can require live experiments on the real plant.

Here is a different take: write down a differentiable simulation of your system, define a loss that captures what “good control” means to you, and let gradient descent find the gains. Automatic differentiation handles the calculus; you just specify the objective.

This post walks through exactly that, using JAX.


The Example: A Second-Order Mass-Spring-Damper

The plant is a linear mass-spring-damper:

mx¨+cx˙+kx=u(t)m\ddot{x} + c\dot{x} + kx = u(t)

with mass m=1kgm = 1\,\text{kg}, damping c=0.5N⋅s/mc = 0.5\,\text{N·s/m}, and spring stiffness k=2N/mk = 2\,\text{N/m}. The PID controller drives the position xx to a unit step setpoint:

u(t)=Kpe(t)+Ki ⁣0te(τ)dτ+Kde˙(t),e(t)=rx(t)u(t) = K_p\,e(t) + K_i\!\int_0^t e(\tau)\,d\tau + K_d\,\dot{e}(t), \quad e(t) = r - x(t)

This system has a natural frequency of ωn=k/m1.41rad/s\omega_n = \sqrt{k/m} \approx 1.41\,\text{rad/s} and a damping ratio of ζ=c/(2mk)0.18\zeta = c/(2\sqrt{mk}) \approx 0.18, so the open-loop step response is lightly damped and oscillatory — a good stress test for the tuning method.


Differentiating Through a Simulation

The key idea is to discretize the system with a fixed time step and write the entire rollout as a JAX function. Since every arithmetic operation in the loop is tracked by JAX’s tracing mechanism, you get exact gradients of any scalar loss with respect to the PID gains for free.

Derivative-on-measurement

One subtlety worth flagging upfront: the derivative term is applied to the measurement rather than the error:

un=Kpen+KiznKdx˙nu_n = K_p\,e_n + K_i\,z_n - K_d\,\dot{x}_n

The sign flip on KdK_d is intentional. To see why it matters, consider what happens at t=0t = 0 with derivative-on-error instead. The setpoint is a unit step, so the error jumps discontinuously from e1=0e_{-1} = 0 to e0=1e_0 = 1 in the very first time step. The finite-difference derivative is then e˙0(e0e1)/Δt=1/0.01=100s1\dot{e}_0 \approx (e_0 - e_{-1})/\Delta t = 1/0.01 = 100\,\text{s}^{-1}, and the control signal spikes to Kd100K_d \cdot 100 regardless of the magnitude of KdK_d. With the optimised Kd2.7K_d \approx 2.7 found below, that spike would be 270N270\,\text{N} on top of the proportional term — something no real actuator can tolerate.

With derivative-on-measurement, this problem disappears. At t=0t = 0 the system starts at rest: x˙0=0\dot{x}_0 = 0, so the derivative contribution is Kd0=0-K_d \cdot 0 = 0, and the first control output is simply u0=Kp1=Kpu_0 = K_p \cdot 1 = K_p. Subsequent samples see a non-zero but finite x˙\dot{x}, so the damping action builds up smoothly. This is standard practice in industrial PID implementations.

Discrete update equations

The integrator is semi-implicit Euler (symplectic Euler): velocity is updated first using the current acceleration, then position is advanced using the updated velocity. This is slightly more stable than explicit Euler for oscillatory systems at the same step size Δt=0.01s\Delta t = 0.01\,\text{s}.

The complete update for state s=[x,x˙]s = [x,\,\dot{x}]^\top and integral accumulator zz is:

(x˙n+1xn+1zn+1)=(x˙n+Δt(uncx˙nkxn)/mxn+Δtx˙n+1zn+Δten)\begin{pmatrix} \dot{x}_{n+1} \\ x_{n+1} \\ z_{n+1} \end{pmatrix} = \begin{pmatrix} \dot{x}_n + \Delta t\,(u_n - c\,\dot{x}_n - k\,x_n)/m \\ x_n + \Delta t\,\dot{x}_{n+1} \\ z_n + \Delta t\,e_n \end{pmatrix}

Note that xn+1x_{n+1} uses x˙n+1\dot{x}_{n+1} (already updated), not x˙n\dot{x}_n.


The Loss Function

The objective combines four terms:

L=n=0Nwnen2Δtweighted tracking+101Nnmax(en0.01,0)tail penalty+102 ⁣nun2Δteffort+1021NnjnejΔtintegral penalty\mathcal{L} = \underbrace{\sum_{n=0}^{N} w_n\, e_n^2\,\Delta t}_{\text{weighted tracking}} + \underbrace{10 \cdot \frac{1}{N}\sum_n \max(|e_n| - 0.01,\, 0)}_{\text{tail penalty}} + \underbrace{10^{-2}\!\sum_n u_n^2\,\Delta t}_{\text{effort}} + \underbrace{10^{-2} \cdot \frac{1}{N}\sum_n\left|\sum_{j\le n} e_j\,\Delta t\right|}_{\text{integral penalty}}

where wn=1+5(tn/T)2w_n = 1 + 5\,(t_n / T)^2 and T=8sT = 8\,\text{s}.

Weighted tracking is the dominant term. The quadratic time-weighting wnw_n goes from 11 at t=0t = 0 to 66 at t=Tt = T, so errors near the end of the simulation are penalised six times more than early-transient errors. This is similar in spirit to ITAE but uses a quadratic ramp normalised by the horizon length rather than a linear ramp in absolute time.

Tail penalty adds a hard push toward zero once the error is larger than 0.01m0.01\,\text{m}. It acts as a deadband: below the threshold the gradient from this term vanishes, preventing the optimiser from over-tightening in the noise-free simulation.

Effort penalises large control signals with a small coefficient (10210^{-2}), discouraging unnecessarily aggressive gains.

Integral penalty discourages large integral accumulation, which indirectly limits integral windup.

The weighting scheme means the optimizer cannot plateau with a residual steady-state offset: as long as the error at late times is non-zero, the gradient from the weighted tracking term remains large, and the gradient from the tail term adds a constant push on top.


Why JAX?

JAX’s jit + grad pipeline means the gradient computation costs roughly the same as two forward passes — that is the magic of reverse-mode AD. With jax.lax.scan you can unroll the simulation loop without materialising every intermediate in Python, keeping compilation time and memory under control even for long horizons.

There are no finite-difference approximations and no symbolic manipulations. The gradients are numerically exact (up to floating point) regardless of the complexity of the integrator or the loss.


Results

The optimisation starts from Kp=1,Ki=0,Kd=0K_p = 1,\, K_i = 0,\, K_d = 0 — a pure proportional controller — and runs Adam for 3000 iterations with a cosine-decaying learning rate starting at 0.10.1. The decay is important: without it the gains could grow indefinitely, since larger gains always reduce the time-weighted tracking error further. The cosine schedule lets the optimiser move fast early and then freeze the gains as the learning rate approaches zero.

Step Response Comparison

The figure below compares the step response over the 8-second simulation window.

Step response comparison

Orange is the initial P-only controller (Kp=1,Ki=0,Kd=0K_p = 1,\, K_i = 0,\, K_d = 0). It never reaches the setpoint. This is expected: for a spring-mass-damper with a proportional controller the steady-state position is xss=Kpr/(k+Kp)=11/(2+1)=1/3mx_{ss} = K_p \cdot r / (k + K_p) = 1 \cdot 1 / (2 + 1) = 1/3\,\text{m}, leaving a permanent offset of 0.67m0.67\,\text{m}.

Blue is the optimised controller. It reaches 1.0m1.0\,\text{m} and stays there. The transient has a mild overshoot followed by a few damped oscillations — expected for a plant with ζ0.18\zeta \approx 0.18 — but they decay and the response settles cleanly within the 8-second window.


Loss Curve

The loss drops by almost three orders of magnitude over 3000 iterations.

Loss over iterations

The vertical axis is on a log scale. The steepest descent happens in the first few hundred iterations when the learning rate is still near its peak value (0.10.1). After roughly iteration 1000 the cosine schedule has reduced the learning rate significantly and the curve flattens, indicating that the gains are no longer moving much. The final plateau is where the optimiser has converged.


Gain Trajectories

This is the most instructive plot: watching the gains evolve reveals what the gradient is actually discovering.

PID gain trajectories

KpK_p (blue) and KdK_d (green) respond immediately in the first few dozen iterations — they act on the current error and current velocity respectively, so their gradient signal is strong and direct. KiK_i (orange) lags because its effect accumulates over the entire trajectory; the gradient signal reaching it is more diffuse and takes longer to build up.

All three curves flatten as the cosine schedule drives the learning rate toward zero — a clean visual signature of convergence rather than a hard stop. The final values are roughly Kp7K_p \approx 7, Ki2K_i \approx 2, Kd2.7K_d \approx 2.7. Nobody specified these values; the optimiser discovered them by minimising the loss.


Tracking Error

The clearest way to see the steady-state improvement is to look at the error signal directly.

Tracking error

Orange (initial, Ki=0K_i = 0): the error quickly settles to 0.67m0.67\,\text{m} and remains there for the rest of the simulation. This is not a transient — it is a permanent steady-state error that can never be eliminated by proportional action alone when a spring load is present.

Blue (optimised): the error oscillates during the transient (positive overshoot around t0.5st \approx 0.5\,\text{s}, a small negative dip around t1.5st \approx 1.5\,\text{s}), then decays to zero. The oscillations are consistent with the plant’s low damping ratio. Gradient descent, given a loss that penalises late-time errors heavily, figured out on its own that integral action is the right tool for eliminating steady-state error — because without Ki>0K_i > 0 the time-weighted loss is unavoidably large.


Control Effort

Control effort

Orange (initial, P-only): the control signal starts at u0=Kpe0=11=1Nu_0 = K_p \cdot e_0 = 1 \cdot 1 = 1\,\text{N} and oscillates before settling to uss=kxss=20.33=0.67Nu_{ss} = k \cdot x_{ss} = 2 \cdot 0.33 = 0.67\,\text{N} — the force the spring exerts at the steady-state offset position.

Blue (optimised): the first sample is u0=Kpe0=717Nu_0 = K_p \cdot e_0 = 7 \cdot 1 \approx 7\,\text{N} (pure proportional kick; integral and derivative are both zero at t=0t = 0 since z0=0z_0 = 0 and x˙0=0\dot{x}_0 = 0). The signal then dips negative around t1st \approx 1\,\text{s}: this is the derivative braking term Kdx˙-K_d\,\dot{x} pulling back as the mass moves toward the setpoint at its peak velocity. After the transient the signal settles to approximately 2N2\,\text{N}, which is the spring force kxss=21Nk \cdot x_{ss} = 2 \cdot 1\,\text{N} needed to hold the mass at x=1mx = 1\,\text{m} against the spring.

There is no large spike at t=0t = 0 in the optimised controller’s signal. This is a direct consequence of using derivative-on-measurement: since x˙0=0\dot{x}_0 = 0, the derivative term contributes nothing at the first step. Had we used derivative-on-error, the step discontinuity in ee at t=0t = 0 would have produced a spike of approximately Kd/Δt2.7/0.01=270NK_d / \Delta t \approx 2.7 / 0.01 = 270\,\text{N} on top of the proportional kick.


Practical Considerations

Simulation fidelity. Gradient-based tuning is only as good as your model. If the simulation drifts significantly from the real plant, the optimal gains may not transfer. Sim-to-real gap is the main risk.

Numerical stability. Long rollout horizons can cause gradient magnitudes to grow or shrink exponentially — the same instability that plagues RNN training. For stiff systems or long horizons, consider using a smaller step size, a more stable integrator (e.g. RK4), or gradient clipping.

Local minima. The closed-loop simulation loss is generally non-convex in the gains. In practice, the basin of attraction for reasonable initial guesses is large for common plant types, but it is worth running from several starting points to build confidence.

Nonlinear plants. The approach extends directly to nonlinear systems — just swap the linear update for your actual model equations. AD does not care whether the dynamics are linear or not.


Takeaways

  • You can treat PID gain tuning as a standard gradient-based optimisation problem by differentiating through a closed-loop simulation.
  • JAX makes this straightforward: write the rollout, define the loss, call jax.grad.
  • The choice of loss matters. A quadratic time-weighting combined with a tail penalty keeps the gradient alive for late-time errors, naturally suppressing steady-state error and sustained ringing. Plain ISE can plateau with a residual offset because the gradient gets small as the error shrinks.
  • Gradient descent discovers control structure. Starting from a P-only controller with Ki=0K_i = 0, the optimiser finds on its own that integral action is needed — because without it the time-weighted loss is unavoidably large.
  • Derivative-on-measurement eliminates the actuator spike that would otherwise appear at t=0t = 0 when the setpoint steps. The gain is quantitative: with Kd2.7K_d \approx 2.7 and Δt=0.01s\Delta t = 0.01\,\text{s}, the spike avoided is roughly 270N270\,\text{N}.
  • The main assumption is that you have a reasonable simulation of your plant. If you do, this approach is fast, flexible, and easy to extend to more complex controllers or multi-loop architectures.

The Python script used to generate all plots is available if you want to run the experiments yourself.