r"""
.. _examples_qlearning_offpolicy:

Off-policy Q-learning
=====================

This example tries to reproduce the results from the linear MPC numerical experiment in
:cite:`gros_datadriven_2020`, but in an off-polic setting. We are given an RL
environment whose cost function is

.. math::
    L(s,a) = \frac{1}{2} \left(
        s^\top s + \frac{1}{2} a^2 + w^\top \max\{0, \underline{s} - s\}
        + w^\top \max\{0, s - \overline{s}\}
    \right)

where :math:`s` is the state, :math:`a` is the action, :math:`w` is a weight vector, and
:math:`\underline{s}` and :math:`\overline{s}` are the lower and upper bounds of the
state, respectively. The dynamics of the real environment are

.. math::
    s_+ = \begin{bmatrix} 0.9 & 0.35 \\ 0 & 1.1 \end{bmatrix} s
        + \begin{bmatrix} 0.0813 \\ 0.2 \end{bmatrix} a
        + \begin{bmatrix} e \\ 0 \end{bmatrix}

where :math:`e \sim \mathcal{U}(-0.1, 0)`. Given the state :math:`s_k`, the following
MPC scheme is used to control the system

.. math::
   \begin{aligned}
      \min_{x_{0:N}, u_{0:N-1}, \sigma_{1:N}} \quad &
        V_0 + x_N^\top S x_N + \sum_{i=1}^{N}{ w^\top \sigma_i } \\
        & + \sum_{i=0}^{N-1}{ \gamma^i
            \left(
                x_i^\top x_i + 0.5 u_i^2 +
                f^\top \begin{bmatrix} x_i \\ u_i \end{bmatrix}
            \right)
        } \\
      \textrm{s.t.} \quad & x_0 = s_k \\
                          & x_{i+1} = A x_i + B u_i + b & i=0,\dots,N-1 \\
                          & \underline{s} + \underline{x} - \sigma_i \leq x_i
                            \leq \overline{s} + \overline{x} + \sigma_i
                            \quad & i=1,\dots,N
   \end{aligned}

with :math:`\gamma = 0.9`, and the learnable parameters are

.. math:: \theta = \left(
        V_0, \underline{x}, \overline{x}, b, f, A, B
    \right)

The parameters are initialized differently, and in particular, the prediction model of
the MPC is initialized wrongly as

.. math::
    A = \begin{bmatrix} 1 & 0.25 \\ 0 & 1 \end{bmatrix}, \quad
    B = \begin{bmatrix} 0.0312 \\ 0.25 \end{bmatrix},

and :math:`S` is the solution to the corresponding discrete-time algebraic Riccati
equation, i.e., computed with the wrong dynamics matrices. The task is simple: find a
parametrization :math:`\theta` such that the cost function is minimized. To solve it,
we will employ a second-order LSTD Q-learning algorithm. However, we will train this
agent with data generated in an off-policy fashion, i.e., generated by another
controller.
"""

import logging
from collections.abc import Callable
from typing import Any, Optional

import casadi as cs
import gymnasium as gym
import numpy as np
import numpy.typing as npt
from csnlp import Nlp
from csnlp.wrappers import Mpc
from gymnasium.spaces import Box
from gymnasium.wrappers import TimeLimit

from mpcrl import Agent, LearnableParameter, LearnableParametersDict, LstdQLearningAgent
from mpcrl.optim import NewtonMethod
from mpcrl.util.control import dlqr
from mpcrl.wrappers.agents import Evaluate, Log, RecordUpdates

# %%
# Defining the environment
# ------------------------
# First things first, we need to build the environment. We will use the :mod:`gymnasium`
# library to do so. The most important methods are :func:`gymnasium.Env.reset` and
# :func:`gymnasium.Env.step`, which will be called to reset the environment to its
# initial state and to step the dynamics and receive a realization of the reward signal,
# respectively. The environment is defined as a the following class.


