Source code for chainerrl.agents.acer

import copy
from logging import getLogger

import chainer
from chainer import functions as F
import numpy as np

from chainerrl.action_value import SingleActionValue
from chainerrl import agent
from chainerrl import distribution
from chainerrl import links
from chainerrl.misc import async_
from chainerrl.misc import copy_param
from chainerrl.recurrent import Recurrent
from chainerrl.recurrent import RecurrentChainMixin
from chainerrl.recurrent import state_kept
from chainerrl.recurrent import state_reset


def compute_importance(pi, mu, x):
    return np.nan_to_num(float(pi.prob(x).array) / float(mu.prob(x).array))


def compute_full_importance(pi, mu):
    pimu = pi.all_prob.array / mu.all_prob.array
    # NaN occurs when inf/inf or 0/0
    pimu[np.isnan(pimu)] = 1
    pimu = np.nan_to_num(pimu)
    return pimu


def compute_policy_gradient_full_correction(
        action_distrib, action_distrib_mu, action_value, v,
        truncation_threshold):
    """Compute off-policy bias correction term wrt all actions."""
    assert truncation_threshold is not None
    assert np.isscalar(v)
    with chainer.no_backprop_mode():
        rho_all_inv = compute_full_importance(action_distrib_mu,
                                              action_distrib)
        correction_weight = (
            np.maximum(1 - truncation_threshold * rho_all_inv,
                       np.zeros_like(rho_all_inv)) *
            action_distrib.all_prob.array[0])
        correction_advantage = action_value.q_values.array[0] - v
    return -F.sum(correction_weight *
                  action_distrib.all_log_prob *
                  correction_advantage, axis=1)


def compute_policy_gradient_sample_correction(
        action_distrib, action_distrib_mu, action_value, v,
        truncation_threshold):
    """Compute off-policy bias correction term wrt a sampled action."""
    assert np.isscalar(v)
    assert truncation_threshold is not None
    with chainer.no_backprop_mode():
        sample_action = action_distrib.sample().array
        rho_dash_inv = compute_importance(
            action_distrib_mu, action_distrib, sample_action)
        if (truncation_threshold > 0 and
                rho_dash_inv >= 1 / truncation_threshold):
            return chainer.Variable(np.asarray([0], dtype=np.float32))
        correction_weight = max(0, 1 - truncation_threshold * rho_dash_inv)
        assert correction_weight <= 1
        q = float(action_value.evaluate_actions(sample_action).array[0])
        correction_advantage = q - v
    return -(correction_weight *
             action_distrib.log_prob(sample_action) *
             correction_advantage)


def compute_policy_gradient_loss(action, advantage, action_distrib,
                                 action_distrib_mu, action_value, v,
                                 truncation_threshold):
    """Compute policy gradient loss with off-policy bias correction."""
    assert np.isscalar(advantage)
    assert np.isscalar(v)
    log_prob = action_distrib.log_prob(action)
    if action_distrib_mu is not None:
        # Off-policy
        rho = compute_importance(
            action_distrib, action_distrib_mu, action)
        g_loss = 0
        if truncation_threshold is None:
            g_loss -= rho * log_prob * advantage
        else:
            # Truncated off-policy policy gradient term
            g_loss -= min(truncation_threshold, rho) * log_prob * advantage
            # Bias correction term
            if isinstance(action_distrib,
                          distribution.CategoricalDistribution):
                g_loss += compute_policy_gradient_full_correction(
                    action_distrib=action_distrib,
                    action_distrib_mu=action_distrib_mu,
                    action_value=action_value,
                    v=v,
                    truncation_threshold=truncation_threshold)
            else:
                g_loss += compute_policy_gradient_sample_correction(
                    action_distrib=action_distrib,
                    action_distrib_mu=action_distrib_mu,
                    action_value=action_value,
                    v=v,
                    truncation_threshold=truncation_threshold)
    else:
        # On-policy
        g_loss = -log_prob * advantage
    return g_loss


class ACERSeparateModel(chainer.Chain, RecurrentChainMixin):
    """ACER model that consists of a separate policy and V-function.

    Args:
        pi (Policy): Policy.
        q (QFunction): Q-function.
    """

    def __init__(self, pi, q):
        super().__init__(pi=pi, q=q)

    def __call__(self, obs):
        action_distrib = self.pi(obs)
        action_value = self.q(obs)
        v = F.sum(action_distrib.all_prob *
                  action_value.q_values, axis=1)
        return action_distrib, action_value, v


