from abc import ABCMeta
from abc import abstractmethod
from abc import abstractproperty
from cached_property import cached_property
import chainer
from chainer import functions as F
import numpy as np
from chainerrl.functions import arctanh
from chainerrl.functions import mellowmax
def _wrap_by_variable(x):
if isinstance(x, chainer.Variable):
return x
else:
return chainer.Variable(x)
def _unwrap_variable(x):
if isinstance(x, chainer.Variable):
return x.array
else:
return x
def sample_discrete_actions(batch_probs):
"""Sample a batch of actions from a batch of action probabilities.
Args:
batch_probs (ndarray): batch of action probabilities BxA
Returns:
ndarray consisting of sampled action indices
"""
xp = chainer.cuda.get_array_module(batch_probs)
return xp.argmax(
xp.log(batch_probs) + xp.random.gumbel(size=batch_probs.shape),
axis=1).astype(np.int32, copy=False)
[docs]class Distribution(object, metaclass=ABCMeta):
"""Batch of distributions of data."""
@abstractproperty
def entropy(self):
"""Entropy of distributions.
Returns:
chainer.Variable
"""
raise NotImplementedError()
[docs] @abstractmethod
def sample(self):
"""Sample from distributions.
Returns:
chainer.Variable
"""
raise NotImplementedError()
[docs] @abstractmethod
def prob(self, x):
"""Compute p(x).
Returns:
chainer.Variable
"""
raise NotImplementedError()
[docs] @abstractmethod
def log_prob(self, x):
"""Compute log p(x).
Returns:
chainer.Variable
"""
raise NotImplementedError()
[docs] @abstractmethod
def copy(self, x):
"""Copy a distribion unchained from the computation graph.
Returns:
Distribution
"""
raise NotImplementedError()
@abstractproperty
def most_probable(self):
"""Most probable data points.
Returns:
chainer.Variable
"""
raise NotImplementedError()
@abstractproperty
def kl(self, distrib):
"""Compute KL divergence D_KL(P|Q).
Args:
distrib (Distribution): Distribution Q.
Returns:
chainer.Variable
"""
raise NotImplementedError()
@abstractproperty
def params(self):
"""Learnable parameters of this distribution.
Returns:
tuple of chainer.Variable
"""
raise NotImplementedError()
[docs] def sample_with_log_prob(self):
"""Do `sample` and `log_prob` at the same time.
This can be more efficient than calling `sample` and `log_prob`
separately.
Returns:
chainer.Variable: Samples.
chainer.Variable: Log probability of the samples.
"""
y = self.sample()
return y, self.log_prob(y)
class CategoricalDistribution(Distribution):
"""Distribution of categorical data."""
@cached_property
def entropy(self):
with chainer.force_backprop_mode():
return - F.sum(self.all_prob * self.all_log_prob, axis=1)
@cached_property
def most_probable(self):
return chainer.Variable(
np.argmax(self.all_prob.array, axis=1).astype(np.int32))
def sample(self):
return chainer.Variable(sample_discrete_actions(self.all_prob.array))
def prob(self, x):
return F.select_item(self.all_prob, x)
def log_prob(self, x):
return F.select_item(self.all_log_prob, x)
@abstractmethod
def all_prob(self):
raise NotImplementedError()
@abstractmethod
def all_log_prob(self):
raise NotImplementedError()
def kl(self, distrib):
return F.sum(
self.all_prob * (self.all_log_prob - distrib.all_log_prob), axis=1)
[docs]class SoftmaxDistribution(CategoricalDistribution):
"""Softmax distribution.
Args:
logits (ndarray or chainer.Variable): Logits for softmax
distribution.
beta (float): inverse of the temperature parameter of softmax
distribution
min_prob (float): minimum probability across all labels
"""
def __init__(self, logits, beta=1.0, min_prob=0.0):
self.logits = logits
self.beta = beta
self.min_prob = min_prob
self.n = logits.shape[1]
assert self.min_prob * self.n <= 1.0
@property
def params(self):
return (self.logits,)
@cached_property
def all_prob(self):
with chainer.force_backprop_mode():
if self.min_prob > 0:
return (F.softmax(self.beta * self.logits)
* (1 - self.min_prob * self.n)) + self.min_prob
else:
return F.softmax(self.beta * self.logits)
@cached_property
def all_log_prob(self):
with chainer.force_backprop_mode():
if self.min_prob > 0:
return F.log(self.all_prob)
else:
return F.log_softmax(self.beta * self.logits)
def copy(self):
return SoftmaxDistribution(_unwrap_variable(self.logits).copy(),
beta=self.beta, min_prob=self.min_prob)
def __repr__(self):
return 'SoftmaxDistribution(beta={}, min_prob={}) logits:{} probs:{} entropy:{}'.format( # NOQA
self.beta, self.min_prob, self.logits.array,
self.all_prob.array, self.entropy.array)
def __getitem__(self, i):
return SoftmaxDistribution(self.logits[i],
beta=self.beta, min_prob=self.min_prob)
[docs]class MellowmaxDistribution(CategoricalDistribution):
"""Maximum entropy mellowmax distribution.
See: http://arxiv.org/abs/1612.05628
Args:
values (ndarray or chainer.Variable): Values to apply mellowmax.
"""
def __init__(self, values, omega=8.):
self.values = values
self.omega = omega
@property
def params(self):
return (self.values,)
@cached_property
def all_prob(self):
with chainer.force_backprop_mode():
return mellowmax.maximum_entropy_mellowmax(self.values)
@cached_property
def all_log_prob(self):
with chainer.force_backprop_mode():
return F.log(self.all_prob)
def copy(self):
return MellowmaxDistribution(_unwrap_variable(self.values).copy(),
omega=self.omega)
def __repr__(self):
return 'MellowmaxDistribution(omega={}) values:{} probs:{} entropy:{}'.format( # NOQA
self.omega, self.values.array, self.all_prob.array,
self.entropy.array)
def __getitem__(self, i):
return MellowmaxDistribution(self.values[i], omega=self.omega)
def clip_actions(actions, min_action, max_action):
min_actions = F.broadcast_to(min_action, actions.shape)
max_actions = F.broadcast_to(max_action, actions.shape)
return F.maximum(F.minimum(actions, max_actions), min_actions)
def _eltwise_gaussian_log_likelihood(x, mean, var, ln_var):
# log N(x|mean,var)
# = -0.5log(2pi) - 0.5log(var) - (x - mean)**2 / (2*var)
return -0.5 * np.log(2 * np.pi) - \
0.5 * ln_var - \
((x - mean) ** 2) / (2 * var)
[docs]class GaussianDistribution(Distribution):
"""Gaussian distribution."""
def __init__(self, mean, var):
self.mean = _wrap_by_variable(mean)
self.var = _wrap_by_variable(var)
self.ln_var = F.log(var)
@property
def params(self):
return (self.mean, self.var)
@cached_property
def most_probable(self):
return self.mean
def sample(self):
return F.gaussian(self.mean, self.ln_var)
def prob(self, x):
return F.exp(self.log_prob(x))
def log_prob(self, x):
eltwise_log_prob = _eltwise_gaussian_log_likelihood(
x, self.mean, self.var, self.ln_var)
return F.sum(eltwise_log_prob, axis=1)
@cached_property
def entropy(self):
# Differential entropy of Gaussian is:
# 0.5 * (log(2 * pi * var) + 1)
# = 0.5 * (log(2 * pi) + log var + 1)
with chainer.force_backprop_mode():
return 0.5 * self.mean.array.shape[1] * (np.log(2 * np.pi) + 1) + \
0.5 * F.sum(self.ln_var, axis=1)
def copy(self):
return GaussianDistribution(_unwrap_variable(self.mean).copy(),
_unwrap_variable(self.var).copy())
def kl(self, q):
p = self
return 0.5 * F.sum(q.ln_var - p.ln_var +
(p.var + (p.mean - q.mean) ** 2) / q.var -
1, axis=1)
def __repr__(self):
return 'GaussianDistribution mean:{} ln_var:{} entropy:{}'.format(
self.mean.array, self.ln_var.array, self.entropy.array)
def __getitem__(self, i):
return GaussianDistribution(self.mean[i], self.var[i])
def _tanh_forward_log_det_jacobian(x):
"""Compute log|det(dy/dx)| except summation where y=tanh(x)."""
# For the derivation of this formula, see:
# https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py # NOQA
return 2. * (np.log(2.) - x - F.softplus(-2. * x))
class SquashedGaussianDistribution(Distribution):
"""Gaussian distribution squashed by tanh.
This type of distribution was used in https://arxiv.org/abs/1812.05905.
"""
def __init__(self, mean, var):
self.mean = _wrap_by_variable(mean)
self.var = _wrap_by_variable(var)
self.ln_var = F.log(var)
@property
def params(self):
return (self.mean, self.var)
@cached_property
def most_probable(self):
return F.tanh(self.mean)
def sample_with_log_prob(self):
x = F.gaussian(self.mean, self.ln_var)
normal_log_prob = _eltwise_gaussian_log_likelihood(
x, self.mean, self.var, self.ln_var)
log_probs = normal_log_prob - _tanh_forward_log_det_jacobian(x)
y = F.tanh(x)
return y, F.sum(log_probs, axis=1)
def sample(self):
# Caution: If you would like to apply `log_prob` later, use
# `sample_with_log_prob` instead for stability, especially when
# tanh(x) can be close to -1 or 1.
y = F.tanh(F.gaussian(self.mean, self.ln_var))
return y
def prob(self, x):
return F.exp(self.log_prob(x))
def log_prob(self, x):
# Caution: If you would like to apply this to samples from the same
# distribution, use `sample_with_log_prob` instead for stability,
# especially when tanh(x) can be close to -1 or 1.
raw_action = arctanh(x)
normal_log_prob = _eltwise_gaussian_log_likelihood(
raw_action, self.mean, self.var, self.ln_var)
log_probs = normal_log_prob - _tanh_forward_log_det_jacobian(
raw_action)
return F.sum(log_probs, axis=1)
@cached_property
def entropy(self):
raise NotImplementedError
def copy(self):
return SquashedGaussianDistribution(
_unwrap_variable(self.mean).copy(),
_unwrap_variable(self.var).copy())
def kl(self, q):
raise NotImplementedError
def __repr__(self):
return 'SquashedGaussianDistribution mean:{} ln_var:{}'.format( # NOQA
self.mean.array, self.ln_var.array)
def __getitem__(self, i):
return SquashedGaussianDistribution(self.mean[i], self.var[i])
[docs]class ContinuousDeterministicDistribution(Distribution):
"""Continous deterministic distribution.
This distribution is supposed to be used in continuous deterministic
policies.
"""
def __init__(self, x):
self.x = _wrap_by_variable(x)
@cached_property
def entropy(self):
raise RuntimeError('Not defined')
@cached_property
def most_probable(self):
return self.x
def sample(self):
return self.x
def prob(self, x):
raise RuntimeError('Not defined')
def copy(self):
return ContinuousDeterministicDistribution(
_unwrap_variable(self.x).copy())
def log_prob(self, x):
raise RuntimeError('Not defined')
def kl(self, distrib):
raise RuntimeError('Not defined')
@property
def params(self):
return (self.x,)