from cached_property import cached_property
import chainer
import numpy as np
from chainerrl.links.stateless_recurrent import call_recurrent_link
from chainerrl.links.stateless_recurrent import concatenate_sequences
from chainerrl.links.stateless_recurrent import is_recurrent_link
from chainerrl.links.stateless_recurrent import split_batched_sequences
from chainerrl.links.stateless_recurrent import StatelessRecurrentChainList
[docs]class StatelessRecurrentSequential(
StatelessRecurrentChainList, chainer.Sequential):
"""Sequential model that can contain stateless recurrent links.
This link a stateless recurrent analog to chainer.Sequential. It supports
the stateless recurrent interface by automatically detecting recurrent
links and handles recurrent states properly.
For non-recurrent layers (non-link callables or non-recurrent callable
links), this link automatically concatenates the input to the layers
for efficient computation.
Args:
*layers: Callable objects.
"""
def n_step_forward(self, sequences, recurrent_state, output_mode):
assert sequences
assert output_mode in ['split', 'concat']
if recurrent_state is None:
recurrent_state_queue = [None] * len(self.recurrent_children)
else:
assert len(recurrent_state) == len(self.recurrent_children)
recurrent_state_queue = list(reversed(recurrent_state))
new_recurrent_state = []
h = sequences
seq_mode = True
sections = np.cumsum([len(x) for x in sequences[:-1]], dtype=np.int32)
for layer in self._layers:
if is_recurrent_link(layer):
if not seq_mode:
h = split_batched_sequences(h, sections)
seq_mode = True
rs = recurrent_state_queue.pop()
h, rs = call_recurrent_link(layer, h, rs, output_mode='split')
new_recurrent_state.append(rs)
else:
if seq_mode:
seq_mode = False
h = concatenate_sequences(h)
if isinstance(h, tuple):
h = layer(*h)
else:
h = layer(h)
if not seq_mode and output_mode == 'split':
h = split_batched_sequences(h, sections)
seq_mode = True
elif seq_mode and output_mode == 'concat':
h = concatenate_sequences(h)
seq_mode = False
assert seq_mode is (output_mode == 'split')
assert not recurrent_state_queue
assert len(new_recurrent_state) == len(self.recurrent_children)
return h, tuple(new_recurrent_state)
@cached_property
def recurrent_children(self):
"""Return recurrent child links.
This overrides `StatelessRecurrentChainList.recurrent_children`
because `Sequential`'s evaluation order can be different from the
order of links in `Sequential.children()`.
See https://github.com/chainer/chainer/issues/6053
Returns:
tuple: Tuple of `chainer.Link`s that are recurrent.
"""
return tuple(child for child in self._layers
if is_recurrent_link(child))