Source code for chainerrl.agent

from abc import ABCMeta
from abc import abstractmethod
from abc import abstractproperty
import os

from chainer import serializers
import numpy
import warnings


def load_npz_no_strict(filename, obj):
    try:
        serializers.load_npz(filename, obj)
    except KeyError as e:
        warnings.warn(repr(e))
        with numpy.load(filename) as f:
            d = serializers.NpzDeserializer(f, strict=False)
            d.load(obj)


[docs]class Agent(object, metaclass=ABCMeta): """Abstract agent class."""
[docs] @abstractmethod def act_and_train(self, obs, reward): """Select an action for training. Returns: ~object: action """ raise NotImplementedError()
[docs] @abstractmethod def act(self, obs): """Select an action for evaluation. Returns: ~object: action """ raise NotImplementedError()
[docs] @abstractmethod def stop_episode_and_train(self, state, reward, done=False): """Observe consequences and prepare for a new episode. Returns: None """ raise NotImplementedError()
[docs] @abstractmethod def stop_episode(self): """Prepare for a new episode. Returns: None """ raise NotImplementedError()
[docs] @abstractmethod def save(self, dirname): """Save internal states. Returns: None """ pass
[docs] @abstractmethod def load(self, dirname): """Load internal states. Returns: None """ pass
[docs] @abstractmethod def get_statistics(self): """Get statistics of the agent. Returns: List of two-item tuples. The first item in a tuple is a str that represents the name of item, while the second item is a value to be recorded. Example: [('average_loss': 0), ('average_value': 1), ...] """ pass
class AttributeSavingMixin(object): """Mixin that provides save and load functionalities.""" @abstractproperty def saved_attributes(self): """Specify attribute names to save or load as a tuple of str.""" pass def save(self, dirname): """Save internal states.""" self.__save(dirname, []) def __save(self, dirname, ancestors): os.makedirs(dirname, exist_ok=True) ancestors.append(self) for attr in self.saved_attributes: assert hasattr(self, attr) attr_value = getattr(self, attr) if attr_value is None: continue if isinstance(attr_value, AttributeSavingMixin): assert not any( attr_value is ancestor for ancestor in ancestors ), "Avoid an infinite loop" attr_value.__save(os.path.join(dirname, attr), ancestors) else: serializers.save_npz( os.path.join(dirname, '{}.npz'.format(attr)), getattr(self, attr)) ancestors.pop() def load(self, dirname): """Load internal states.""" self.__load(dirname, []) def __load(self, dirname, ancestors): ancestors.append(self) for attr in self.saved_attributes: assert hasattr(self, attr) attr_value = getattr(self, attr) if attr_value is None: continue if isinstance(attr_value, AttributeSavingMixin): assert not any( attr_value is ancestor for ancestor in ancestors ), "Avoid an infinite loop" attr_value.load(os.path.join(dirname, attr)) else: """Fix Chainer Issue #2772 In Chainer v2, a (stateful) optimizer cannot be loaded from an npz saved before the first update. """ load_npz_no_strict( os.path.join(dirname, '{}.npz'.format(attr)), getattr(self, attr)) ancestors.pop() class AsyncAgent(Agent, metaclass=ABCMeta): """Abstract asynchronous agent class.""" @abstractproperty def process_idx(self): """Index of process as integer, 0 for the representative process.""" pass @abstractproperty def shared_attributes(self): """Tuple of names of shared attributes.""" pass class BatchAgent(Agent, metaclass=ABCMeta): """Abstract agent class that can interact with a batch of envs.""" @abstractmethod def batch_act(self, batch_obs): """Select a batch of actions for evaluation. Args: batch_obs (Sequence of ~object): Observations. Returns: Sequence of ~object: Actions. """ raise NotImplementedError() @abstractmethod def batch_act_and_train(self, batch_obs): """Select a batch of actions for training. Args: batch_obs (Sequence of ~object): Observations. Returns: Sequence of ~object: Actions. """ raise NotImplementedError() @abstractmethod def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset): """Observe a batch of action consequences for evaluation. Args: batch_obs (Sequence of ~object): Observations. batch_reward (Sequence of float): Rewards. batch_done (Sequence of boolean): Boolean values where True indicates the current state is terminal. batch_reset (Sequence of boolean): Boolean values where True indicates the current episode will be reset, even if the current state is not terminal. Returns: None """ raise NotImplementedError() @abstractmethod def batch_observe_and_train( self, batch_obs, batch_reward, batch_done, batch_reset): """Observe a batch of action consequences for training. Args: batch_obs (Sequence of ~object): Observations. batch_reward (Sequence of float): Rewards. batch_done (Sequence of boolean): Boolean values where True indicates the current state is terminal. batch_reset (Sequence of boolean): Boolean values where True indicates the current episode will be reset, even if the current state is not terminal. Returns: None """ raise NotImplementedError()