Source code for mpcrl.agents.common.learning_agent

from abc import ABC, abstractmethod
from collections.abc import Collection, Iterable
from typing import Any, Generic, Optional, TypeVar, Union

import numpy as np
import numpy.typing as npt
from gymnasium import Env
from gymnasium.spaces import Box

from ...core.callbacks import LearningAgentCallbackMixin
from ...core.experience import ExperienceReplay
from ...core.parameters import LearnableParametersDict
from ...core.update import UpdateStrategy
from ...util.seeding import RngType, mk_seed
from .agent import ActType, Agent, ObsType, SymType, _update_dicts

ExpType = TypeVar("ExpType")


[docs] class LearningAgent( Agent[SymType], LearningAgentCallbackMixin, ABC, Generic[SymType, ExpType] ): r"""Base abstract class for a learning agent with MPC as policy provider. The main method :meth:`update`, which is called to update the learnable parameters of the MPC according to the underlying learning methodology (e.g., Bayesian Optimization, RL, etc.), is abstract and must be implemented by inheriting classes based on their learning algorithm. Aside of :meth:`update`, this class also provides also the basic skeleton for both on-policy and off-policy learning, which require further ad-hoc implementations. For on-policy learning, the :meth:`train` method should be run, which requires the implementation of :meth:`train_one_episode`. For off-policy learning, run :meth:`train_offpolicy` and implement :meth:`train_one_rollout` instead. At least one of the two implementations is required in the inheriting class, depending on the learning algorithm. Some algorithms might support both. Parameters ---------- update_strategy : UpdateStrategy or int The strategy used to decide which frequency to update the mpc parameters with. If an ``int`` is passed, then the default strategy that updates every ``n`` env's steps is used (where ``n`` is the argument passed); otherwise, an instance of :class:`core.update.UpdateStrategy` can be passed to specify the desired strategy in more details. learnable_parameters : :class:`core.parameters.LearnableParametersDict` A special dict containing the learnable parameters of the MPC (usually referred to as :math:`\theta`), together with their bounds and values. This dict is complementary to :attr:`fixed_parameters`, which contains the MPC parameters that are not learnt by the agent. experience : int or ExperienceReplay, optional The container for experience replay memory. If ``None`` is passed, then a memory with unitary length is created, i.e., it keeps only the latest memory transition. If an integer ``n`` is passed, then a memory with the length ``n`` is created and with sample size ``n``. Otherwise, pass an instance of :class:`core.experience.ExperienceReplay` to specify the requirements in more details. kwargs Additional arguments to be passed to :class:`Agent`. Notes ----- This class makes no assumptions on the learning methodology used to update the MPC's learnable parameters. This could be either gradient-based or gradient-free, but the logic implemented in this class should largely remain untouched. """ def __init__( self, update_strategy: Union[int, UpdateStrategy], learnable_parameters: LearnableParametersDict, experience: Union[None, int, ExperienceReplay[ExpType]] = None, **kwargs: Any, ) -> None: Agent.__init__(self, **kwargs) LearningAgentCallbackMixin.__init__(self) self._learnable_pars = self._reorder_learnable_parameters(learnable_parameters) if experience is None: experience = ExperienceReplay(maxlen=1) elif isinstance(experience, int): experience = ExperienceReplay(maxlen=experience, sample_size=experience) self._experience = experience if not isinstance(update_strategy, UpdateStrategy): update_strategy = UpdateStrategy(update_strategy, "on_timestep_end") self._update_strategy = update_strategy self._raises: bool self._is_training = False self._establish_callback_hooks() @property def experience(self) -> ExperienceReplay[ExpType]: """Gets the experience replay memory of the agent.""" return self._experience @property def update_strategy(self) -> UpdateStrategy: """Gets the update strategy of the agent.""" return self._update_strategy @property def learnable_parameters(self) -> LearnableParametersDict: """Gets the parameters of the MPC that can be learnt by the agent.""" return self._learnable_pars
[docs] def reset(self, seed: RngType = None) -> None: super().reset(seed) self.experience.reset(seed)
[docs] def store_experience(self, item: ExpType) -> None: """Stores the given item in the agent's :attr:`experience` for later usage in updating the parametrization. Parameters ---------- item : ExpType Item to be stored in memory. """ self._experience.append(item)
[docs] def evaluate(self, *args: Any, **kwargs: Any) -> npt.NDArray[np.floating]: self._is_training = False return super().evaluate(*args, **kwargs)
[docs] def train( self, env: Env[ObsType, ActType], episodes: int, seed: RngType = None, raises: bool = True, env_reset_options: Optional[dict[str, Any]] = None, ) -> npt.NDArray[np.floating]: """On-policy training of the agent on an environment. Parameters ---------- env : Env[ObsType, ActType] The gym environment where to train the agent on. episodes : int Number of training episodes. seed : None, int, array_like of ints, SeedSequence, BitGenerator, Generator Seed for the agent's and env's random number generator. By default ``None``. raises : bool, optional If ``True``, when any of the MPC solver runs fails, or when an update fails, the corresponding error is raised; otherwise, only a warning is raised. env_reset_options : dict, optional Additional information to specify how the environment is reset at each evalution episode (optional, depending on the specific environment). Returns ------- array of doubles The cumulative returns for each training episode. Raises ------ MpcSolverError or MpcSolverWarning Raises the error or the warning (depending on ``raises``) if any of the MPC solvers fail. UpdateError or UpdateWarning Raises the error or the warning (depending on ``raises``) if the update fails. """ if hasattr(env, "action_space"): assert isinstance(env.action_space, Box), "Env action space must be a Box," rng = np.random.default_rng(seed) self.reset(rng) self._is_training = True self._raises = raises returns = np.zeros(episodes, float) self.on_training_start(env) for episode in range(episodes): state, _ = env.reset(seed=mk_seed(rng), options=env_reset_options) self.on_episode_start(env, episode, state) r = self.train_one_episode(env, episode, state, raises) self.on_episode_end(env, episode, r) returns[episode] = r self.on_training_end(env, returns) return returns
[docs] def train_one_episode( self, env: Env[ObsType, ActType], episode: int, init_state: ObsType, raises: bool = True, ) -> float: """On-policy training of the agent on an environment for one single episode. Parameters ---------- env : Env[ObsType, ActType] The gym environment where to train the agent on. episode : int Number of the current training episode. init_state : observation type Initial state/observation of the environment. raises : bool, optional If ``True``, when any of the MPC solver runs fails, or when an update fails, the corresponding error is raised; otherwise, only a warning is raised. Returns ------- float The cumulative rewards for this training episode. Raises ------ MpcSolverError or MpcSolverWarning Raises the error or the warning (depending on ``raises``) if any of the MPC solvers fail. UpdateError or UpdateWarning Raises the error or the warning (depending on ``raises``) if the update fails. """ raise NotImplementedError( f"{self.__class__.__name__} does not implement `train_one_episode` for " "on-policy learning." )
[docs] def train_offpolicy( self, episode_rollouts: Iterable[Iterable[Any]], seed: RngType = None, raises: bool = True, ) -> None: """Off-policy training of the agent on an environment. Parameters ---------- episode_rollouts : iterable of iterables of any An iterable of episodical rollouts generated in an off-policy fashion. Each rollout is itself a sequence of transitions, e.g., SARSA tuples. In other words, `episode_rollouts` is a sequence of sequences of tuples. However, in general, its nature and the tuples' can widely differ from learning algorithm to learning algorithm. seed : None, int, array_like of ints, SeedSequence, BitGenerator, Generator Seed for the agent's random number generator. By default ``None``. raises : bool, optional If ``True``, when any of the MPC solver runs fails, or when an update fails, the corresponding error is raised; otherwise, only a warning is raised. Raises ------ MpcSolverError or MpcSolverWarning Raises the error or the warning (depending on ``raises``) if any of the MPC solvers fail. UpdateError or UpdateWarning Raises the error or the warning (depending on ``raises``) if the update fails. """ rng = np.random.default_rng(seed) self.reset(rng) self._raises = raises env_proxy = "off-policy" self._is_training = True self.on_training_start(env_proxy) for episode, rollout in enumerate(episode_rollouts): self.on_episode_start(env_proxy, episode, float("nan")) self.train_one_rollout(rollout, episode, raises) self.on_episode_end(env_proxy, episode, float("nan")) self.on_training_end(env_proxy, np.empty(0))
[docs] def train_one_rollout( self, rollout: Iterable[Any], episode: int, raises: bool = True ) -> None: """Train the agent in an off-policy manner on the given rollout. Parameters ---------- rollout : iterable of any Rollout, i.e., a sequence of transitions generated off-policy, e.g., SARSA tuples. However, in general, these tuples can be of different nature, depending on the specific learning algorithm. raises : bool, optional If ``True``, when any of the MPC solver runs fails, or when an update fails, the corresponding error is raised; otherwise, only a warning is raised. Raises ------ MpcSolverError or MpcSolverWarning Raises the error or the warning (depending on ``raises``) if any of the MPC solvers fail. UpdateError or UpdateWarning Raises the error or the warning (depending on ``raises``) if the update fails. """ raise NotImplementedError( f"{self.__class__.__name__} does not implement `train_offpolicy` for " "off-policy learning." )
[docs] @abstractmethod def update(self) -> Optional[str]: r"""Updates the learnable parameters (usually referred to as :math:`\theta`) of the MPC according to the agent's learning algorithm. Returns ------- errormsg : str or None In case the update fails, an error message is returned to be raised as error or warning; otherwise, ``None`` is returned. """
def _reorder_learnable_parameters( self, dict_: LearnableParametersDict ) -> LearnableParametersDict: """Reorders the learnable parameters of the MPC according to their creation order.""" reordered = [dict_.pop(name) for name in self.V.parameters if name in dict_] assert not dict_, ( "Not all learnable parameters could be reordered. " "Please check for spurious learnable parameters in `learnable_parameters`." ) dict_.update(reordered) return dict_ def _establish_callback_hooks(self) -> None: super()._establish_callback_hooks() # hook exploration (only if necessary) exploration_hook = self._exploration.hook if exploration_hook is not None: self._hook_callback( repr(self._exploration), exploration_hook, self._exploration.step ) # hook updates (always necessary) update_hook = self._update_strategy.hook self._hook_callback( repr(self._update_strategy), update_hook, self._check_and_perform_update ) def _check_and_perform_update( self, _: Env[ObsType, ActType], episode: int, timestep_or_return: float ) -> None: """Internal utility to check if an update is due and perform it.""" if not self._is_training or not self._update_strategy.can_update(): return update_msg = self.update() if update_msg is not None: timestep = ( timestep_or_return if isinstance(timestep_or_return, int) else None ) self.on_update_failure(episode, timestep, update_msg, self._raises) self.on_update() def _get_parameters( self, overwrite_fixed_pars: Union[ None, dict[str, npt.ArrayLike], Collection[dict[str, npt.ArrayLike]] ] = None, ) -> Union[None, dict[str, npt.ArrayLike], Collection[dict[str, npt.ArrayLike]]]: """Internal utility to retrieve parameters of the MPC in order to solve it. :class:`LearningAgent` returns both fixed and learnable parameters. Parameters ---------- overwrite_fixed_pars : dict of (str, array_like), or collection of, optional If not ``None``, this argument is used instead of :attr:`fixed_parameters` to retrieve the fixed parameters of the MPC. """ learnable_pars = self._learnable_pars.value_as_dict fixed_pars = ( self.fixed_parameters if overwrite_fixed_pars is None else overwrite_fixed_pars ) if fixed_pars is None: return learnable_pars if isinstance(fixed_pars, dict): fixed_pars.update(learnable_pars) return fixed_pars return tuple(_update_dicts(fixed_pars, learnable_pars))