Source code for chainerrl.agents.dqn

import copy
from logging import getLogger

import chainer
from chainer import cuda
import chainer.functions as F

from chainerrl import agent
from chainerrl.misc.batch_states import batch_states
from chainerrl.misc.copy_param import synchronize_parameters
from chainerrl.replay_buffer import batch_experiences
from chainerrl.replay_buffer import batch_recurrent_experiences
from chainerrl.replay_buffer import ReplayUpdater


def compute_value_loss(y, t, clip_delta=True, batch_accumulator='mean'):
    """Compute a loss for value prediction problem.

    Args:
        y (Variable or ndarray): Predicted values.
        t (Variable or ndarray): Target values.
        clip_delta (bool): Use the Huber loss function if set True.
        batch_accumulator (str): 'mean' or 'sum'. 'mean' will use the mean of
            the loss values in a batch. 'sum' will use the sum.
    Returns:
        (Variable) scalar loss
    """
    assert batch_accumulator in ('mean', 'sum')
    y = F.reshape(y, (-1, 1))
    t = F.reshape(t, (-1, 1))
    if clip_delta:
        loss_sum = F.sum(F.huber_loss(y, t, delta=1.0))
        if batch_accumulator == 'mean':
            loss = loss_sum / y.shape[0]
        elif batch_accumulator == 'sum':
            loss = loss_sum
    else:
        loss_mean = F.mean_squared_error(y, t) / 2
        if batch_accumulator == 'mean':
            loss = loss_mean
        elif batch_accumulator == 'sum':
            loss = loss_mean * y.shape[0]
    return loss


def compute_weighted_value_loss(y, t, weights,
                                clip_delta=True, batch_accumulator='mean'):
    """Compute a loss for value prediction problem.

    Args:
        y (Variable or ndarray): Predicted values.
        t (Variable or ndarray): Target values.
        weights (ndarray): Weights for y, t.
        clip_delta (bool): Use the Huber loss function if set True.
        batch_accumulator (str): 'mean' will divide loss by batchsize
    Returns:
        (Variable) scalar loss
    """
    assert batch_accumulator in ('mean', 'sum')
    y = F.reshape(y, (-1, 1))
    t = F.reshape(t, (-1, 1))
    if clip_delta:
        losses = F.huber_loss(y, t, delta=1.0)
    else:
        losses = F.square(y - t) / 2
    losses = F.reshape(losses, (-1,))
    loss_sum = F.sum(losses * weights)
    if batch_accumulator == 'mean':
        loss = loss_sum / y.shape[0]
    elif batch_accumulator == 'sum':
        loss = loss_sum
    return loss


def _batch_reset_recurrent_states_when_episodes_end(
        model, batch_done, batch_reset, recurrent_states):
    """Reset recurrent states when episodes end.

    Args:
        model (chainer.Link): Model that implements `StatelessRecurrent`.
        batch_done (array-like of bool): True iff episodes are terminal.
        batch_reset (array-like of bool): True iff episodes will be reset.
        recurrent_states (object): Recurrent state.

    Returns:
        object: New recurrent states.
    """
    indices_that_ended = [
        i for i, (done, reset)
        in enumerate(zip(batch_done, batch_reset)) if done or reset]
    if indices_that_ended:
        return model.mask_recurrent_state_at(
            recurrent_states, indices_that_ended)
    else:
        return recurrent_states


