"""
PID tuning via automatic differentiation with JAX.

Plant: mass-spring-damper  m*x'' + c*x' + k*x = u
        m=1, c=0.5, k=2

Optimises [Kp, Ki, Kd] to minimise integrated squared error + control effort.
Produces four plots used in the blog post.
"""

import os
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import optax



# --------------------------------------------------------------------------- #
# Plant parameters
# --------------------------------------------------------------------------- #
M, C, K = 1.0, 0.5, 2.0
DT = 0.01
T_END = 8.0
SETPOINT = 1.0
LAM = 1e-1           # no effort penalty — cosine LR decay is what stops gain growth
N_STEPS = int(T_END / DT)

# --------------------------------------------------------------------------- #
# Closed-loop simulation (fully differentiable)
# --------------------------------------------------------------------------- #

def simulate(gains):
    """Roll out the closed-loop system for N_STEPS steps.

    State: [position, velocity, integral_of_error]
    Returns arrays of positions, control signals, and errors.
    """
    Kp, Ki, Kd = gains[0], gains[1], gains[2]
    r = SETPOINT

    def step(carry, _):
        x, xdot, z = carry
        e = r - x
        # Derivative-on-measurement: avoids the setpoint kick that arises when
        # differentiating the error at t=0 (where e jumps discontinuously).
        u = Kp * e + Ki * z - Kd * xdot

        # Euler integration
        xddot = (u - C * xdot - K * x) / M
        # Semi implicit Euler: update velocity before position for better stability
        xdot_new = xdot + DT * xddot
        x_new = x + DT * xdot_new
        z_new = z + DT * e

        return (x_new, xdot_new, z_new), (x_new, u, e)

    init = (0.0, 0.0, 0.0)
    _, (xs, us, es) = jax.lax.scan(step, init, None, length=N_STEPS)
    return xs, us, es


def loss(gains):
    xs, us, es = simulate(gains)
    t = jnp.arange(N_STEPS) * DT

    weights = 1.0 + 5.0 * (t / T_END)**2
    tracking = jnp.sum(weights * es**2) * DT

    threshold = 0.01
    tail = jnp.mean(jnp.maximum(jnp.abs(es) - threshold, 0.0))

    effort = 1e-2 * jnp.sum(us**2) * DT

    integral_usage = jnp.mean(jnp.abs(es.cumsum() * DT))

    return tracking + 10.0 * tail + effort + 1e-2 * integral_usage


# JIT-compile gradient
loss_and_grad = jax.jit(jax.value_and_grad(loss))

# --------------------------------------------------------------------------- #
# Optimisation
# --------------------------------------------------------------------------- #

gains_init = jnp.array([1.0, 0.0, 0.0])

N_ITER = 3000
lr_schedule = optax.cosine_decay_schedule(init_value=0.1, decay_steps=N_ITER)
optimizer = optax.adam(learning_rate=lr_schedule)
opt_state = optimizer.init(gains_init)

gains = gains_init
loss_history = []
gains_history = [np.array(gains)]
for i in range(N_ITER):
    val, grads = loss_and_grad(gains)
    updates, opt_state = optimizer.update(grads, opt_state)
    gains = optax.apply_updates(gains, updates)
    loss_history.append(float(val))
    gains_history.append(np.array(gains))

gains_history = np.array(gains_history)
print(f"Initial gains : Kp={gains_init[0]:.3f}  Ki={gains_init[1]:.3f}  Kd={gains_init[2]:.3f}")
print(f"Optimised gains: Kp={gains[0]:.3f}  Ki={gains[1]:.3f}  Kd={gains[2]:.3f}")
print(f"Final loss: {loss_history[-1]:.4f}")

# --------------------------------------------------------------------------- #
# Simulate initial and optimised controllers
# --------------------------------------------------------------------------- #

xs_init, us_init, es_init = simulate(gains_init)
xs_opt,  us_opt,  es_opt  = simulate(gains)

time = np.arange(N_STEPS) * DT
xs_init = np.array(xs_init)
xs_opt  = np.array(xs_opt)
us_init = np.array(us_init)
us_opt  = np.array(us_opt)
es_init = np.array(es_init)
es_opt  = np.array(es_opt)

# --------------------------------------------------------------------------- #
# Plotting helpers
# --------------------------------------------------------------------------- #

OUT_DIR = os.path.dirname(os.path.abspath(__file__))

