Source code for chainerrl.agents.sarsa

import chainer

from chainerrl.agents import dqn

[docs]class SARSA(dqn.DQN): """Off-policy SARSA. This agent learns the Q-function of a behavior policy defined via the given explorer, instead of learning the Q-function of the optimal policy. """ def _compute_target_values(self, exp_batch): batch_next_state = exp_batch['next_state'] if self.recurrent: target_next_qout, _ = self.target_model.n_step_forward( batch_next_state, exp_batch['next_recurrent_state'], output_mode='concat') else: target_next_qout = self.target_model(batch_next_state) # Choose an action using the behavior policy next_greedy_actions = chainer.cuda.to_cpu( target_next_qout.greedy_actions.array) batch_next_action = self.xp.array([ self.explorer.select_action( self.t, lambda: next_greedy_actions[i], action_value=target_next_qout[i:i + 1], ) for i in range(len(exp_batch['action']))]) next_q = target_next_qout.evaluate_actions(batch_next_action) batch_rewards = exp_batch['reward'] batch_terminal = exp_batch['is_state_terminal'] discount = exp_batch['discount'] return batch_rewards + discount * (1.0 - batch_terminal) * next_q