Source code for chainerrl.policies.mellowmax_policy

from logging import getLogger

import chainer

from chainerrl import distribution
from chainerrl.policy import Policy


logger = getLogger(__name__)


[docs]class MellowmaxPolicy(chainer.Chain, Policy): """Mellowmax policy. See: http://arxiv.org/abs/1612.05628 Args: model (chainer.Link): Link that is callable and outputs action values. omega (float): Parameter of the mellowmax function. """ def __init__(self, model, omega=1.): self.omega = omega super().__init__(model=model) def __call__(self, x): h = self.model(x) return distribution.MellowmaxDistribution(h, omega=self.omega)