from collections import deque
from collections.abc import Iterable
from typing import Any, Optional, SupportsFloat, TypeVar
from warnings import warn
from gymnasium import Env, Wrapper, utils
ObsType = TypeVar("ObsType")
ActType = TypeVar("ActType")
def _compact_dicts(
dicts: Iterable[dict[str, Any]], fill_value: Any = None
) -> dict[str, list[Any]]:
"""Compacts an iterable of dictionaries into a single dict with lists of entries. If
an entry is missing for any given dict, ``fill_value`` is used in place.
Parameters
----------
dicts : iterable of dicts of (str, any)
Dictionaries to be compacted into a single one.
fill_value : any, optional
The value to be used to fill in missing values.
Returns
-------
dict of (str, list of any | fill_value)
A unique dictionary made from all the passed dicts.
"""
out: dict[str, list[Any]] = {}
for i, dict_ in enumerate(dicts):
for k, v in dict_.items():
if k in out:
out[k].append(v)
else:
out[k] = [fill_value] * i + [v]
for k in out.keys() - dict_.keys():
out[k].append(fill_value)
return out
[docs]
class MonitorInfos(
Wrapper[ObsType, ActType, ObsType, ActType], utils.RecordConstructorArgs
):
"""This wrapper keeps track of the infos that are generated by calls to
:func:`gymnasium.Env.reset` and :func:`gymnasium.Env.step`.
Parameters
----------
env : Env[ObsType, ActType]
The environment to apply the wrapper to.
deque_size : int, optional
The maximum number of episodes to hold as historical data in the internal
deques. By default, ``None``, i.e., unlimited.
"""
def __init__(
self, env: Env[ObsType, ActType], deque_size: Optional[int] = None
) -> None:
utils.RecordConstructorArgs.__init__(self, deque_size=deque_size)
Wrapper.__init__(self, env)
# long-term storages
self.reset_infos: deque[dict[str, Any]] = deque(maxlen=deque_size)
self.step_infos: deque[list[dict[str, Any]]] = deque(maxlen=deque_size)
# current-episode storages
self.ep_step_infos: list[dict[str, Any]] = []
[docs]
def reset(
self, *, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None
) -> tuple[ObsType, dict[str, Any]]:
observation, info = super().reset(seed=seed, options=options)
self.ep_step_infos.clear()
self.reset_infos.append(info)
return observation, info
[docs]
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
obs, reward, terminated, truncated, info = super().step(action)
self.ep_step_infos.append(info)
if terminated or truncated:
self.step_infos.append(self.ep_step_infos.copy())
self.ep_step_infos.clear()
return obs, reward, terminated, truncated, info
[docs]
def finalized_reset_infos(self, fill_value: Any = None) -> dict[str, list[Any]]:
"""Returns a compacted final dictionary containing the reset infos. Missing
values are filled automatically.
Parameters
----------
fill_value : Any, optional
The value to be used to fill in missing values.
Returns
-------
dict of (str, list)
A unique dict containing for each entry returned in the reset info the
corresponding list of entries, one per each reset call.
"""
return _compact_dicts(self.reset_infos, fill_value)
[docs]
def finalized_step_infos(
self, fill_value: Any = None
) -> dict[str, list[list[Any]]]:
"""Returns a compacted final dictionary containing the step infos. Missing
values are filled automatically.
Parameters
----------
fill_value : Any, optional
The value to be used to fill in missing values.
Returns
-------
dict of (str, list of lists)
A unique dict containing for each entry returned in the step info the
corresponding list of dicts per episode, with one entry per each step call.
"""
if self.ep_step_infos:
warn(
"Internal buffer of step infos not empty, meaning that the last "
"episode did not terminate properly.",
RuntimeWarning,
stacklevel=2,
)
return _compact_dicts(
(_compact_dicts(d, fill_value) for d in self.step_infos), fill_value
)