Source code for chainerrl.recurrent

from abc import ABCMeta
from abc import abstractmethod
import collections
import contextlib

import chainer


def unchain_backward(state):
    """Call Variable.unchain_backward recursively."""
    if isinstance(state, collections.Iterable):
        for s in state:
            unchain_backward(s)
    elif isinstance(state, chainer.Variable):
        state.unchain_backward()


[docs]class Recurrent(object, metaclass=ABCMeta): """Interface of recurrent and stateful models. This is an interface of recurrent and stateful models. ChainerRL supports recurrent neural network models as stateful models that implement this interface. To implement this interface, you need to implement three abstract methods of it: get_state, set_state and reset_state. """ __state_stack = []
[docs] @abstractmethod def get_state(self): """Get the current state of this model. Returns: Any object that represents a state of this model. """ raise NotImplementedError()
[docs] @abstractmethod def set_state(self, state): """Overwrite the state of this model with a given state. Args: state (object): Any object that represents a state of this model. """ raise NotImplementedError()
[docs] @abstractmethod def reset_state(self): """Reset the state of this model to the initial state. For typical RL models, this method is expected to be called before every episode. """ raise NotImplementedError()
def unchain_backward(self): unchain_backward(self.get_state()) def push_state(self): self.__state_stack.append(self.get_state()) self.reset_state() def pop_state(self): self.set_state(self.__state_stack.pop()) def push_and_keep_state(self): self.__state_stack.append(self.get_state())
[docs] def update_state(self, *args, **kwargs): """Update this model's state as if self.__call__ is called. Unlike __call__, stateless objects may do nothing. """ self(*args, **kwargs)
@contextlib.contextmanager def state_reset(self): self.push_state() yield self.pop_state() @contextlib.contextmanager def state_kept(self): self.push_and_keep_state() yield self.pop_state()
def get_state(chain): assert isinstance(chain, (chainer.Chain, chainer.ChainList)) state = [] for l in chain.children(): if isinstance(l, chainer.links.LSTM): state.append((l.c, l.h)) elif isinstance(l, Recurrent): state.append(l.get_state()) elif isinstance(l, (chainer.Chain, chainer.ChainList)): state.append(get_state(l)) else: state.append(None) return state def stateful_links(chain): for l in chain.children(): if isinstance(l, (chainer.links.LSTM, Recurrent)): yield l elif isinstance(l, (chainer.Chain, chainer.ChainList)): for m in stateful_links(l): yield m def set_state(chain, state): assert isinstance(chain, (chainer.Chain, chainer.ChainList)) for l, s in zip(chain.children(), state): if isinstance(l, chainer.links.LSTM): c, h = s # LSTM.set_state doesn't accept None state if c is not None: l.set_state(c, h) elif isinstance(l, Recurrent): l.set_state(s) elif isinstance(l, (chainer.Chain, chainer.ChainList)): set_state(l, s) else: assert s is None def reset_state(chain): assert isinstance(chain, (chainer.Chain, chainer.ChainList)) for l in chain.children(): if isinstance(l, chainer.links.LSTM): l.reset_state() elif isinstance(l, Recurrent): l.reset_state() elif isinstance(l, (chainer.Chain, chainer.ChainList)): reset_state(l) class RecurrentChainMixin(Recurrent): """Mixin that aggregate states of children. This mixin can only applied to chainer.Chain or chainer.ChainLink. The resulting class will implement Recurrent by searching recurrent models recursively from its children. """ def get_state(self): return get_state(self) def set_state(self, state): set_state(self, state) def reset_state(self): reset_state(self)
[docs]@contextlib.contextmanager def state_kept(link): """Keeps the previous state of a given link. This is a context manager that saves saves the current state of the link before entering the context, and then restores the saved state after escaping the context. This will just ignore non-Recurrent links. .. code-block:: python # Suppose the link is in a state A assert link.get_state() is A with state_kept(link): # The link is still in a state A assert link.get_state() is A # After evaluating the link, it may be in a different state y1 = link(x1) assert link.get_state() is not A # After escaping from the context, the link is in a state A again # because of the context manager assert link.get_state() is A """ if isinstance(link, Recurrent): link.push_and_keep_state() yield link.pop_state() else: yield
[docs]@contextlib.contextmanager def state_reset(link): """Reset the state while keeping the previous state of a given link. This is a context manager that saves saves the current state of the link and reset it to the initial state before entering the context, and then restores the saved state after escaping the context. This will just ignore non-Recurrent links. .. code-block:: python # Suppose the link is in a non-initial state A assert link.get_state() is A with state_reset(link): # The link's state has been reset to the initial state assert link.get_state() is InitialState # After evaluating the link, it may be in a different state y1 = link(x1) assert link.get_state() is not InitialState # After escaping from the context, the link is in a state A again # because of the context manager assert link.get_state() is A """ if isinstance(link, Recurrent): link.push_state() yield link.pop_state() else: yield