Source code for chainerrl.links.sequence

from inspect import Parameter
from inspect import signature

import chainer

from chainerrl.recurrent import RecurrentChainMixin


def accept_variable_arguments(func):
    for param in signature(func).parameters.values():
        if param.kind in (Parameter.VAR_POSITIONAL,
                          Parameter.VAR_KEYWORD):
            return True
    return False


[docs]class Sequence(chainer.ChainList, RecurrentChainMixin): """Sequential callable Link that consists of other Links.""" def __init__(self, *layers): self.layers = list(layers) links = [layer for layer in layers if isinstance(layer, chainer.Link)] # Cache the signatures because it might be slow self.argnames = [set(signature(layer).parameters) for layer in layers] self.accept_var_args = [accept_variable_arguments(layer) for layer in layers] super().__init__(*links) def __call__(self, x, **kwargs): h = x for layer, argnames, accept_var_args in zip(self.layers, self.argnames, self.accept_var_args): if accept_var_args: layer_kwargs = kwargs else: layer_kwargs = {k: v for k, v in kwargs.items() if k in argnames} h = layer(h, **layer_kwargs) return h