from inspect import Parameter
from inspect import signature
import chainer
from chainerrl.recurrent import RecurrentChainMixin
def accept_variable_arguments(func):
for param in signature(func).parameters.values():
if param.kind in (Parameter.VAR_POSITIONAL,
Parameter.VAR_KEYWORD):
return True
return False
[docs]class Sequence(chainer.ChainList, RecurrentChainMixin):
"""Sequential callable Link that consists of other Links."""
def __init__(self, *layers):
self.layers = list(layers)
links = [layer for layer in layers if isinstance(layer, chainer.Link)]
# Cache the signatures because it might be slow
self.argnames = [set(signature(layer).parameters)
for layer in layers]
self.accept_var_args = [accept_variable_arguments(layer)
for layer in layers]
super().__init__(*links)
def __call__(self, x, **kwargs):
h = x
for layer, argnames, accept_var_args in zip(self.layers,
self.argnames,
self.accept_var_args):
if accept_var_args:
layer_kwargs = kwargs
else:
layer_kwargs = {k: v for k, v in kwargs.items()
if k in argnames}
h = layer(h, **layer_kwargs)
return h