[docs]class DQN(agent.AttributeSavingMixin, agent.BatchAgent): """Deep Q-Network algorithm. Args: q_function (StateQFunction): Q-function optimizer (Optimizer): Optimizer that is already setup replay_buffer (ReplayBuffer): Replay buffer gamma (float): Discount factor explorer (Explorer): Explorer that specifies an exploration strategy. gpu (int): GPU device id if not None nor negative. replay_start_size (int): if the replay buffer's size is less than replay_start_size, skip update minibatch_size (int): Minibatch size update_interval (int): Model update interval in step target_update_interval (int): Target model update interval in step clip_delta (bool): Clip delta if set True phi (callable): Feature extractor applied to observations target_update_method (str): 'hard' or 'soft'. soft_update_tau (float): Tau of soft target update. n_times_update (int): Number of repetition of update average_q_decay (float): Decay rate of average Q, only used for recording statistics average_loss_decay (float): Decay rate of average loss, only used for recording statistics batch_accumulator (str): 'mean' or 'sum' episodic_update_len (int or None): Subsequences of this length are used for update if set int and episodic_update=True logger (Logger): Logger used batch_states (callable): method which makes a batch of observations. default is `chainerrl.misc.batch_states.batch_states` recurrent (bool): If set to True, `model` is assumed to implement `chainerrl.links.StatelessRecurrent` and is updated in a recurrent manner. """ saved_attributes = ('model', 'target_model', 'optimizer') def __init__(self, q_function, optimizer, replay_buffer, gamma, explorer, gpu=None, replay_start_size=50000, minibatch_size=32, update_interval=1, target_update_interval=10000, clip_delta=True, phi=lambda x: x, target_update_method='hard', soft_update_tau=1e-2, n_times_update=1, average_q_decay=0.999, average_loss_decay=0.99, batch_accumulator='mean', episodic_update_len=None, logger=getLogger(__name__), batch_states=batch_states, recurrent=False, ): self.model = q_function self.q_function = q_function # For backward compatibility if gpu is not None and gpu >= 0: cuda.get_device_from_id(gpu).use() self.model.to_gpu(device=gpu) self.xp = self.model.xp self.replay_buffer = replay_buffer self.optimizer = optimizer self.gamma = gamma self.explorer = explorer self.gpu = gpu self.target_update_interval = target_update_interval self.clip_delta = clip_delta self.phi = phi self.target_update_method = target_update_method self.soft_update_tau = soft_update_tau self.batch_accumulator = batch_accumulator assert batch_accumulator in ('mean', 'sum') self.logger = logger self.batch_states = batch_states self.recurrent = recurrent if self.recurrent: update_func = self.update_from_episodes else: update_func = self.update self.replay_updater = ReplayUpdater( replay_buffer=replay_buffer, update_func=update_func, batchsize=minibatch_size, episodic_update=recurrent, episodic_update_len=episodic_update_len, n_times_update=n_times_update, replay_start_size=replay_start_size, update_interval=update_interval, ) self.t = 0 self.last_state = None self.last_action = None self.target_model = None self.sync_target_network() # For backward compatibility self.target_q_function = self.target_model self.average_q = 0 self.average_q_decay = average_q_decay self.average_loss = 0 self.average_loss_decay = average_loss_decay # Recurrent states of the model self.train_recurrent_states = None self.train_prev_recurrent_states = None self.test_recurrent_states = None # Error checking if (self.replay_buffer.capacity is not None and self.replay_buffer.capacity < self.replay_updater.replay_start_size): raise ValueError( 'Replay start size cannot exceed ' 'replay buffer capacity.') def sync_target_network(self): """Synchronize target network with current network.""" if self.target_model is None: self.target_model = copy.deepcopy(self.model) call_orig = self.target_model.__call__ def call_test(self_, x): with chainer.using_config('train', False): return call_orig(self_, x) self.target_model.__call__ = call_test else: synchronize_parameters( src=self.model, dst=self.target_model, method=self.target_update_method, tau=self.soft_update_tau) def update(self, experiences, errors_out=None): """Update the model from experiences Args: experiences (list): List of lists of dicts. For DQN, each dict must contains: - state (object): State - action (object): Action - reward (float): Reward - is_state_terminal (bool): True iff next state is terminal - next_state (object): Next state - weight (float, optional): Weight coefficient. It can be used for importance sampling. errors_out (list or None): If set to a list, then TD-errors computed from the given experiences are appended to the list. Returns: None """ has_weight = 'weight' in experiences[0][0] exp_batch = batch_experiences( experiences, xp=self.xp, phi=self.phi, gamma=self.gamma, batch_states=self.batch_states) if has_weight: exp_batch['weights'] = self.xp.asarray( [elem[0]['weight']for elem in experiences], dtype=self.xp.float32) if errors_out is None: errors_out = [] loss = self._compute_loss(exp_batch, errors_out=errors_out) if has_weight: self.replay_buffer.update_errors(errors_out) # Update stats self.average_loss *= self.average_loss_decay self.average_loss += (1 - self.average_loss_decay) * float(loss.array) self.model.cleargrads() loss.backward() self.optimizer.update() def update_from_episodes(self, episodes, errors_out=None): assert errors_out is None,\ "Recurrent DQN does not support PrioritizedBuffer" exp_batch = batch_recurrent_experiences( episodes, model=self.model, xp=self.xp, phi=self.phi, gamma=self.gamma, batch_states=self.batch_states, ) loss = self._compute_loss(exp_batch, errors_out=None) # Update stats self.average_loss *= self.average_loss_decay self.average_loss += (1 - self.average_loss_decay) * float(loss.array) self.optimizer.update(lambda: loss) def _compute_target_values(self, exp_batch): batch_next_state = exp_batch['next_state'] if self.recurrent: target_next_qout, _ = self.target_model.n_step_forward( batch_next_state, exp_batch['next_recurrent_state'], output_mode='concat') else: target_next_qout = self.target_model(batch_next_state) next_q_max = target_next_qout.max batch_rewards = exp_batch['reward'] batch_terminal = exp_batch['is_state_terminal'] discount = exp_batch['discount'] return batch_rewards + discount * (1.0 - batch_terminal) * next_q_max def _compute_y_and_t(self, exp_batch): batch_size = exp_batch['reward'].shape[0] # Compute Q-values for current states batch_state = exp_batch['state'] if self.recurrent: qout, _ = self.model.n_step_forward( batch_state, exp_batch['recurrent_state'], output_mode='concat', ) else: qout = self.model(batch_state) batch_actions = exp_batch['action'] batch_q = F.reshape(qout.evaluate_actions( batch_actions), (batch_size, 1)) with chainer.no_backprop_mode(): batch_q_target = F.reshape( self._compute_target_values(exp_batch), (batch_size, 1)) return batch_q, batch_q_target def _compute_loss(self, exp_batch, errors_out=None): """Compute the Q-learning loss for a batch of experiences Args: exp_batch (dict): A dict of batched arrays of transitions Returns: Computed loss from the minibatch of experiences """ y, t = self._compute_y_and_t(exp_batch) if errors_out is not None: del errors_out[:] delta = F.absolute(y - t) if delta.ndim == 2: delta = F.sum(delta, axis=1) delta = cuda.to_cpu(delta.array) for e in delta: errors_out.append(e) if 'weights' in exp_batch: return compute_weighted_value_loss( y, t, exp_batch['weights'], clip_delta=self.clip_delta, batch_accumulator=self.batch_accumulator) else: return compute_value_loss(y, t, clip_delta=self.clip_delta, batch_accumulator=self.batch_accumulator) def act(self, obs): with chainer.using_config('train', False), chainer.no_backprop_mode(): action_value =\ self._evaluate_model_and_update_recurrent_states( [obs], test=True) q = float(action_value.max.array) action = cuda.to_cpu(action_value.greedy_actions.array)[0] # Update stats self.average_q *= self.average_q_decay self.average_q += (1 - self.average_q_decay) * q self.logger.debug('t:%s q:%s action_value:%s', self.t, q, action_value) return action def act_and_train(self, obs, reward): # Observe the consequences if self.last_state is not None: assert self.last_action is not None # Add a transition to the replay buffer transition = { 'state': self.last_state, 'action': self.last_action, 'reward': reward, 'next_state': obs, 'is_state_terminal': False, } if self.recurrent: transition['recurrent_state'] =\ self.model.get_recurrent_state_at( self.train_prev_recurrent_states, 0, unwrap_variable=True) self.train_prev_recurrent_states = None transition['next_recurrent_state'] =\ self.model.get_recurrent_state_at( self.train_recurrent_states, 0, unwrap_variable=True) self.replay_buffer.append(**transition) # Update the target network if self.t % self.target_update_interval == 0: self.sync_target_network() # Update the model self.replay_updater.update_if_necessary(self.t) # Choose an action with chainer.using_config('train', False), chainer.no_backprop_mode(): action_value =\ self._evaluate_model_and_update_recurrent_states( [obs], test=False) q = float(action_value.max.array) greedy_action = cuda.to_cpu(action_value.greedy_actions.array)[0] action = self.explorer.select_action( self.t, lambda: greedy_action, action_value=action_value) # Update stats self.average_q *= self.average_q_decay self.average_q += (1 - self.average_q_decay) * q self.t += 1 self.last_state = obs self.last_action = action self.logger.debug('t:%s q:%s action_value:%s', self.t, q, action_value) self.logger.debug('t:%s r:%s a:%s', self.t, reward, action) return self.last_action def _evaluate_model_and_update_recurrent_states(self, batch_obs, test): batch_xs = self.batch_states(batch_obs, self.xp, self.phi) if self.recurrent: if test: batch_av, self.test_recurrent_states = self.model( batch_xs, self.test_recurrent_states) else: self.train_prev_recurrent_states = self.train_recurrent_states batch_av, self.train_recurrent_states = self.model( batch_xs, self.train_recurrent_states) else: batch_av = self.model(batch_xs) return batch_av def batch_act_and_train(self, batch_obs): with chainer.using_config('train', False), chainer.no_backprop_mode(): batch_av = self._evaluate_model_and_update_recurrent_states( batch_obs, test=False) batch_maxq = batch_av.max.array batch_argmax = cuda.to_cpu(batch_av.greedy_actions.array) batch_action = [ self.explorer.select_action( self.t, lambda: batch_argmax[i], action_value=batch_av[i:i + 1], ) for i in range(len(batch_obs))] self.batch_last_obs = list(batch_obs) self.batch_last_action = list(batch_action) # Update stats self.average_q *= self.average_q_decay self.average_q += (1 - self.average_q_decay) * float(batch_maxq.mean()) return batch_action def batch_act(self, batch_obs): with chainer.using_config('train', False), chainer.no_backprop_mode(): batch_av = self._evaluate_model_and_update_recurrent_states( batch_obs, test=True) batch_argmax = cuda.to_cpu(batch_av.greedy_actions.array) return batch_argmax def batch_observe_and_train(self, batch_obs, batch_reward, batch_done, batch_reset): for i in range(len(batch_obs)): self.t += 1 # Update the target network if self.t % self.target_update_interval == 0: self.sync_target_network() if self.batch_last_obs[i] is not None: assert self.batch_last_action[i] is not None # Add a transition to the replay buffer transition = { 'state': self.batch_last_obs[i], 'action': self.batch_last_action[i], 'reward': batch_reward[i], 'next_state': batch_obs[i], 'next_action': None, 'is_state_terminal': batch_done[i], } if self.recurrent: transition['recurrent_state'] =\ self.model.get_recurrent_state_at( self.train_prev_recurrent_states, i, unwrap_variable=True) transition['next_recurrent_state'] =\ self.model.get_recurrent_state_at( self.train_recurrent_states, i, unwrap_variable=True) self.replay_buffer.append(env_id=i, **transition) if batch_reset[i] or batch_done[i]: self.batch_last_obs[i] = None self.batch_last_action[i] = None self.replay_buffer.stop_current_episode(env_id=i) self.replay_updater.update_if_necessary(self.t) if self.recurrent: # Reset recurrent states when episodes end self.train_prev_recurrent_states = None self.train_recurrent_states =\ _batch_reset_recurrent_states_when_episodes_end( model=self.model, batch_done=batch_done, batch_reset=batch_reset, recurrent_states=self.train_recurrent_states, ) def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset): if self.recurrent: # Reset recurrent states when episodes end self.test_recurrent_states =\ _batch_reset_recurrent_states_when_episodes_end( model=self.model, batch_done=batch_done, batch_reset=batch_reset, recurrent_states=self.test_recurrent_states, ) def stop_episode_and_train(self, state, reward, done=False): """Observe a terminal state and a reward. This function must be called once when an episode terminates. """ assert self.last_state is not None assert self.last_action is not None # Add a transition to the replay buffer transition = { 'state': self.last_state, 'action': self.last_action, 'reward': reward, 'next_state': state, 'next_action': self.last_action, 'is_state_terminal': done, } if self.recurrent: transition['recurrent_state'] =\ self.model.get_recurrent_state_at( self.train_prev_recurrent_states, 0, unwrap_variable=True) self.train_prev_recurrent_states = None transition['next_recurrent_state'] =\ self.model.get_recurrent_state_at( self.train_recurrent_states, 0, unwrap_variable=True) self.replay_buffer.append(**transition) self.last_state = None self.last_action = None if self.recurrent: self.train_recurrent_states = None self.replay_buffer.stop_current_episode() def stop_episode(self): if self.recurrent: self.test_recurrent_states = None def get_statistics(self): return [ ('average_q', self.average_q), ('average_loss', self.average_loss), ('n_updates', self.optimizer.t), ]