"""As it will be clear from the inheritance diagram in :ref:`module_reference_agents`,
all agents are derived from mixin classes that define callbacks and manage hooks
attached to these callbacks. This system allows not only the user to customize the
behaviour of a derived agent every time a callback is triggered, but also to easily
implement and manage all those events and quantities that need to be scheduled during
training and evaluation. Some examples of such events are the decay of the learning rate
or the exploration chances, or when and with which frequency to invoke an update of the
MPC parametrization. Here we list the classes that enable this system, but for an
introduction to the callbacks and how to use them, see :ref:`user_guide_callbacks`."""
from copy import deepcopy
from typing import Any, Callable, Literal, Optional, TypeVar, Union
import numpy as np
import numpy.typing as npt
from gymnasium import Env
from .errors import raise_or_warn_on_mpc_failure, raise_or_warn_on_update_failure
ObsType = TypeVar("ObsType")
ActType = TypeVar("ActType")
def _failure_msg(
category: Literal["mpc", "update"],
name: str,
episode: int,
timestep: Optional[int],
status: str,
) -> str:
"""Internal utility for composing message for mpc/update failure."""
C = category.title()
if timestep is None:
return f"{C} failure of {name} at episode {episode}, status: {status}."
return (
f"{C} failure of {name} at episode {episode}, time {timestep}, "
f"status: {status}."
)
[docs]
class CallbackMixin:
"""A class with the particular purpose of creating, storing and deleting hooks
attached to callbacks.
Notes
-----
A particular note must be included about the `__setstate__` method. When this method
is used (e.g., via :func:`copy.deepcopy`), the hooks are not copied from the old
copy. The reason is that the old copy/state's hooks are likely to be pointing to
methods belonging to old objects' instances. Of course, this is an issue, because if
the old hooks are used, the new object (created from the state) would reference
callbacks belonging to the old object. For this reasons, hooks are not copied;
instead, the method :meth:`_establish_callback_hooks` is automatically called to
re-establish these, but with respect to the new object(s).
"""
def __init__(self) -> None:
self._hooks: dict[str, dict[str, Callable[..., None]]] = {}
def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy()
# remove hooks (otherwise, new copies will still be calling the old object)
if "_hooks" in state:
state["_hooks"] = {}
if "_hooked_callbacks" in state:
state["_hooked_callbacks"] = {}
return state
def __setstate__(
self,
state: Union[
None, dict[str, Any], tuple[Optional[dict[str, Any]], dict[str, Any]]
],
) -> None:
if isinstance(state, tuple) and len(state) == 2:
state, slotstate = state
else:
slotstate = None
if state is not None:
# remove hooks (otherwise, new copies will still be calling the old object)
if "_hooks" in state:
state["_hooks"] = {}
if "_hooked_callbacks" in state:
state["_hooked_callbacks"] = {}
self.__dict__.update(state)
if slotstate is not None:
for key, value in slotstate.items():
setattr(self, key, value)
# re-establish hooks
self._establish_callback_hooks()
def __deepcopy__(self, memo: dict[int, Any]) -> "CallbackMixin":
cls = self.__class__
other = cls.__new__(cls)
memo[id(self)] = other
for key, value in self.__dict__.items():
if key == "_hooks" or key == "_hooked_callbacks":
setattr(other, key, {})
else:
setattr(other, key, deepcopy(value, memo))
# re-establish hooks
other._establish_callback_hooks()
return other
def _run_hooks(self, method_name: str, *args: Any) -> None:
"""Runs the internal hooks attached to the given method."""
if hooks := self._hooks.get(method_name):
for hook in hooks.values():
hook(*args)
def _establish_callback_hooks(self) -> None:
"""This method must be used to perform the connections between callbacks and any
invokable method (hook). If the object has no hooks, then this method does
nothing."""
def _hook_callback(
self, attachername: str, callbackname: str, func: Callable[..., None]
) -> None:
"""Hooks a function to be called each time a callback is invoked.
Parameters
----------
attachername : str
The name of the object requesting the hook. Has only info purposes.
callbackname : str
Name of the callback to hook to, i.e., the target of the hooking.
func : Callable
function to be called when the callback is invoked. Must accept the same
input arguments as the callback it is hooked to. The return value is
discarded.
Raises
------
ValueError
If an hook with name ``attachername`` is already attached to this callback.
"""
hook_dict = self._hooks.setdefault(callbackname, {})
if attachername in hook_dict:
raise ValueError(
f"Hook '{attachername}' already attached to callback '{callbackname}'."
)
hook_dict[attachername] = func
[docs]
class AgentCallbackMixin(CallbackMixin):
"""Class with callbacks for agents.
In particular, this class defines the following callbacks:
- :meth:`on_mpc_failure`, invoked when an MPC solver fails
- :meth:`on_validation_start`, invoked when validation starts (see
:meth:`mpcrl.Agent.evaluate`)
- :meth:`on_validation_end`, invoked when validation ends
- :meth:`on_episode_start`, invoked when a training or validation episode starts
- :meth:`on_episode_end`, invoked when a training or validation episode ends
- :meth:`on_env_step`, invoked when a training or validation episode steps, i.e.,
after :func:`gymnasium.Env.step`
- :meth:`on_timestep_end`, invoked when the current simulation's time step reaches
an end, i.e., after having stepped the environment and done all the internal
computations according to the algorithm.
"""
[docs]
def on_mpc_failure(
self, episode: int, timestep: Optional[int], status: str, raises: bool
) -> None:
"""Callback in case of failure of the MPC solver.
Parameters
----------
episode : int
Number of the episode when the failure happened.
timestep : int or None
Timestep of the current episode when the failure happened. Can be ``None``,
in case the error occurs inter-episodically or no notion of time step is
available.
status : str
Status of the solver that failed.
raises : bool
Whether the failure should be raised as exception (``True``) or as a warning
(``False``).
"""
name: str = getattr(self, "name", "agent")
raise_or_warn_on_mpc_failure(
_failure_msg("mpc", name, episode, timestep, status),
raises,
)
self._run_hooks("on_mpc_failure", episode, timestep, status, raises)
[docs]
def on_validation_start(self, env: Env[ObsType, ActType]) -> None:
"""Callback called at the beginning of the validation process (see
:meth:`mpcrl.Agent.evaluate`)
Parameters
----------
env : gym env
A gym environment where the agent is being validated on.
"""
self._run_hooks("on_validation_start", env)
[docs]
def on_validation_end(
self, env: Env[ObsType, ActType], returns: npt.NDArray[np.floating]
) -> None:
"""Callback called at the end of the validation process (see
:meth:`mpcrl.Agent.evaluate`).
Parameters
----------
env : gym env
A gym environment where the agent has been validated on.
returns : array of double
Each episode's cumulative rewards.
"""
self._run_hooks("on_validation_end", env, returns)
[docs]
def on_episode_start(
self, env: Env[ObsType, ActType], episode: int, state: ObsType
) -> None:
"""Callback called at the beginning of each episode in the training or
validation process (see :meth:`mpcrl.Agent.evaluate`,
:meth:`mpcrl.LearningAgent.train` and
:meth:`mpcrl.LearningAgent.train_offpolicy`).
Parameters
----------
env : gym env
A gym environment where the agent is being trained on.
episode : int
Number of the training episode.
state : ObsType
Starting state for this episode.
"""
self._run_hooks("on_episode_start", env, episode, state)
[docs]
def on_episode_end(
self, env: Env[ObsType, ActType], episode: int, rewards: float
) -> None:
"""Callback called at the end of each episode in the training or evaluation
process (see :meth:`mpcrl.Agent.evaluate`, :meth:`mpcrl.LearningAgent.train` and
:meth:`mpcrl.LearningAgent.train_offpolicy`).
Parameters
----------
env : gym env
A gym environment where the agent is being trained on.
episode : int
Number of the training episode.
rewards : float
Cumulative rewards for this episode.
"""
self._run_hooks("on_episode_end", env, episode, rewards)
[docs]
def on_env_step(
self, env: Env[ObsType, ActType], episode: int, timestep: int
) -> None:
"""Callback called after each call to :func:`gymnasium.Env.step`.
Parameters
----------
env : gym env
A gym environment where the agent is being trained on.
episode : int
Number of the training episode.
timestep : int
Time instant of the current training episode.
"""
self._run_hooks("on_env_step", env, episode, timestep)
[docs]
def on_timestep_end(
self, env: Env[ObsType, ActType], episode: int, timestep: int
) -> None:
"""Callback called at the end of each time iteration. It is called with the same
frequency as :meth:`on_env_step`, but with different timing.
Parameters
----------
env : gym env
A gym environment where the agent is being trained on.
episode : int
Number of the training episode.
timestep : int
Time instant of the current training episode.
"""
self._run_hooks("on_timestep_end", env, episode, timestep)
[docs]
class LearningAgentCallbackMixin(AgentCallbackMixin):
"""Class with callbacks for learning agents.
In particular, this class defines, on top of the callbacks from
:class:`AgentCallbackMixin`, the additional following callbacks:
- :meth:`on_update_failure`, invoked when an update of the parametrization fails
- :meth:`on_training_start`, invoked when training starts (see
:meth:`mpcrl.LearningAgent.train` and :meth:`mpcrl.LearningAgent.train_offpolicy`)
- :meth:`on_training_end`, invoked when training ends
- :meth:`on_update`, invoked after each update of the parametrization.
"""
[docs]
def on_update_failure(
self, episode: int, timestep: Optional[int], errormsg: str, raises: bool
) -> None:
"""Callback in case of update failure.
Parameters
----------
episode : int
Number of the episode when the failure happened.
timestep : int or None
Timestep of the current episode when the failure happened. Can be ``None``
in case the update occurs inter-episodically or no notion of time step is
available.
errormsg : str
Error message of the update failure.
raises : bool
Whether the failure should be raised as exception (``True``) or as a warning
(``False``).
"""
name: str = getattr(self, "name", "agent")
raise_or_warn_on_update_failure(
_failure_msg("update", name, episode, timestep, errormsg),
raises,
)
self._run_hooks("on_update_failure", episode, timestep, errormsg, raises)
[docs]
def on_training_start(self, env: Env[ObsType, ActType]) -> None:
"""Callback called at the beginning of the training process.
Parameters
----------
env : gym env
A gym environment where the agent is being trained on.
"""
self._run_hooks("on_training_start", env)
[docs]
def on_training_end(
self, env: Env[ObsType, ActType], returns: npt.NDArray[np.floating]
) -> None:
"""Callback called at the end of the training process.
Parameters
----------
env : gym env
A gym environment where the agent has been trained on.
returns : array of double
Each episode's cumulative rewards.
"""
self._run_hooks("on_training_end", env, returns)
[docs]
def on_update(self) -> None:
"""Callback called after each :func:`mpcrl.LearningAgent.update`.
This callback is especially useful for, e.g., decaying exploration probabilities
or learning rates."""
self._run_hooks("on_update")