Source code for mpcrl.core.update

"""The update strategy is likely to be one of the most important aspects that a designer
has to consider when training a Reinforcement Learning agent. When instantiating an
agent, via :class:`UpdateStrategy`, the user can specify when and with which frequency
to update the agent's MPC parametrization (e.g., at the end of every training episode,
or every ``N`` time steps), as well as the number of updates to skip at the beginning
(in case we need to wait for experience buffers to properly fill first). See
:ref:`user_guide_updating` for a more thorough explanation."""

from collections.abc import Iterator
from itertools import chain, repeat
from typing import Literal

from ..util.iters import bool_cycle


[docs] class UpdateStrategy: """A class holding information on the update strategy to be used by the learning algorithm. Parameters ---------- frequency : int Frequency at which, each time the hook is called, an update should be carried out. skip_first : int, optional Skips the first ``skip_first`` updates. By default ``0``, so no update is skipped. This is useful when, e.g., the agent has to wait for the experience buffer to be filled before starting to update. hook : {"on_episode_end", "on_timestep_end"}, optional Specifies to which callback to hook, i.e., when to check if an update is due according to the given frequency. The options are: - ``"on_episode_end"`` checks if an update is due after each episode ends - ``"on_timestep_end"`` checks for an update after each simulation's time step. By default, ``"on_timestep_end"`` is selected. """ def __init__( self, frequency: int, hook: Literal["on_episode_end", "on_timestep_end"] = "on_timestep_end", skip_first: int = 0, ) -> None: self.frequency = frequency self.hook = hook self._update_cycle = chain( repeat(False, skip_first * frequency), bool_cycle(frequency) )
[docs] def can_update(self) -> bool: """Returns whether an update must be carried out now, at the current instant, according to the specified strategy. Notes ----- This methods steps the internal iterators to check whether an update is due with :func:`next`. This means that calling this method has a side effect on the state of these iterators, and calling immediately again can result in a different outcome. Returns ------- bool ``True`` if the agent should update according to this strategy; otherwise, ``False``. """ return next(self._update_cycle)
def __iter__(self) -> Iterator[bool]: return self._update_cycle def __next__(self) -> bool: return next(self._update_cycle) def __repr__(self) -> str: return f"{self.__class__.__name__}(frequency={self.frequency},hook={self.hook})" def __str__(self) -> str: return self.__repr__()