Source code for mpcrl.wrappers.agents.evaluate

from typing import Any, Literal, Optional, TypeVar

import numpy as np
import numpy.typing as npt
from gymnasium import Env

from ...agents.common.learning_agent import ExpType, LearningAgent
from ...util.iters import bool_cycle
from ...util.seeding import RngType, mk_seed
from .wrapper import LearningWrapper, SymType

ObsType = TypeVar("ObsType")
ActType = TypeVar("ActType")


[docs] class Evaluate(LearningWrapper[SymType, ExpType]): """Wrapper for evaluating an agent during training. On the given hook and with the given frequency, this wrapper automatically evaluates the agent on the specified environment by calling the agent's method :meth:`mpcrl.Agent.evaluate`. The resulting evaluation returns are stored in the attribute :attr:`eval_returns`. Parameters ---------- agent : LearningAgent The learning agent to be evaluated by the wrapper. eval_env : gymnasium.Env A gym environment to evaluate the agent in. hook : {"on_episode_end", "on_timestep_end", "on_update"} Hook to trigger the evaluation. The evaluation will be triggered every ``frequency`` invokations of the specified hook. frequency : int Frequency of the evaluation. n_eval_episodes : int, optional How many episodes to evaluate the agent for, by default ``1``. eval_immediately : bool, optional Whether to evaluate the agent immediately after the wrapper is created, by default ``False``. deterministic : bool, optional Whether the agent should act deterministically; by default, ``True``. seed : None, int, array_like of ints, SeedSequence, BitGenerator, Generator Agent's and each env's random seeds for the evaluation. raises : bool, optional If ``True``, when any of the MPC solver runs fails, or when an update fails, the corresponding error is raised; otherwise, only a warning is raised. env_reset_options : dict, optional Additional information to specify how the environment is reset at each evalution episode (optional, depending on the specific environment). fix_seed : bool, optional If ``True``, the seed is fixed and the same seed is used for all evaluations. """ def __init__( self, agent: LearningAgent[SymType, ExpType], eval_env: Env[ObsType, ActType], hook: Literal["on_episode_end", "on_timestep_end", "on_update"], frequency: int, n_eval_episodes: int = 1, eval_immediately: bool = False, *, deterministic: bool = True, seed: RngType = None, raises: bool = True, env_reset_options: Optional[dict[str, Any]] = None, fix_seed: bool = False, ) -> None: self.eval_env = eval_env self._hook = hook self._n_eval_episodes = n_eval_episodes self._deterministic = deterministic np_random = np.random.default_rng(seed) self._seed = mk_seed(np_random) if fix_seed else np_random self._raises = raises self._env_reset_options = env_reset_options self._keep_seed_fixed = fix_seed self._eval_cycle = bool_cycle(frequency) self.eval_returns: list[npt.NDArray[np.floating]] = [] self._is_eval_in_progress = False super().__init__(agent) if eval_immediately: self._evaluate(force=True) def _evaluate(self, *_: Any, **kwargs: Any) -> None: # we always return if an evaluation is already in progress to avoid reentrancy if self._is_eval_in_progress: return # we return also if: # - the agent is not training (we do not want this hook to fire on .evaluate) # - or the cycle is not at the evaluation point # unless we have forced an evaluation in __init__ via `eval_immediately=True` forced = kwargs.get("force", False) unwrapped_agent = self.agent.unwrapped is_training = unwrapped_agent._is_training if not forced and (not is_training or not next(self._eval_cycle)): return self._is_eval_in_progress = True try: self.eval_returns.append( self.agent.evaluate( self.eval_env, self._n_eval_episodes, self._deterministic, self._seed, self._raises, self._env_reset_options, ) ) finally: self._is_eval_in_progress = False unwrapped_agent._is_training = is_training def _establish_callback_hooks(self) -> None: super()._establish_callback_hooks() self._hook_callback(repr(self), self._hook, self._evaluate)