from logging import getLogger
import chainer
from chainerrl import distribution
from chainerrl.policy import Policy
logger = getLogger(__name__)
[docs]class MellowmaxPolicy(chainer.Chain, Policy):
"""Mellowmax policy.
See: http://arxiv.org/abs/1612.05628
Args:
model (chainer.Link):
Link that is callable and outputs action values.
omega (float):
Parameter of the mellowmax function.
"""
def __init__(self, model, omega=1.):
self.omega = omega
super().__init__(model=model)
def __call__(self, x):
h = self.model(x)
return distribution.MellowmaxDistribution(h, omega=self.omega)