Source code for chainerrl.agents.residual_dqn

from chainerrl.agents.dqn import DQN
from chainerrl.functions import scale_grad


[docs]class ResidualDQN(DQN): """DQN that allows maxQ also backpropagate gradients.""" def __init__(self, *args, **kwargs): self.grad_scale = kwargs.pop('grad_scale', 1.0) super().__init__(*args, **kwargs) def sync_target_network(self): pass def _compute_target_values(self, exp_batch): batch_next_state = exp_batch['next_state'] if self.recurrent: target_next_qout, _ = self.model.n_step_forward( batch_next_state, exp_batch['next_recurrent_state'], output_mode='concat') else: target_next_qout = self.model(batch_next_state) next_q_max = target_next_qout.max batch_rewards = exp_batch['reward'] batch_terminal = exp_batch['is_state_terminal'] batch_discount = exp_batch['discount'] return (batch_rewards + batch_discount * (1.0 - batch_terminal) * next_q_max) def _compute_y_and_t(self, exp_batch): batch_state = exp_batch['state'] # Compute Q-values for current states if self.recurrent: qout, _ = self.model.n_step_forward( batch_state, exp_batch['recurrent_state'], output_mode='concat') else: qout = self.model(batch_state) batch_actions = exp_batch['action'] batch_q = qout.evaluate_actions(batch_actions)[..., None] # Target values must also backprop gradients batch_q_target = self._compute_target_values(exp_batch)[..., None] return batch_q, scale_grad.scale_grad(batch_q_target, self.grad_scale) @property def saved_attributes(self): # ResidualDQN doesn't use target models return ('model', 'optimizer')