Source code for chainerrl.distribution

from abc import ABCMeta
from abc import abstractmethod
from abc import abstractproperty

from cached_property import cached_property
import chainer
from chainer import functions as F
import numpy as np

from chainerrl.functions import arctanh
from chainerrl.functions import mellowmax


def _wrap_by_variable(x):
    if isinstance(x, chainer.Variable):
        return x
    else:
        return chainer.Variable(x)


def _unwrap_variable(x):
    if isinstance(x, chainer.Variable):
        return x.array
    else:
        return x


def sample_discrete_actions(batch_probs):
    """Sample a batch of actions from a batch of action probabilities.

    Args:
        batch_probs (ndarray): batch of action probabilities BxA
    Returns:
        ndarray consisting of sampled action indices
    """
    xp = chainer.cuda.get_array_module(batch_probs)
    return xp.argmax(
        xp.log(batch_probs) + xp.random.gumbel(size=batch_probs.shape),
        axis=1).astype(np.int32, copy=False)


[docs]class Distribution(object, metaclass=ABCMeta): """Batch of distributions of data.""" @abstractproperty def entropy(self): """Entropy of distributions. Returns: chainer.Variable """ raise NotImplementedError()
[docs] @abstractmethod def sample(self): """Sample from distributions. Returns: chainer.Variable """ raise NotImplementedError()
[docs] @abstractmethod def prob(self, x): """Compute p(x). Returns: chainer.Variable """ raise NotImplementedError()
[docs] @abstractmethod def log_prob(self, x): """Compute log p(x). Returns: chainer.Variable """ raise NotImplementedError()
[docs] @abstractmethod def copy(self, x): """Copy a distribion unchained from the computation graph. Returns: Distribution """ raise NotImplementedError()
@abstractproperty def most_probable(self): """Most probable data points. Returns: chainer.Variable """ raise NotImplementedError() @abstractproperty def kl(self, distrib): """Compute KL divergence D_KL(P|Q). Args: distrib (Distribution): Distribution Q. Returns: chainer.Variable """ raise NotImplementedError() @abstractproperty def params(self): """Learnable parameters of this distribution. Returns: tuple of chainer.Variable """ raise NotImplementedError()
[docs] def sample_with_log_prob(self): """Do `sample` and `log_prob` at the same time. This can be more efficient than calling `sample` and `log_prob` separately. Returns: chainer.Variable: Samples. chainer.Variable: Log probability of the samples. """ y = self.sample() return y, self.log_prob(y)
class CategoricalDistribution(Distribution): """Distribution of categorical data.""" @cached_property def entropy(self): with chainer.force_backprop_mode(): return - F.sum(self.all_prob * self.all_log_prob, axis=1) @cached_property def most_probable(self): return chainer.Variable( np.argmax(self.all_prob.array, axis=1).astype(np.int32)) def sample(self): return chainer.Variable(sample_discrete_actions(self.all_prob.array)) def prob(self, x): return F.select_item(self.all_prob, x) def log_prob(self, x): return F.select_item(self.all_log_prob, x) @abstractmethod def all_prob(self): raise NotImplementedError() @abstractmethod def all_log_prob(self): raise NotImplementedError() def kl(self, distrib): return F.sum( self.all_prob * (self.all_log_prob - distrib.all_log_prob), axis=1)
[docs]class SoftmaxDistribution(CategoricalDistribution): """Softmax distribution. Args: logits (ndarray or chainer.Variable): Logits for softmax distribution. beta (float): inverse of the temperature parameter of softmax distribution min_prob (float): minimum probability across all labels """ def __init__(self, logits, beta=1.0, min_prob=0.0): self.logits = logits self.beta = beta self.min_prob = min_prob self.n = logits.shape[1] assert self.min_prob * self.n <= 1.0 @property def params(self): return (self.logits,) @cached_property def all_prob(self): with chainer.force_backprop_mode(): if self.min_prob > 0: return (F.softmax(self.beta * self.logits) * (1 - self.min_prob * self.n)) + self.min_prob else: return F.softmax(self.beta * self.logits) @cached_property def all_log_prob(self): with chainer.force_backprop_mode(): if self.min_prob > 0: return F.log(self.all_prob) else: return F.log_softmax(self.beta * self.logits) def copy(self): return SoftmaxDistribution(_unwrap_variable(self.logits).copy(), beta=self.beta, min_prob=self.min_prob) def __repr__(self): return 'SoftmaxDistribution(beta={}, min_prob={}) logits:{} probs:{} entropy:{}'.format( # NOQA self.beta, self.min_prob, self.logits.array, self.all_prob.array, self.entropy.array) def __getitem__(self, i): return SoftmaxDistribution(self.logits[i], beta=self.beta, min_prob=self.min_prob)
[docs]class MellowmaxDistribution(CategoricalDistribution): """Maximum entropy mellowmax distribution. See: http://arxiv.org/abs/1612.05628 Args: values (ndarray or chainer.Variable): Values to apply mellowmax. """ def __init__(self, values, omega=8.): self.values = values self.omega = omega @property def params(self): return (self.values,) @cached_property def all_prob(self): with chainer.force_backprop_mode(): return mellowmax.maximum_entropy_mellowmax(self.values) @cached_property def all_log_prob(self): with chainer.force_backprop_mode(): return F.log(self.all_prob) def copy(self): return MellowmaxDistribution(_unwrap_variable(self.values).copy(), omega=self.omega) def __repr__(self): return 'MellowmaxDistribution(omega={}) values:{} probs:{} entropy:{}'.format( # NOQA self.omega, self.values.array, self.all_prob.array, self.entropy.array) def __getitem__(self, i): return MellowmaxDistribution(self.values[i], omega=self.omega)
def clip_actions(actions, min_action, max_action): min_actions = F.broadcast_to(min_action, actions.shape) max_actions = F.broadcast_to(max_action, actions.shape) return F.maximum(F.minimum(actions, max_actions), min_actions) def _eltwise_gaussian_log_likelihood(x, mean, var, ln_var): # log N(x|mean,var) # = -0.5log(2pi) - 0.5log(var) - (x - mean)**2 / (2*var) return -0.5 * np.log(2 * np.pi) - \ 0.5 * ln_var - \ ((x - mean) ** 2) / (2 * var)
[docs]class GaussianDistribution(Distribution): """Gaussian distribution.""" def __init__(self, mean, var): self.mean = _wrap_by_variable(mean) self.var = _wrap_by_variable(var) self.ln_var = F.log(var) @property def params(self): return (self.mean, self.var) @cached_property def most_probable(self): return self.mean def sample(self): return F.gaussian(self.mean, self.ln_var) def prob(self, x): return F.exp(self.log_prob(x)) def log_prob(self, x): eltwise_log_prob = _eltwise_gaussian_log_likelihood( x, self.mean, self.var, self.ln_var) return F.sum(eltwise_log_prob, axis=1) @cached_property def entropy(self): # Differential entropy of Gaussian is: # 0.5 * (log(2 * pi * var) + 1) # = 0.5 * (log(2 * pi) + log var + 1) with chainer.force_backprop_mode(): return 0.5 * self.mean.array.shape[1] * (np.log(2 * np.pi) + 1) + \ 0.5 * F.sum(self.ln_var, axis=1) def copy(self): return GaussianDistribution(_unwrap_variable(self.mean).copy(), _unwrap_variable(self.var).copy()) def kl(self, q): p = self return 0.5 * F.sum(q.ln_var - p.ln_var + (p.var + (p.mean - q.mean) ** 2) / q.var - 1, axis=1) def __repr__(self): return 'GaussianDistribution mean:{} ln_var:{} entropy:{}'.format( self.mean.array, self.ln_var.array, self.entropy.array) def __getitem__(self, i): return GaussianDistribution(self.mean[i], self.var[i])
def _tanh_forward_log_det_jacobian(x): """Compute log|det(dy/dx)| except summation where y=tanh(x).""" # For the derivation of this formula, see: # https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py # NOQA return 2. * (np.log(2.) - x - F.softplus(-2. * x)) class SquashedGaussianDistribution(Distribution): """Gaussian distribution squashed by tanh. This type of distribution was used in https://arxiv.org/abs/1812.05905. """ def __init__(self, mean, var): self.mean = _wrap_by_variable(mean) self.var = _wrap_by_variable(var) self.ln_var = F.log(var) @property def params(self): return (self.mean, self.var) @cached_property def most_probable(self): return F.tanh(self.mean) def sample_with_log_prob(self): x = F.gaussian(self.mean, self.ln_var) normal_log_prob = _eltwise_gaussian_log_likelihood( x, self.mean, self.var, self.ln_var) log_probs = normal_log_prob - _tanh_forward_log_det_jacobian(x) y = F.tanh(x) return y, F.sum(log_probs, axis=1) def sample(self): # Caution: If you would like to apply `log_prob` later, use # `sample_with_log_prob` instead for stability, especially when # tanh(x) can be close to -1 or 1. y = F.tanh(F.gaussian(self.mean, self.ln_var)) return y def prob(self, x): return F.exp(self.log_prob(x)) def log_prob(self, x): # Caution: If you would like to apply this to samples from the same # distribution, use `sample_with_log_prob` instead for stability, # especially when tanh(x) can be close to -1 or 1. raw_action = arctanh(x) normal_log_prob = _eltwise_gaussian_log_likelihood( raw_action, self.mean, self.var, self.ln_var) log_probs = normal_log_prob - _tanh_forward_log_det_jacobian( raw_action) return F.sum(log_probs, axis=1) @cached_property def entropy(self): raise NotImplementedError def copy(self): return SquashedGaussianDistribution( _unwrap_variable(self.mean).copy(), _unwrap_variable(self.var).copy()) def kl(self, q): raise NotImplementedError def __repr__(self): return 'SquashedGaussianDistribution mean:{} ln_var:{}'.format( # NOQA self.mean.array, self.ln_var.array) def __getitem__(self, i): return SquashedGaussianDistribution(self.mean[i], self.var[i])
[docs]class ContinuousDeterministicDistribution(Distribution): """Continous deterministic distribution. This distribution is supposed to be used in continuous deterministic policies. """ def __init__(self, x): self.x = _wrap_by_variable(x) @cached_property def entropy(self): raise RuntimeError('Not defined') @cached_property def most_probable(self): return self.x def sample(self): return self.x def prob(self, x): raise RuntimeError('Not defined') def copy(self): return ContinuousDeterministicDistribution( _unwrap_variable(self.x).copy()) def log_prob(self, x): raise RuntimeError('Not defined') def kl(self, distrib): raise RuntimeError('Not defined') @property def params(self): return (self.x,)