Source code for chainerrl.links.stateless_recurrent_branched

import chainer

from chainerrl.links.stateless_recurrent import call_recurrent_link
from chainerrl.links.stateless_recurrent import StatelessRecurrentChainList


[docs]class StatelessRecurrentBranched( StatelessRecurrentChainList, chainer.ChainList): """Stateless recurrent parallel link. This is a recurrent analog to chainerrl.links.Branched. It bundles multiple links that implements `StatelessRecurrent`. Args: *links: Child links. Each link should be recurrent and callable. """ def n_step_forward(self, sequences, recurrent_state, output_mode): if recurrent_state is None: n = len(self) recurrent_state = [None] * n child_ys, rs = tuple(zip(*[ call_recurrent_link(link, sequences, rs, output_mode) for link, rs in zip(self, recurrent_state)])) if output_mode == 'concat': return child_ys, rs assert output_mode == 'split' assert len(child_ys) == len(self) assert len(child_ys[0]) == len(sequences) assert len(rs) == len(self) out = list(zip(*child_ys)) return out, rs