Source code for mpcrl.optim.base_optimizer

from typing import Any

import numpy as np

from ..core.parameters import LearnableParametersDict


[docs] class BaseOptimizer: """Base class for optimization algorithms. This class contains useful methods for, e.g., initializing the optimizer, retrieving bounds on the learnable parameters, etc. Parameters ---------- max_percentage_update : float, optional A positive float that specifies the maximum percentage change the learnable parameters can experience in each update. For example, ``max_percentage_update=0.5`` means that the parameters can be updated by up to 50% of their current value. By default, it is set to ``+inf``. """ def __init__(self, max_percentage_update: float = float("+inf")) -> None: self.max_percentage_update = max_percentage_update self.learnable_parameters: LearnableParametersDict
[docs] def set_learnable_parameters(self, pars: LearnableParametersDict) -> None: """Makes the optimization class aware of the dictionary of the learnable parameters whose values are to be updated. Parameters ---------- pars : :class`mpcrl.LearnableParametersDict` The dictionary of the learnable parameters. """ self.learnable_parameters = pars self._update_solver = self._init_update_solver()
def _get_update_bounds( self, theta: np.ndarray, eps: float = 0.1 ) -> tuple[np.ndarray, np.ndarray]: """Internal utility to retrieve the current bounds on the learnable parameters. Only useful if the update problem is not unconstrained, i.e., there are either some lower- or upper-bounds, or a maximum percentage update was given.""" lb = self.learnable_parameters.lb - theta ub = self.learnable_parameters.ub - theta perc = self.max_percentage_update if perc != float("+inf"): max_update_delta = np.maximum(np.abs(perc * theta), eps) lb = np.maximum(lb, -max_update_delta) ub = np.minimum(ub, +max_update_delta) return lb, ub def _init_update_solver(self) -> Any: """Internal utility to initialize whatever solver is necessary to perform an update according to this learning strategy.""" def __str__(self) -> str: return self.__class__.__name__ def __repr__(self) -> str: cn = self.__class__.__name__ mp = self.max_percentage_update return f"{cn}(max%={mp})"