Source code for chainerrl.agents.dpp

from abc import ABCMeta
from abc import abstractmethod

import chainer
import chainer.functions as F

from chainerrl.agents.dqn import DQN

class AbstractDPP(DQN, metaclass=ABCMeta):
    """Dynamic Policy Programming.


    def _l_operator(self, qout):
        raise NotImplementedError()

    def _compute_target_values(self, exp_batch):

        batch_next_state = exp_batch['next_state']

        if self.recurrent:
            target_next_qout, _ = self.target_model.n_step_forward(
                batch_next_state, exp_batch['next_recurrent_state'],
            target_next_qout = self.target_model(batch_next_state)
        next_q_expect = self._l_operator(target_next_qout)

        batch_rewards = exp_batch['reward']
        batch_terminal = exp_batch['is_state_terminal']

        return (batch_rewards +
                exp_batch['discount'] * (1 - batch_terminal) * next_q_expect)

    def _compute_y_and_t(self, exp_batch):

        batch_state = exp_batch['state']
        batch_size = len(exp_batch['reward'])

        if self.recurrent:
            qout, _ = self.model.n_step_forward(
                batch_state, exp_batch['recurrent_state'],
            qout = self.model(batch_state)

        batch_actions = exp_batch['action']
        # Q(s_t,a_t)
        batch_q = F.reshape(qout.evaluate_actions(
            batch_actions), (batch_size, 1))

        with chainer.no_backprop_mode():
            # Compute target values
            if self.recurrent:
                target_qout, _ = self.target_model.n_step_forward(
                    batch_state, exp_batch['recurrent_state'],
                target_qout = self.target_model(batch_state)

            # Q'(s_t,a_t)
            target_q = F.reshape(target_qout.evaluate_actions(
                batch_actions), (batch_size, 1))

            # LQ'(s_t,a)
            target_q_expect = F.reshape(
                self._l_operator(target_qout), (batch_size, 1))

            # r + g * LQ'(s_{t+1},a)
            batch_q_target = F.reshape(
                self._compute_target_values(exp_batch), (batch_size, 1))

            # Q'(s_t,a_t) + r + g * LQ'(s_{t+1},a) - LQ'(s_t,a)
            t = target_q + batch_q_target - target_q_expect

        return batch_q, t

[docs]class DPP(AbstractDPP): """Dynamic Policy Programming with softmax operator. Args: eta (float): Positive constant. For other arguments, see DQN. """ def __init__(self, *args, **kwargs): self.eta = kwargs.pop('eta', 1.0) super().__init__(*args, **kwargs) def _l_operator(self, qout): return qout.compute_expectation(self.eta)
class DPPL(AbstractDPP): """Dynamic Policy Programming with L operator. Args: eta (float): Positive constant. For other arguments, see DQN. """ def __init__(self, *args, **kwargs): self.eta = kwargs.pop('eta', 1.0) super().__init__(*args, **kwargs) def _l_operator(self, qout): return F.logsumexp(self.eta * qout.q_values, axis=1) / self.eta class DPPGreedy(AbstractDPP): """Dynamic Policy Programming with max operator. This algorithm corresponds to DPP with eta = infinity. """ def _l_operator(self, qout): return qout.max