Source code for chainerrl.links.mlp_bn

import chainer
from chainer import functions as F
from chainer.initializers import LeCunNormal
from chainer import links as L


class LinearBN(chainer.Chain):
    """Linear layer with BatchNormalization."""

    def __init__(self, in_size, out_size):
        super().__init__()
        with self.init_scope():
            self.linear = L.Linear(in_size, out_size)
            bn = L.BatchNormalization(out_size)
            bn.avg_var[:] = 1
            self.bn = bn

    def __call__(self, x):
        return self.bn(self.linear(x))


[docs]class MLPBN(chainer.Chain): """Multi-Layer Perceptron with Batch Normalization. Args: in_size (int): Input size. out_size (int): Output size. hidden_sizes (list of ints): Sizes of hidden channels. normalize_input (bool): If set to True, Batch Normalization is applied to inputs. normalize_output (bool): If set to True, Batch Normalization is applied to outputs. nonlinearity (callable): Nonlinearity between layers. It must accept a Variable as an argument and return a Variable with the same shape. Nonlinearities with learnable parameters such as PReLU are not supported. last_wscale (float): Scale of weight initialization of the last layer. """ def __init__(self, in_size, out_size, hidden_sizes, normalize_input=True, normalize_output=False, nonlinearity=F.relu, last_wscale=1): self.in_size = in_size self.out_size = out_size self.hidden_sizes = hidden_sizes self.normalize_input = normalize_input self.normalize_output = normalize_output self.nonlinearity = nonlinearity super().__init__() with self.init_scope(): if normalize_input: self.input_bn = L.BatchNormalization(in_size) self.input_bn.avg_var[:] = 1 if hidden_sizes: hidden_layers = [] hidden_layers.append(LinearBN(in_size, hidden_sizes[0])) for hin, hout in zip(hidden_sizes, hidden_sizes[1:]): hidden_layers.append(LinearBN(hin, hout)) self.hidden_layers = chainer.ChainList(*hidden_layers) self.output = L.Linear(hidden_sizes[-1], out_size, initialW=LeCunNormal(last_wscale)) else: self.output = L.Linear(in_size, out_size, initialW=LeCunNormal(last_wscale)) if normalize_output: self.output_bn = L.BatchNormalization(out_size) self.output_bn.avg_var[:] = 1 def __call__(self, x): h = x assert (not chainer.config.train) or x.shape[0] > 1 if self.normalize_input: h = self.input_bn(h) if self.hidden_sizes: for link in self.hidden_layers: h = self.nonlinearity(link(h)) h = self.output(h) if self.normalize_output: h = self.output_bn(h) return h