Source code for chainerrl.agents.al

import chainer
from chainer import functions as F

from chainerrl.agents import dqn


[docs]class AL(dqn.DQN): """Advantage Learning. See: http://arxiv.org/abs/1512.04860. Args: alpha (float): Weight of (persistent) advantages. Convergence is guaranteed only for alpha in [0, 1). For other arguments, see DQN. """ def __init__(self, *args, **kwargs): self.alpha = kwargs.pop('alpha', 0.9) super().__init__(*args, **kwargs) def _compute_y_and_t(self, exp_batch): batch_state = exp_batch['state'] batch_size = len(exp_batch['reward']) if self.recurrent: qout, _ = self.model.n_step_forward( batch_state, exp_batch['recurrent_state'], output_mode='concat') else: qout = self.model(batch_state) batch_actions = exp_batch['action'] batch_q = qout.evaluate_actions(batch_actions) # Compute target values batch_next_state = exp_batch['next_state'] with chainer.no_backprop_mode(): if self.recurrent: target_qout, _ = self.target_model.n_step_forward( batch_state, exp_batch['recurrent_state'], output_mode='concat') target_next_qout, _ = self.target_model.n_step_forward( batch_next_state, exp_batch['next_recurrent_state'], output_mode='concat') else: target_qout = self.target_model(batch_state) target_next_qout = self.target_model(batch_next_state) next_q_max = F.reshape(target_next_qout.max, (batch_size,)) batch_rewards = exp_batch['reward'] batch_terminal = exp_batch['is_state_terminal'] # T Q: Bellman operator t_q = batch_rewards + exp_batch['discount'] * \ (1.0 - batch_terminal) * next_q_max # T_AL Q: advantage learning operator cur_advantage = F.reshape( target_qout.compute_advantage(batch_actions), (batch_size,)) tal_q = t_q + self.alpha * cur_advantage return batch_q, tal_q