Source code for chainerrl.agents.nsq

import copy
from logging import getLogger
import multiprocessing as mp

import chainer
from chainer import functions as F
import numpy as np

from chainerrl.agent import AsyncAgent
from chainerrl.agent import AttributeSavingMixin
from chainerrl.misc import async_
from chainerrl.misc.batch_states import batch_states
from chainerrl.misc import copy_param
from chainerrl.recurrent import Recurrent
from chainerrl.recurrent import state_kept

[docs]class NSQ(AttributeSavingMixin, AsyncAgent): """Asynchronous N-step Q-Learning. See Args: q_function (A3CModel): Model to train optimizer (chainer.Optimizer): optimizer used to train the model t_max (int): The model is updated after every t_max local steps gamma (float): Discount factor [0,1] i_target (intn): The target model is updated after every i_target global steps explorer (Explorer): Explorer to use in training phi (callable): Feature extractor function average_q_decay (float): Decay rate of average Q, only used for recording statistics batch_states (callable): method which makes a batch of observations. default is `chainerrl.misc.batch_states.batch_states` """ process_idx = None saved_attributes = ['q_function', 'target_q_function', 'optimizer'] def __init__(self, q_function, optimizer, t_max, gamma, i_target, explorer, phi=lambda x: x, average_q_decay=0.999, logger=getLogger(__name__), batch_states=batch_states): self.shared_q_function = q_function self.target_q_function = copy.deepcopy(q_function) self.q_function = copy.deepcopy(self.shared_q_function) async_.assert_params_not_shared( self.shared_q_function, self.q_function) self.optimizer = optimizer self.t_max = t_max self.gamma = gamma self.explorer = explorer self.i_target = i_target self.phi = phi self.logger = logger self.average_q_decay = average_q_decay self.batch_states = batch_states self.t_global = mp.Value('l', 0) self.t = 0 self.t_start = 0 self.past_action_values = {} self.past_states = {} self.past_rewards = {} self.average_q = 0 def sync_parameters(self): copy_param.copy_param(target_link=self.q_function, source_link=self.shared_q_function) @property def shared_attributes(self): return ('shared_q_function', 'target_q_function', 'optimizer', 't_global') def update(self, statevar): assert self.t_start < self.t # Update if statevar is None: R = 0 else: with state_kept(self.target_q_function): R = float(self.target_q_function(statevar).max.array) loss = 0 for i in reversed(range(self.t_start, self.t)): R *= self.gamma R += self.past_rewards[i] q = F.reshape(self.past_action_values[i], (1, 1)) # Accumulate gradients of Q-function loss += F.sum(F.huber_loss( q, chainer.Variable(np.asarray([[R]], dtype=np.float32)), delta=1.0)) # Do we need to normalize losses by (self.t - self.t_start)? # Otherwise, loss scales can be different in case of self.t_max # and in case of termination. # I'm not sure but if we need to normalize losses... # loss /= self.t - self.t_start # Compute gradients using thread-specific model self.q_function.cleargrads() loss.backward() # Copy the gradients to the globally shared model copy_param.copy_grad(self.shared_q_function, self.q_function) # Update the globally shared model self.optimizer.update() self.sync_parameters() if isinstance(self.q_function, Recurrent): self.q_function.unchain_backward() self.past_action_values = {} self.past_states = {} self.past_rewards = {} self.t_start = self.t def act_and_train(self, obs, reward): statevar = self.batch_states([obs], np, self.phi) self.past_rewards[self.t - 1] = reward if self.t - self.t_start == self.t_max: self.update(statevar) self.past_states[self.t] = statevar if isinstance(self.target_q_function, Recurrent): # Evaluate it to update states self.target_q_function(statevar) qout = self.q_function(statevar) action = self.explorer.select_action( self.t_global.value, lambda: qout.greedy_actions.array[0], action_value=qout) q = qout.evaluate_actions(np.asarray([action])) self.past_action_values[self.t] = q self.t += 1 self.average_q += ((1 - self.average_q_decay) * (float(q.array[0]) - self.average_q)) with self.t_global.get_lock(): self.t_global.value += 1 t_global = self.t_global.value if t_global % self.i_target == 0: self.logger.debug('target synchronized t_global:%s t_local:%s', t_global, self.t) copy_param.copy_param(self.target_q_function, self.q_function) return action def act(self, obs): statevar = self.batch_states([obs], np, self.phi) qout = self.q_function(statevar) self.logger.debug('act action_value: %s', qout) return qout.greedy_actions.array[0] def stop_episode_and_train(self, state, reward, done=False): self.past_rewards[self.t - 1] = reward if done: self.update(None) else: statevar = self.batch_states([state], np, self.phi) self.update(statevar) if isinstance(self.q_function, Recurrent): self.q_function.reset_state() self.shared_q_function.reset_state() self.target_q_function.reset_state() def stop_episode(self): if isinstance(self.q_function, Recurrent): self.q_function.reset_state() self.shared_q_function.reset_state() self.target_q_function.reset_state() def load(self, dirname): super().load(dirname) copy_param.copy_param(target_link=self.shared_q_function, source_link=self.q_function) def get_statistics(self): return [ ('average_q', self.average_q), ]