Source code for chainerrl.agents.a2c

from logging import getLogger
import warnings

import chainer
from chainer import functions as F

from chainerrl import agent
from chainerrl.misc.batch_states import batch_states
from chainerrl.recurrent import RecurrentChainMixin

logger = getLogger(__name__)

class A2CModel(chainer.Link):
    """A2C model."""

    def pi_and_v(self, obs):
        """Evaluate the policy and the V-function.

            obs (Variable or ndarray): Batched observations.
            Distribution and Variable
        raise NotImplementedError()

    def __call__(self, obs):
        return self.pi_and_v(obs)

class A2CSeparateModel(chainer.Chain, A2CModel, RecurrentChainMixin):
    """A2C model that consists of a separate policy and V-function.

        pi (Policy): Policy.
        v (VFunction): V-function.

    def __init__(self, pi, v):
        super().__init__(pi=pi, v=v)

    def pi_and_v(self, obs):
        pout = self.pi(obs)
        vout = self.v(obs)
        return pout, vout

[docs]class A2C(agent.AttributeSavingMixin, agent.BatchAgent): """A2C: Advantage Actor-Critic. A2C is a synchronous, deterministic variant of Asynchronous Advantage Actor Critic (A3C). See Args: model (A2CModel): Model to train optimizer (chainer.Optimizer): optimizer used to train the model gamma (float): Discount factor [0,1] num_processes (int): The number of processes gpu (int): GPU device id if not None nor negative. update_steps (int): The number of update steps phi (callable): Feature extractor function pi_loss_coef (float): Weight coefficient for the loss of the policy v_loss_coef (float): Weight coefficient for the loss of the value function entropy_coeff (float): Weight coefficient for the loss of the entropy use_gae (bool): use generalized advantage estimation(GAE) tau (float): gae parameter average_actor_loss_decay (float): Decay rate of average actor loss. Used only to record statistics. average_entropy_decay (float): Decay rate of average entropy. Used only to record statistics. average_value_decay (float): Decay rate of average value. Used only to record statistics. act_deterministically (bool): If set true, choose most probable actions in act method. batch_states (callable): method which makes a batch of observations. default is `chainerrl.misc.batch_states.batch_states` """ process_idx = None saved_attributes = ['model', 'optimizer'] def __init__(self, model, optimizer, gamma, num_processes, gpu=None, update_steps=5, phi=lambda x: x, pi_loss_coef=1.0, v_loss_coef=0.5, entropy_coeff=0.01, use_gae=False, tau=0.95, act_deterministically=False, average_actor_loss_decay=0.999, average_entropy_decay=0.999, average_value_decay=0.999, batch_states=batch_states): assert isinstance(model, A2CModel) self.model = model self.gpu = gpu if gpu is not None and gpu >= 0: chainer.cuda.get_device_from_id(gpu).use() self.model.to_gpu(device=gpu) self.optimizer = optimizer self.update_steps = update_steps self.num_processes = num_processes self.gamma = gamma self.use_gae = use_gae self.tau = tau self.act_deterministically = act_deterministically self.phi = phi self.pi_loss_coef = pi_loss_coef self.v_loss_coef = v_loss_coef self.entropy_coeff = entropy_coeff self.average_actor_loss_decay = average_actor_loss_decay self.average_value_decay = average_value_decay self.average_entropy_decay = average_entropy_decay self.batch_states = batch_states self.xp = self.model.xp self.t = 0 self.t_start = 0 # Stats self.average_actor_loss = 0 self.average_value = 0 self.average_entropy = 0 def _flush_storage(self, obs_shape, action): obs_shape = obs_shape[1:] action_shape = action.shape[1:] self.states = self.xp.zeros( [self.update_steps + 1, self.num_processes] + list(obs_shape), dtype='f') self.actions = self.xp.zeros( [self.update_steps, self.num_processes] + list(action_shape), dtype=action.dtype) self.rewards = self.xp.zeros( (self.update_steps, self.num_processes), dtype='f') self.value_preds = self.xp.zeros( (self.update_steps + 1, self.num_processes), dtype='f') self.returns = self.xp.zeros( (self.update_steps + 1, self.num_processes), dtype='f') self.masks = self.xp.ones( (self.update_steps, self.num_processes), dtype='f') self.obs_shape = obs_shape self.action_shape = action_shape def _compute_returns(self, next_value): if self.use_gae: self.value_preds[-1] = next_value gae = 0 for i in reversed(range(self.update_steps)): delta = self.rewards[i] + \ self.gamma * self.value_preds[i + 1] * self.masks[i] - \ self.value_preds[i] gae = delta + self.gamma * self.tau * self.masks[i] * gae self.returns[i] = gae + self.value_preds[i] else: self.returns[-1] = next_value for i in reversed(range(self.update_steps)): self.returns[i] = self.rewards[i] + \ self.gamma * self.returns[i + 1] * self.masks[i] def update(self): with chainer.no_backprop_mode(): _, next_value = self.model.pi_and_v(self.states[-1]) next_value = next_value.array[:, 0] self._compute_returns(next_value) pout, values = \ self.model.pi_and_v(chainer.Variable( self.states[:-1].reshape([-1] + list(self.obs_shape)))) actions = chainer.Variable( self.actions.reshape([-1] + list(self.action_shape))) dist_entropy = F.mean(pout.entropy) action_log_probs = pout.log_prob(actions) values = values.reshape((self.update_steps, self.num_processes)) action_log_probs = action_log_probs.reshape( (self.update_steps, self.num_processes)) advantages = self.returns[:-1] - values value_loss = F.mean(advantages * advantages) action_loss = \ - F.mean(advantages.array * action_log_probs) self.model.cleargrads() (value_loss * self.v_loss_coef + action_loss * self.pi_loss_coef - dist_entropy * self.entropy_coeff).backward() self.optimizer.update() self.states[0] = self.states[-1] self.t_start = self.t # Update stats self.average_actor_loss += ( (1 - self.average_actor_loss_decay) * (float(action_loss.array) - self.average_actor_loss)) self.average_value += ( (1 - self.average_value_decay) * (float(value_loss.array) - self.average_value)) self.average_entropy += ( (1 - self.average_entropy_decay) * (float(dist_entropy.array) - self.average_entropy)) def batch_act_and_train(self, batch_obs): statevar = self.batch_states(batch_obs, self.xp, self.phi) if self.t == 0: with chainer.no_backprop_mode(): pout, _ = self.model.pi_and_v(statevar) action = pout.sample().array self._flush_storage(statevar.shape, action) self.states[self.t - self.t_start] = statevar if self.t - self.t_start == self.update_steps: self.update() with chainer.no_backprop_mode(): pout, value = self.model.pi_and_v(statevar) action = pout.sample().array self.actions[self.t - self.t_start] \ = action.reshape([-1] + list(self.action_shape)) self.value_preds[self.t - self.t_start] = value.array[:, 0] self.t += 1 return chainer.cuda.to_cpu(action) def batch_act(self, batch_obs): statevar = self.batch_states(batch_obs, self.xp, self.phi) with chainer.no_backprop_mode(): pout, _ = self.model.pi_and_v(statevar) action = pout.sample().array return chainer.cuda.to_cpu(action) def batch_observe_and_train(self, batch_obs, batch_reward, batch_done, batch_reset): if any(batch_reset): warnings.warn('A2C currently does not support resetting an env without reaching a terminal state during training. When receiving True in batch_reset, A2C considers it as True in batch_done instead.') # NOQA batch_done = list(batch_done) for i, reset in enumerate(batch_reset): if reset: batch_done[i] = True statevar = self.batch_states(batch_obs, self.xp, self.phi) self.masks[self.t - self.t_start - 1] =\ self.xp.array([0.0 if done else 1.0 for done in batch_done]) self.rewards[self.t - self.t_start - 1] =\ self.xp.array(batch_reward, dtype=self.xp.float32) self.states[self.t - self.t_start] = statevar if self.t - self.t_start == self.update_steps: self.update() def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset): pass def act_and_train(obs, reward): raise RuntimeError('A2C does not support non-batch training') def act(self, obs): with chainer.no_backprop_mode(): statevar = self.batch_states([obs], self.xp, self.phi) pout, _ = self.model.pi_and_v(statevar) if self.act_deterministically: return chainer.cuda.to_cpu(pout.most_probable.array)[0] else: return chainer.cuda.to_cpu(pout.sample().array)[0] def stop_episode_and_train(self, state, reward, done=False): raise RuntimeError('A2C does not support non-batch training') def stop_episode(self): pass def get_statistics(self): return [ ('average_actor', self.average_actor_loss), ('average_value', self.average_value), ('average_entropy', self.average_entropy), ]