class ACERSDNSeparateModel(chainer.Chain, RecurrentChainMixin):
    """ACER model that consists of a separate policy and V-function.

    Args:
        pi (Policy): Policy.
        v (VFunction): V-function.
        adv (StateActionQFunction): Advantage function.
    """

    def __init__(self, pi, v, adv, n=5):
        super().__init__(pi=pi, v=v, adv=adv)
        self.n = n

    def __call__(self, obs):
        action_distrib = self.pi(obs)
        v = self.v(obs)

        def evaluator(action):
            adv_mean = sum(self.adv(obs, action_distrib.sample().array)
                           for _ in range(self.n)) / self.n
            return v + self.adv(obs, action) - adv_mean

        action_value = SingleActionValue(evaluator)

        return action_distrib, action_value, v


class ACERSDNSharedModel(links.Sequence, RecurrentChainMixin):
    """ACER model where the policy and V-function share parameters.

    Args:
        shared (Link): Shared part. Nonlinearity must be included in it.
        pi (Policy): Policy that receives output of shared as input.
        q (QFunction): Q-function that receives output of shared as input.
    """

    def __init__(self, shared, pi, v, adv):
        super().__init__(shared, ACERSDNSeparateModel(pi, v, adv))


class ACERSharedModel(links.Sequence, RecurrentChainMixin):
    """ACER model where the policy and V-function share parameters.

    Args:
        shared (Link): Shared part. Nonlinearity must be included in it.
        pi (Policy): Policy that receives output of shared as input.
        q (QFunction): Q-function that receives output of shared as input.
    """

    def __init__(self, shared, pi, q):
        super().__init__(shared, ACERSeparateModel(pi, q))


def compute_loss_with_kl_constraint(distrib, another_distrib, original_loss,
                                    delta):
    """Compute loss considering a KL constraint.

    Args:
        distrib (Distribution): Distribution to optimize
        another_distrib (Distribution): Distribution used to compute KL
        original_loss (chainer.Variable): Loss to minimize
        delta (float): Minimum KL difference
    Returns:
        loss (chainer.Variable)
    """
    for param in distrib.params:
        assert param.shape[0] == 1
        assert param.requires_grad
    # Compute g: a direction to minimize the original loss
    g = [grad.array[0] for grad in
         chainer.grad([F.squeeze(original_loss)], distrib.params)]

    # Compute k: a direction to increase KL div.
    kl = F.squeeze(another_distrib.kl(distrib))
    k = [grad.array[0] for grad in
         chainer.grad([-kl], distrib.params)]

    # Compute z: combination of g and k to keep small KL div.
    kg_dot = sum(np.dot(kp.ravel(), gp.ravel())
                 for kp, gp in zip(k, g))
    kk_dot = sum(np.dot(kp.ravel(), kp.ravel()) for kp in k)
    if kk_dot > 0:
        k_factor = max(0, ((kg_dot - delta) / kk_dot))
    else:
        k_factor = 0
    z = [gp - k_factor * kp for kp, gp in zip(k, g)]
    loss = 0
    for p, zp in zip(distrib.params, z):
        loss += F.sum(p * zp)
    return F.reshape(loss, original_loss.shape), float(kl.array)


