Source code for mpcrl.wrappers.agents.record_updates
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
from ...agents.common.learning_agent import ExpType, LearningAgent
from ...util.iters import bool_cycle
from .wrapper import LearningWrapper, SymType
[docs]
class RecordUpdates(LearningWrapper[SymType, ExpType]):
"""Wrapper for recording the history of updated parametrizations by the learning
agent.
In other words, it records the new value of the parameters in
:attr:`mpcrl.LearningAgent.learnable_parameters` after every call to
:meth:`mpcrl.LearningAgent.update`. This information can be retrieved from the
attribute :attr:`updates_history`.
Parameters
----------
agent : LearningAgent or subclass
The agent whose updates need recording.
frequency : int, optional
The frequency of recording the updates. If the frequency is set to ``1``, all
updates are recorded. If the frequency is set to ``2``, every second update is
recorded, and so on. By default, ``1``. Note that the first values of the
parameters are always recorded.
"""
def __init__(
self,
agent: LearningAgent[SymType, ExpType],
frequency: int = 1,
) -> None:
super().__init__(agent)
self._record_cycle = bool_cycle(frequency)
self.updates_history: dict[str, list[npt.NDArray[np.floating]]] = {
p.name: [p.value] for p in agent.learnable_parameters.values()
}
def _on_update(self, *_: object, **__: object) -> None:
if next(self._record_cycle):
for par in self.agent.learnable_parameters.values():
self.updates_history[par.name].append(par.value)
def _establish_callback_hooks(self) -> None:
super()._establish_callback_hooks()
# connect the agent's on_update callback to this wrapper storing action
self._hook_callback(repr(self), "on_update", self._on_update)