Source code for chainerrl.agents.sarsa

from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from builtins import *  # NOQA
from future import standard_library
standard_library.install_aliases()

from chainerrl.agents import dqn


[docs]class SARSA(dqn.DQN): """SARSA. Unlike DQN, this agent uses actions that have been actually taken to compute tareget Q values, thus is an on-policy algorithm. """ def _compute_target_values(self, exp_batch, gamma): batch_next_state = exp_batch['next_state'] batch_next_action = exp_batch['next_action'] next_target_action_value = self.target_q_function( batch_next_state) next_q = next_target_action_value.evaluate_actions( batch_next_action) batch_rewards = exp_batch['reward'] batch_terminal = exp_batch['is_state_terminal'] return batch_rewards + self.gamma * (1.0 - batch_terminal) * next_q