[docs]class ACER(agent.AttributeSavingMixin, agent.AsyncAgent): """ACER (Actor-Critic with Experience Replay). See http://arxiv.org/abs/1611.01224 Args: model (ACERModel): Model to train. It must be a callable that accepts observations as input and return three values: action distributions (Distribution), Q values (ActionValue) and state values (chainer.Variable). optimizer (chainer.Optimizer): optimizer used to train the model t_max (int): The model is updated after every t_max local steps gamma (float): Discount factor [0,1] replay_buffer (EpisodicReplayBuffer): Replay buffer to use. If set None, this agent won't use experience replay. beta (float): Weight coefficient for the entropy regularizaiton term. phi (callable): Feature extractor function pi_loss_coef (float): Weight coefficient for the loss of the policy Q_loss_coef (float): Weight coefficient for the loss of the value function use_trust_region (bool): If set true, use efficient TRPO. trust_region_alpha (float): Decay rate of the average model used for efficient TRPO. trust_region_delta (float): Threshold used for efficient TRPO. truncation_threshold (float or None): Threshold used to truncate larger importance weights. If set None, importance weights are not truncated. disable_online_update (bool): If set true, disable online on-policy update and rely only on experience replay. n_times_replay (int): Number of times experience replay is repeated per one time of online update. replay_start_size (int): Experience replay is disabled if the number of transitions in the replay buffer is lower than this value. normalize_loss_by_steps (bool): If set true, losses are normalized by the number of steps taken to accumulate the losses act_deterministically (bool): If set true, choose most probable actions in act method. use_Q_opc (bool): If set true, use Q_opc, a Q-value estimate without importance sampling, is used to compute advantage values for policy gradients. The original paper recommend to use in case of continuous action. average_entropy_decay (float): Decay rate of average entropy. Used only to record statistics. average_value_decay (float): Decay rate of average value. Used only to record statistics. average_kl_decay (float): Decay rate of kl value. Used only to record statistics. """ process_idx = None saved_attributes = ['model', 'optimizer'] def __init__(self, model, optimizer, t_max, gamma, replay_buffer, beta=1e-2, phi=lambda x: x, pi_loss_coef=1.0, Q_loss_coef=0.5, use_trust_region=True, trust_region_alpha=0.99, trust_region_delta=1, truncation_threshold=10, disable_online_update=False, n_times_replay=8, replay_start_size=10 ** 4, normalize_loss_by_steps=True, act_deterministically=False, use_Q_opc=False, average_entropy_decay=0.999, average_value_decay=0.999, average_kl_decay=0.999, logger=None): # Globally shared model self.shared_model = model # Globally shared average model used to compute trust regions self.shared_average_model = copy.deepcopy(self.shared_model) # Thread specific model self.model = copy.deepcopy(self.shared_model) async_.assert_params_not_shared(self.shared_model, self.model) self.optimizer = optimizer self.replay_buffer = replay_buffer self.t_max = t_max self.gamma = gamma self.beta = beta self.phi = phi self.pi_loss_coef = pi_loss_coef self.Q_loss_coef = Q_loss_coef self.normalize_loss_by_steps = normalize_loss_by_steps self.act_deterministically = act_deterministically self.use_trust_region = use_trust_region self.trust_region_alpha = trust_region_alpha self.truncation_threshold = truncation_threshold self.trust_region_delta = trust_region_delta self.disable_online_update = disable_online_update self.n_times_replay = n_times_replay self.use_Q_opc = use_Q_opc self.replay_start_size = replay_start_size self.average_value_decay = average_value_decay self.average_entropy_decay = average_entropy_decay self.average_kl_decay = average_kl_decay self.logger = logger if logger else getLogger(__name__) self.t = 0 self.last_state = None self.last_action = None # ACER won't use a explorer, but this arrtibute is referenced by # run_dqn self.explorer = None # Stats self.average_value = 0 self.average_entropy = 0 self.average_kl = 0 self.init_history_data_for_online_update() def init_history_data_for_online_update(self): self.past_states = {} self.past_actions = {} self.past_rewards = {} self.past_values = {} self.past_action_distrib = {} self.past_action_values = {} self.past_avg_action_distrib = {} self.t_start = self.t def sync_parameters(self): copy_param.copy_param(target_link=self.model, source_link=self.shared_model) copy_param.soft_copy_param(target_link=self.shared_average_model, source_link=self.model, tau=1 - self.trust_region_alpha) @property def shared_attributes(self): return ('shared_model', 'shared_average_model', 'optimizer') def compute_one_step_pi_loss(self, action, advantage, action_distrib, action_distrib_mu, action_value, v, avg_action_distrib): assert np.isscalar(advantage) assert np.isscalar(v) g_loss = compute_policy_gradient_loss( action=action, advantage=advantage, action_distrib=action_distrib, action_distrib_mu=action_distrib_mu, action_value=action_value, v=v, truncation_threshold=self.truncation_threshold) if self.use_trust_region: pi_loss, kl = compute_loss_with_kl_constraint( action_distrib, avg_action_distrib, g_loss, delta=self.trust_region_delta) self.average_kl += ( (1 - self.average_kl_decay) * (kl - self.average_kl)) else: pi_loss = g_loss # Entropy is maximized pi_loss -= self.beta * action_distrib.entropy return pi_loss def compute_loss( self, t_start, t_stop, R, states, actions, rewards, values, action_values, action_distribs, action_distribs_mu, avg_action_distribs): assert np.isscalar(R) pi_loss = 0 Q_loss = 0 Q_ret = R Q_opc = R discrete = isinstance(action_distribs[t_start], distribution.CategoricalDistribution) del R for i in reversed(range(t_start, t_stop)): r = rewards[i] v = values[i] action_distrib = action_distribs[i] action_distrib_mu = (action_distribs_mu[i] if action_distribs_mu else None) avg_action_distrib = avg_action_distribs[i] action_value = action_values[i] ba = np.expand_dims(actions[i], 0) if action_distrib_mu is not None: # Off-policy rho = compute_importance(action_distrib, action_distrib_mu, ba) else: # On-policy rho = 1 Q_ret = r + self.gamma * Q_ret Q_opc = r + self.gamma * Q_opc assert np.isscalar(Q_ret) assert np.isscalar(Q_opc) if self.use_Q_opc: advantage = Q_opc - float(v.array) else: advantage = Q_ret - float(v.array) pi_loss += self.compute_one_step_pi_loss( action=ba, advantage=advantage, action_distrib=action_distrib, action_distrib_mu=action_distrib_mu, action_value=action_value, v=float(v.array), avg_action_distrib=avg_action_distrib) # Accumulate gradients of value function Q = action_value.evaluate_actions(ba) assert isinstance(Q, chainer.Variable), "Q must be backprop-able" Q_loss += (Q_ret - Q) ** 2 / 2 if not discrete: assert isinstance(v, chainer.Variable), \ "v must be backprop-able" v_target = (min(1, rho) * (Q_ret - float(Q.array)) + float(v.array)) Q_loss += (v_target - v) ** 2 / 2 if self.process_idx == 0: self.logger.debug( 't:%s v:%s Q:%s Q_ret:%s Q_opc:%s', i, float(v.array), float(Q.array), Q_ret, Q_opc) if discrete: c = min(1, rho) else: c = min(1, rho ** (1 / ba.size)) Q_ret = c * (Q_ret - float(Q.array)) + float(v.array) Q_opc = Q_opc - float(Q.array) + float(v.array) pi_loss *= self.pi_loss_coef Q_loss *= self.Q_loss_coef if self.normalize_loss_by_steps: pi_loss /= t_stop - t_start Q_loss /= t_stop - t_start if self.process_idx == 0: self.logger.debug('pi_loss:%s Q_loss:%s', pi_loss.array, Q_loss.array) return pi_loss + F.reshape(Q_loss, pi_loss.array.shape) def update(self, t_start, t_stop, R, states, actions, rewards, values, action_values, action_distribs, action_distribs_mu, avg_action_distribs): assert np.isscalar(R) total_loss = self.compute_loss( t_start=t_start, t_stop=t_stop, R=R, states=states, actions=actions, rewards=rewards, values=values, action_values=action_values, action_distribs=action_distribs, action_distribs_mu=action_distribs_mu, avg_action_distribs=avg_action_distribs) # Compute gradients using thread-specific model self.model.cleargrads() F.squeeze(total_loss).backward() # Copy the gradients to the globally shared model copy_param.copy_grad( target_link=self.shared_model, source_link=self.model) # Update the globally shared model if self.process_idx == 0: norm = sum(np.sum(np.square(param.grad)) for param in self.optimizer.target.params() if param.grad is not None) self.logger.debug('grad norm:%s', norm) self.optimizer.update() self.sync_parameters() if isinstance(self.model, Recurrent): self.model.unchain_backward() def update_from_replay(self): if self.replay_buffer is None: return if len(self.replay_buffer) < self.replay_start_size: return episode = self.replay_buffer.sample_episodes(1, self.t_max)[0] with state_reset(self.model): with state_reset(self.shared_average_model): rewards = {} states = {} actions = {} action_distribs = {} action_distribs_mu = {} avg_action_distribs = {} action_values = {} values = {} for t, transition in enumerate(episode): s = self.phi(transition['state']) a = transition['action'] bs = np.expand_dims(s, 0) action_distrib, action_value, v = self.model(bs) with chainer.no_backprop_mode(): avg_action_distrib, _, _ = \ self.shared_average_model(bs) states[t] = s actions[t] = a values[t] = v action_distribs[t] = action_distrib avg_action_distribs[t] = avg_action_distrib rewards[t] = transition['reward'] action_distribs_mu[t] = transition['mu'] action_values[t] = action_value last_transition = episode[-1] if last_transition['is_state_terminal']: R = 0 else: with chainer.no_backprop_mode(): last_s = last_transition['next_state'] action_distrib, action_value, last_v = self.model( np.expand_dims(self.phi(last_s), 0)) R = float(last_v.array) return self.update( R=R, t_start=0, t_stop=len(episode), states=states, rewards=rewards, actions=actions, values=values, action_distribs=action_distribs, action_distribs_mu=action_distribs_mu, avg_action_distribs=avg_action_distribs, action_values=action_values) def update_on_policy(self, statevar): assert self.t_start < self.t if not self.disable_online_update: if statevar is None: R = 0 else: with chainer.no_backprop_mode(): with state_kept(self.model): action_distrib, action_value, v = self.model(statevar) R = float(v.array) self.update( t_start=self.t_start, t_stop=self.t, R=R, states=self.past_states, actions=self.past_actions, rewards=self.past_rewards, values=self.past_values, action_values=self.past_action_values, action_distribs=self.past_action_distrib, action_distribs_mu=None, avg_action_distribs=self.past_avg_action_distrib) self.init_history_data_for_online_update() def act_and_train(self, obs, reward): statevar = np.expand_dims(self.phi(obs), 0) self.past_rewards[self.t - 1] = reward if self.t - self.t_start == self.t_max: self.update_on_policy(statevar) for _ in range(self.n_times_replay): self.update_from_replay() self.past_states[self.t] = statevar action_distrib, action_value, v = self.model(statevar) self.past_action_values[self.t] = action_value action = action_distrib.sample().array[0] # Save values for a later update self.past_values[self.t] = v self.past_action_distrib[self.t] = action_distrib with chainer.no_backprop_mode(): avg_action_distrib, _, _ = self.shared_average_model( statevar) self.past_avg_action_distrib[self.t] = avg_action_distrib self.past_actions[self.t] = action self.t += 1 if self.process_idx == 0: self.logger.debug('t:%s r:%s a:%s action_distrib:%s', self.t, reward, action, action_distrib) # Update stats self.average_value += ( (1 - self.average_value_decay) * (float(v.array[0]) - self.average_value)) self.average_entropy += ( (1 - self.average_entropy_decay) * (float(action_distrib.entropy.array[0]) - self.average_entropy)) if self.replay_buffer is not None and self.last_state is not None: assert self.last_action is not None assert self.last_action_distrib is not None # Add a transition to the replay buffer self.replay_buffer.append( state=self.last_state, action=self.last_action, reward=reward, next_state=obs, next_action=action, is_state_terminal=False, mu=self.last_action_distrib, ) self.last_state = obs self.last_action = action self.last_action_distrib = action_distrib.copy() return action def act(self, obs): # Use the process-local model for acting with chainer.no_backprop_mode(): statevar = np.expand_dims(self.phi(obs), 0) action_distrib, _, _ = self.model(statevar) if self.act_deterministically: return action_distrib.most_probable.array[0] else: return action_distrib.sample().array[0] def stop_episode_and_train(self, state, reward, done=False): assert self.last_state is not None assert self.last_action is not None self.past_rewards[self.t - 1] = reward if done: self.update_on_policy(None) else: statevar = np.expand_dims(self.phi(state), 0) self.update_on_policy(statevar) for _ in range(self.n_times_replay): self.update_from_replay() if isinstance(self.model, Recurrent): self.model.reset_state() self.shared_average_model.reset_state() # Add a transition to the replay buffer if self.replay_buffer is not None: self.replay_buffer.append( state=self.last_state, action=self.last_action, reward=reward, next_state=state, next_action=self.last_action, is_state_terminal=done, mu=self.last_action_distrib) self.replay_buffer.stop_current_episode() self.last_state = None self.last_action = None self.last_action_distrib = None def stop_episode(self): if isinstance(self.model, Recurrent): self.model.reset_state() self.shared_average_model.reset_state() def load(self, dirname): super().load(dirname) copy_param.copy_param(target_link=self.shared_model, source_link=self.model) def get_statistics(self): return [ ('average_value', self.average_value), ('average_entropy', self.average_entropy), ('average_kl', self.average_kl), ]