BLUE   = "#2E86AB"
ORANGE = "#E07A5F"
GREEN  = "#3D405B"
GRAY   = "#AAAAAA"

def save(name):
    path = os.path.join(OUT_DIR, name)
    plt.savefig(path, dpi=150, bbox_inches="tight")
    plt.close()
    print(f"Saved {path}")

# --------------------------------------------------------------------------- #
# 1. Step response
# --------------------------------------------------------------------------- #

fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(time, np.ones_like(time), color=GRAY, linestyle="--", linewidth=1.2, label="Setpoint")
ax.plot(time, xs_init, color=ORANGE, linewidth=1.8, label=r"Initial ($K_p{=}1,\;K_i{=}0,\;K_d{=}0$)")
ax.plot(time, xs_opt,  color=BLUE,   linewidth=1.8, label="Optimised")
ax.set_xlabel("Time (s)")
ax.set_ylabel("Position (m)")
ax.set_title("Step Response — Before and After Tuning")
ax.legend(framealpha=0.9)
ax.set_xlim(0, T_END)
ax.grid(True, alpha=0.3)
fig.tight_layout()
save("step_response.png")

# --------------------------------------------------------------------------- #
# 2. Loss curve
# --------------------------------------------------------------------------- #

fig, ax = plt.subplots(figsize=(7, 3.5))
ax.semilogy(loss_history, color=BLUE, linewidth=1.8)
ax.set_xlabel("Iteration")
ax.set_ylabel("Loss (log scale)")
ax.set_title("Optimisation Loss over Adam Iterations (time-weighted ISE + effort)")
ax.grid(True, which="both", alpha=0.3)
fig.tight_layout()
save("loss_curve.png")

# --------------------------------------------------------------------------- #
# 3. Gain trajectories
# --------------------------------------------------------------------------- #

fig, ax = plt.subplots(figsize=(8, 4))
iters = np.arange(len(gains_history))
ax.plot(iters, gains_history[:, 0], color=BLUE,   linewidth=1.8, label=r"$K_p$")
ax.plot(iters, gains_history[:, 1], color=ORANGE, linewidth=1.8, label=r"$K_i$")
ax.plot(iters, gains_history[:, 2], color=GREEN,  linewidth=1.8, label=r"$K_d$")
ax.set_xlabel("Iteration")
ax.set_ylabel("Gain value")
ax.set_title("PID Gain Trajectories During Optimisation")
ax.legend(framealpha=0.9)
ax.grid(True, alpha=0.3)
fig.tight_layout()
save("gain_trajectories.png")

# --------------------------------------------------------------------------- #
# 4. Tracking error — shows steady-state error of P-only vs zero error with Ki
# --------------------------------------------------------------------------- #

fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(time, es_init, color=ORANGE, linewidth=1.5,
        label=r"Initial ($K_i{=}0$) — permanent steady-state error", alpha=0.9)
ax.plot(time, es_opt,  color=BLUE,   linewidth=1.5,
        label="Optimised — error converges to zero")
ax.axhline(0, color=GRAY, linewidth=0.8, linestyle="--")
ax.set_xlabel("Time (s)")
ax.set_ylabel("Tracking error  $e(t) = r - x(t)$")
ax.set_title("Tracking Error — Steady-State Error Eliminated by Integral Action")
ax.legend(framealpha=0.9)
ax.set_xlim(0, T_END)
ax.grid(True, alpha=0.3)
fig.tight_layout()
save("tracking_error.png")

# --------------------------------------------------------------------------- #
# 5. Control effort — shows that optimised controller doesn't use much more effort than initial P-only
# --------------------------------------------------------------------------- #

fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(time, us_init, color=ORANGE, linewidth=1.5, label="Initial — more effort due to steady-state error", alpha=0.9)
ax.plot(time, us_opt,  color=BLUE,   linewidth=1.5, label="Optimised — less effort as error goes to zero")
ax.set_xlabel("Time (s)")
ax.set_ylabel("Control signal $u(t)$")
ax.set_title("Control Effort — Optimised Controller Uses Less Effort as Error Vanishes")
ax.legend(framealpha=0.9)
ax.set_xlim(0, T_END)
ax.grid(True, alpha=0.3)
fig.tight_layout()
save("control_effort.png")


print("\nAll plots written to", OUT_DIR)
