from typing import Any, Callable, Generic, Union
from ...agents.common.agent import Agent, SymType
from ...agents.common.learning_agent import ExpType, LearningAgent
from ...core.callbacks import CallbackMixin
[docs]
class Wrapper(CallbackMixin, Generic[SymType]):
"""Wraps an instance of :class:`mpcrl.Agent` to allow a modular transformation of
its behaviour. This class is the base class for all wrappers. The subclass could
override some methods to change the behavior of the original agent without touching
the original code.
Parameters
----------
agent : Agent or subclass
The agent to wrap.
"""
def __init__(self, agent: Agent[SymType]) -> None:
CallbackMixin.__init__(self)
del self._hooks # keep only one dict of hooks, i.e., the agent's one
self.agent = agent
self._hooked_callbacks: dict[str, list[str]] = {}
self._establish_callback_hooks()
@property
def unwrapped(self) -> Union[Agent[SymType], LearningAgent[SymType, ExpType]]:
"""Returns the original agent wrapped by this wrapper."""
return self.agent.unwrapped
[docs]
def is_wrapped(self, wrapper_type: type["Wrapper[SymType]"]) -> bool:
"""Gets whether the agent instance is wrapped or not by the wrapper type.
Parameters
----------
wrapper_type : type of Wrapper
Type of wrapper to check if the agent is wrapped with.
Returns
-------
bool
``True`` if wrapped by an instance of ``wrapper_type``; ``False``,
otherwise.
"""
if isinstance(self, wrapper_type):
return True
return self.agent.is_wrapped(wrapper_type)
def _hook_callback(
self, attachername: str, callbackname: str, func: Callable[..., None]
) -> None:
# store the callback id for later removal via `detach_wrapper(s)`
self._hooked_callbacks.setdefault(callbackname, []).append(attachername)
self.unwrapped._hook_callback(attachername, callbackname, func)
[docs]
def detach_wrapper(
self, recursive: bool = False
) -> Union[Agent[SymType], LearningAgent[SymType, ExpType], "Wrapper[SymType]"]:
"""Detaches the wrapper from the agent, returning the unwrapped agent. De facto,
this method detaches all the hooks attached by this wrapper.
Parameters
----------
recursive : bool, optional
If ``True``, detaches all the wrappers around the agent recursively.
Returns
-------
Agent or LearningAgent or Wrapper
Returns the wrapped agent (or other wrapper) instance. This instance has no
more active hooks attached by this wrapper. If ``recursive=True``, all the
wrappers around the agent and their hooks are detached.
Notes
-----
Detaching a wrapper is useful when you want to make sure that the wrapper's
hooked callback cannot modify the behaviour or data of the agent, for example,
after learning is done and you want to save and evaluate your learnt policy.
"""
hooks = self.unwrapped._hooks
hooked_callbacks = self._hooked_callbacks
# for each callback type, remove the hooks attached by this wrapper
for callbackname, attachernames in hooked_callbacks.items():
hook_group = hooks[callbackname]
for attachername in attachernames:
hook_group.pop(attachername)
# if the callback has no more hooks, remove it
if not hook_group:
hooks.pop(callbackname)
# clear hooked callbacks tracking
hooked_callbacks.clear()
return (
self.agent.detach_wrapper(True)
if recursive
and hasattr(self.agent, "detach_wrapper")
and callable(self.agent.detach_wrapper)
else self.agent
)
def __getattr__(self, name: str) -> Any:
"""Reroutes attributes to the wrapped agent instance."""
if name.startswith("_"):
raise AttributeError(f"Accessing private attribute '{name}' is prohibited.")
return getattr(self.agent, name)
def __str__(self) -> str:
return f"<{self.__class__.__name__}{self.agent.__str__()}>"
def __repr__(self) -> str:
return f"<{self.__class__.__name__}{self.agent.__repr__()}>"
[docs]
class LearningWrapper(Wrapper[SymType], Generic[SymType, ExpType]):
"""A :class:`Wrapper` subclass dedicated to wrapping instances of
:class:`mpcrl.LearningAgent`."""
def __init__(self, agent: LearningAgent[SymType, ExpType]) -> None:
Wrapper.__init__(self, agent)
self.agent: LearningAgent[SymType, ExpType]
@property
def unwrapped(self) -> LearningAgent[SymType, ExpType]:
return self.agent.unwrapped