Source code for mpcrl.agents.common.rl_learning_agent
from abc import ABC
from typing import Any, Generic, TypeVar
from ...optim.gradient_based_optimizer import GradientBasedOptimizer, LrType
from .agent import SymType
from .learning_agent import LearningAgent
ExpType = TypeVar("ExpType")
[docs]
class RlLearningAgent(
LearningAgent[SymType, ExpType], ABC, Generic[SymType, ExpType, LrType]
):
r"""Base abstract class for learning agents that employe gradient-based RL
strategies to learn/improve the MPC policy. The only difference with the
:class:`LearningAgent` is that this class accepts the RL task's discount factor and
a gradient-based optimizer that dictates how the learnable parameters are updated.
Parameters
----------
discount_factor : float
In RL, the factor that discounts future rewards in favor of immediate rewards.
Usually denoted as :math:`\gamma`. It should satisfy :math:`\gamma \in (0, 1]`.
optimizer : GradientBasedOptimizer
A gradient-based optimizer (e.g., :class:`optim.GradientDescent`) to
compute the updates of the learnable parameters, based on the current
gradient-based RL algorithm.
kwargs
Additional arguments to be passed to :class:`LearningAgent`.
"""
def __init__(
self,
discount_factor: float,
optimizer: GradientBasedOptimizer[LrType],
**kwargs: Any,
) -> None:
self.discount_factor = discount_factor
self.optimizer = optimizer
super().__init__(**kwargs)
self.optimizer.set_learnable_parameters(self._learnable_pars)
def _establish_callback_hooks(self) -> None:
super()._establish_callback_hooks()
optim = self.optimizer
optimizer_hook = optim.hook
if optimizer_hook is not None:
self._hook_callback(repr(optim), optimizer_hook, optim.step)