class LtiSystem(gym.Env[npt.NDArray[np.floating], float]):
    """A simple discrete-time LTI system affected by uniform noise."""

    nx = 2  # number of states
    nu = 1  # number of inputs
    A = np.asarray([[0.9, 0.35], [0, 1.1]])  # state-space matrix A
    B = np.asarray([[0.0813], [0.2]])  # state-space matrix B
    x_bnd = (np.asarray([[0], [-1]]), np.asarray([[1], [1]]))  # bounds of state
    a_bnd = (-1, 1)  # bounds of control input
    w = np.asarray([[1e2], [1e2]])  # penalty weight for bound violations
    e_bnd = (-1e-1, 0)  # uniform noise bounds
    action_space = Box(*a_bnd, (nu,), np.float64)

    def reset(
        self,
        *,
        seed: Optional[int] = None,
        options: Optional[dict[str, Any]] = None,
    ) -> tuple[npt.NDArray[np.floating], dict[str, Any]]:
        """Resets the state of the LTI system."""
        super().reset(seed=seed, options=options)
        self.x = np.asarray([0, 0.15]).reshape(self.nx, 1)
        return self.x, {}

    def get_stage_cost(self, state: npt.NDArray[np.floating], action: float) -> float:
        """Computes the stage cost :math:`L(s,a)`."""
        lb, ub = self.x_bnd
        return (
            0.5
            * (
                np.square(state).sum()
                + 0.5 * action**2
                + self.w.T @ np.maximum(0, lb - state)
                + self.w.T @ np.maximum(0, state - ub)
            ).item()
        )

    def step(
        self, action: cs.DM
    ) -> tuple[npt.NDArray[np.floating], float, bool, bool, dict[str, Any]]:
        """Steps the LTI system."""
        action = float(action)
        x_new = self.A @ self.x + self.B * action
        x_new[0] += self.np_random.uniform(*self.e_bnd)
        r = self.get_stage_cost(self.x, action)
        self.x = x_new
        return x_new, r, False, False, {}


# %%
# Defining the MPC controller
# ---------------------------
# The second component is the MPC controller. We'll create a custom that, of course,
# inherits from :class:`csnlp.wrappers.Mpc`. The implementation is as follows, and it is
# in line with the theory presented above.


class LinearMpc(Mpc[cs.SX]):
    """A simple linear MPC controller."""

    horizon = 10
    discount_factor = 0.9
    learnable_pars_init = {
        "V0": np.asarray(0.0),
        "x_lb": np.asarray([0, 0]),
        "x_ub": np.asarray([1, 0]),
        "b": np.zeros(LtiSystem.nx),
        "f": np.zeros(LtiSystem.nx + LtiSystem.nu),
        "A": np.asarray([[1, 0.25], [0, 1]]),
        "B": np.asarray([[0.0312], [0.25]]),
    }

    def __init__(self) -> None:
        N = self.horizon
        gamma = self.discount_factor
        w = LtiSystem.w
        nx, nu = LtiSystem.nx, LtiSystem.nu
        x_bnd, a_bnd = LtiSystem.x_bnd, LtiSystem.a_bnd
        nlp = Nlp[cs.SX]()
        super().__init__(nlp, N)

        # parameters
        V0 = self.parameter("V0")
        x_lb = self.parameter("x_lb", (nx,))
        x_ub = self.parameter("x_ub", (nx,))
        b = self.parameter("b", (nx, 1))
        f = self.parameter("f", (nx + nu, 1))
        A = self.parameter("A", (nx, nx))
        B = self.parameter("B", (nx, nu))

        # variables (state, action, slack)
        x, _ = self.state("x", nx, bound_initial=False)
        u, _ = self.action("u", nu, lb=a_bnd[0], ub=a_bnd[1])
        s, _, _ = self.variable("s", (nx, N), lb=0)

        # dynamics
        self.set_affine_dynamics(A, B, c=b)

        # other constraints
        self.constraint("x_lb", x_bnd[0] + x_lb - s, "<=", x[:, 1:])
        self.constraint("x_ub", x[:, 1:], "<=", x_bnd[1] + x_ub + s)

        # objective
        A_init, B_init = self.learnable_pars_init["A"], self.learnable_pars_init["B"]
        S = cs.DM(dlqr(A_init, B_init, 0.5 * np.eye(nx), 0.25 * np.eye(nu))[1])
        gammapowers = cs.DM(gamma ** np.arange(N)).T
        self.minimize(
            V0
            + cs.bilin(S, x[:, -1])
            + cs.sum2(f.T @ cs.vertcat(x[:, :-1], u))
            + 0.5
            * cs.sum2(
                gammapowers * (cs.sum1(x[:, :-1] ** 2) + 0.5 * cs.sum1(u**2) + w.T @ s)
            )
        )

        # solver
        opts = {
            "expand": True,
            "print_time": False,
            "bound_consistency": True,
            "calc_lam_x": True,
            "calc_lam_p": False,
            "fatrop": {"max_iter": 500, "print_level": 0},
        }
        self.init_solver(opts, solver="fatrop", type="nlp")


# %%
# Behaviour policy
# ----------------
# Q-learning is a versatile algorithm that can learn in an off-policy fashion, i.e.,
# from data generated by a different policy than the one being learned. Since we wish to
# learn in an off-policy fashion, we need to create a behaviour policy to generate
# training data. To this end, we can employ a non-learning agent that uses the nominal
# MPC controller. However, we could use simpler policies (e.g., LQR), or even more
# complex, expert policies.


def get_behaviour_policy() -> Callable[[npt.NDArray[np.floating]], float]:
    """Returns a function that implements a behaviour policy."""
    nominal_agent = Agent(LinearMpc(), LinearMpc.learnable_pars_init.copy())

    def _policy(state: npt.NDArray[np.floating]) -> float:
        action, _ = nominal_agent.state_value(state, True)
        return float(action)

    return _policy


# %%
# Simulation
# ----------
# So far, we have only defined the classes for the environment, the MPC controller, and
# the off-policy behaviour policy. Now, it is time to integrate these and run the
# simulation. This is comprised of multiple steps, which are detailed below.
#
# 1. We instantiate the environment. Note how it is wrapped in
#    :class:`gymnasium.wrappers.TimeLimit` to impose a maximum amount of steps
#    to be simulated. Note also that the Q-learning policy will NOT interact with this
#    environment, but rather the behaviour policy will.
# 2. We instantiate the MPC controller and define its learnable parameters.
# 3. We instantiate the Q-learning agent. We pass different options to it, such as
#    the update strategy, the optimizer, the Hessian type, etc. For plotting purposes,
#    it is also wrapped such that the updated parameters are recorded. We also log
#    the progress of the simulation. Additionally, we evaluate the performance of the
#    agent periodically to monitor its progress (since it is trained from offline data).
# 4. We define the behaviour policy.
# 5. We run the simulation. Under the hood, the agent will sequentially collect data
#    from the other policy and update the parameters of the MPC controller.
# 6. Finally, we plot the results. The first plot shows the TD error and the periodic
#    evaluations of the learned policy. The second plot shows how each learnable
#    parameter evolves over time.

if __name__ == "__main__":
    # instantiate the env and wrap it
    env = TimeLimit(LtiSystem(), 100)

    # now build the MPC and the dict of learnable parameters
    seed = 69
    mpc = LinearMpc()
    learnable_pars = LearnableParametersDict(
        (
            LearnableParameter(name, val.shape, val)
            for name, val in mpc.learnable_pars_init.items()
        )
    )

    # build and wrap appropriately the agent
    agent = Evaluate(
        Log(
            RecordUpdates(
                LstdQLearningAgent(
                    mpc=mpc,
                    learnable_parameters=learnable_pars,
                    discount_factor=mpc.discount_factor,
                    update_strategy=1,
                    optimizer=NewtonMethod(learning_rate=5e-2),
                    hessian_type="approx",
                    record_td_errors=True,
                    remove_bounds_on_initial_action=True,
                )
            ),
            level=logging.DEBUG,
            log_frequencies={"on_episode_end": 1},
        ),
        eval_env=TimeLimit(LtiSystem(), 100),
        hook="on_episode_end",
        frequency=10,
        n_eval_episodes=5,
        seed=seed,
    )

    # before training, let's create a nominal non-learning agent which will be used to
    # generate expert rollout data. This data will then be used to train the off-policy
    # q-learning agent.
    behaviour_policy = get_behaviour_policy()

    # finally, we can launch the off-policy training by just passing the
    # `behaviour_policy` to the `train` method of the agent. This will use that policy
    # to generate learning data, instead of the agent's own policy.
    agent.train(env=env, episodes=100, behaviour_policy=behaviour_policy, seed=seed + 1)
    eval_returns = np.asarray(agent.eval_returns)

    # plot the results
    import matplotlib.pyplot as plt

    _, axs = plt.subplots(2, 1, constrained_layout=True)
    eval_returns_avg = eval_returns.mean(1)
    eval_returns_std = eval_returns.std(1)
    evals = np.arange(1, eval_returns.shape[0] + 1)
    axs[0].plot(agent.td_errors, "o", markersize=1)
    axs[0].set_ylabel("Time steps")
    axs[0].set_ylabel(r"$\tau$")
    patch = axs[1].fill_between(
        evals,
        eval_returns_avg - eval_returns_std,
        eval_returns_avg + eval_returns_std,
        alpha=0.3,
    )
    axs[1].plot(evals, eval_returns_avg, color=patch.get_facecolor())
    axs[1].set_ylabel("Evaluations")
    axs[1].set_ylabel(r"$\sum L$")

    _, axs = plt.subplots(3, 2, constrained_layout=True, sharex=True)
    updates_history = {k: np.asarray(v) for k, v in agent.updates_history.items()}
    axs[0, 0].plot(updates_history["b"])
    axs[0, 1].plot(np.stack([updates_history[n][:, 0] for n in ("x_lb", "x_ub")], -1))
    axs[1, 0].plot(updates_history["f"])
    axs[1, 1].plot(updates_history["V0"])
    axs[2, 0].plot(updates_history["A"].reshape(-1, 4))
    axs[2, 1].plot(updates_history["B"].squeeze())
    axs[0, 0].set_ylabel("$b$")
    axs[0, 1].set_ylabel("$x_1$")
    axs[1, 0].set_ylabel("$f$")
    axs[1, 1].set_ylabel("$V_0$")
    axs[2, 0].set_ylabel("$A$")
    axs[2, 1].set_ylabel("$B$")
    plt.show()
