import copy
from logging import getLogger
import chainer
from chainer import cuda
import chainer.functions as F
from chainerrl import agent
from chainerrl.misc.batch_states import batch_states
from chainerrl.misc.copy_param import synchronize_parameters
from chainerrl.replay_buffer import batch_experiences
from chainerrl.replay_buffer import batch_recurrent_experiences
from chainerrl.replay_buffer import ReplayUpdater
def compute_value_loss(y, t, clip_delta=True, batch_accumulator='mean'):
"""Compute a loss for value prediction problem.
Args:
y (Variable or ndarray): Predicted values.
t (Variable or ndarray): Target values.
clip_delta (bool): Use the Huber loss function if set True.
batch_accumulator (str): 'mean' or 'sum'. 'mean' will use the mean of
the loss values in a batch. 'sum' will use the sum.
Returns:
(Variable) scalar loss
"""
assert batch_accumulator in ('mean', 'sum')
y = F.reshape(y, (-1, 1))
t = F.reshape(t, (-1, 1))
if clip_delta:
loss_sum = F.sum(F.huber_loss(y, t, delta=1.0))
if batch_accumulator == 'mean':
loss = loss_sum / y.shape[0]
elif batch_accumulator == 'sum':
loss = loss_sum
else:
loss_mean = F.mean_squared_error(y, t) / 2
if batch_accumulator == 'mean':
loss = loss_mean
elif batch_accumulator == 'sum':
loss = loss_mean * y.shape[0]
return loss
def compute_weighted_value_loss(y, t, weights,
clip_delta=True, batch_accumulator='mean'):
"""Compute a loss for value prediction problem.
Args:
y (Variable or ndarray): Predicted values.
t (Variable or ndarray): Target values.
weights (ndarray): Weights for y, t.
clip_delta (bool): Use the Huber loss function if set True.
batch_accumulator (str): 'mean' will divide loss by batchsize
Returns:
(Variable) scalar loss
"""
assert batch_accumulator in ('mean', 'sum')
y = F.reshape(y, (-1, 1))
t = F.reshape(t, (-1, 1))
if clip_delta:
losses = F.huber_loss(y, t, delta=1.0)
else:
losses = F.square(y - t) / 2
losses = F.reshape(losses, (-1,))
loss_sum = F.sum(losses * weights)
if batch_accumulator == 'mean':
loss = loss_sum / y.shape[0]
elif batch_accumulator == 'sum':
loss = loss_sum
return loss
def _batch_reset_recurrent_states_when_episodes_end(
model, batch_done, batch_reset, recurrent_states):
"""Reset recurrent states when episodes end.
Args:
model (chainer.Link): Model that implements `StatelessRecurrent`.
batch_done (array-like of bool): True iff episodes are terminal.
batch_reset (array-like of bool): True iff episodes will be reset.
recurrent_states (object): Recurrent state.
Returns:
object: New recurrent states.
"""
indices_that_ended = [
i for i, (done, reset)
in enumerate(zip(batch_done, batch_reset)) if done or reset]
if indices_that_ended:
return model.mask_recurrent_state_at(
recurrent_states, indices_that_ended)
else:
return recurrent_states
[docs]class DQN(agent.AttributeSavingMixin, agent.BatchAgent):
"""Deep Q-Network algorithm.
Args:
q_function (StateQFunction): Q-function
optimizer (Optimizer): Optimizer that is already setup
replay_buffer (ReplayBuffer): Replay buffer
gamma (float): Discount factor
explorer (Explorer): Explorer that specifies an exploration strategy.
gpu (int): GPU device id if not None nor negative.
replay_start_size (int): if the replay buffer's size is less than
replay_start_size, skip update
minibatch_size (int): Minibatch size
update_interval (int): Model update interval in step
target_update_interval (int): Target model update interval in step
clip_delta (bool): Clip delta if set True
phi (callable): Feature extractor applied to observations
target_update_method (str): 'hard' or 'soft'.
soft_update_tau (float): Tau of soft target update.
n_times_update (int): Number of repetition of update
average_q_decay (float): Decay rate of average Q, only used for
recording statistics
average_loss_decay (float): Decay rate of average loss, only used for
recording statistics
batch_accumulator (str): 'mean' or 'sum'
episodic_update_len (int or None): Subsequences of this length are used
for update if set int and episodic_update=True
logger (Logger): Logger used
batch_states (callable): method which makes a batch of observations.
default is `chainerrl.misc.batch_states.batch_states`
recurrent (bool): If set to True, `model` is assumed to implement
`chainerrl.links.StatelessRecurrent` and is updated in a recurrent
manner.
"""
saved_attributes = ('model', 'target_model', 'optimizer')
def __init__(self, q_function, optimizer, replay_buffer, gamma,
explorer, gpu=None, replay_start_size=50000,
minibatch_size=32, update_interval=1,
target_update_interval=10000, clip_delta=True,
phi=lambda x: x,
target_update_method='hard',
soft_update_tau=1e-2,
n_times_update=1, average_q_decay=0.999,
average_loss_decay=0.99,
batch_accumulator='mean',
episodic_update_len=None,
logger=getLogger(__name__),
batch_states=batch_states,
recurrent=False,
):
self.model = q_function
self.q_function = q_function # For backward compatibility
if gpu is not None and gpu >= 0:
cuda.get_device_from_id(gpu).use()
self.model.to_gpu(device=gpu)
self.xp = self.model.xp
self.replay_buffer = replay_buffer
self.optimizer = optimizer
self.gamma = gamma
self.explorer = explorer
self.gpu = gpu
self.target_update_interval = target_update_interval
self.clip_delta = clip_delta
self.phi = phi
self.target_update_method = target_update_method
self.soft_update_tau = soft_update_tau
self.batch_accumulator = batch_accumulator
assert batch_accumulator in ('mean', 'sum')
self.logger = logger
self.batch_states = batch_states
self.recurrent = recurrent
if self.recurrent:
update_func = self.update_from_episodes
else:
update_func = self.update
self.replay_updater = ReplayUpdater(
replay_buffer=replay_buffer,
update_func=update_func,
batchsize=minibatch_size,
episodic_update=recurrent,
episodic_update_len=episodic_update_len,
n_times_update=n_times_update,
replay_start_size=replay_start_size,
update_interval=update_interval,
)
self.t = 0
self.last_state = None
self.last_action = None
self.target_model = None
self.sync_target_network()
# For backward compatibility
self.target_q_function = self.target_model
self.average_q = 0
self.average_q_decay = average_q_decay
self.average_loss = 0
self.average_loss_decay = average_loss_decay
# Recurrent states of the model
self.train_recurrent_states = None
self.train_prev_recurrent_states = None
self.test_recurrent_states = None
# Error checking
if (self.replay_buffer.capacity is not None and
self.replay_buffer.capacity <
self.replay_updater.replay_start_size):
raise ValueError(
'Replay start size cannot exceed '
'replay buffer capacity.')
def sync_target_network(self):
"""Synchronize target network with current network."""
if self.target_model is None:
self.target_model = copy.deepcopy(self.model)
call_orig = self.target_model.__call__
def call_test(self_, x):
with chainer.using_config('train', False):
return call_orig(self_, x)
self.target_model.__call__ = call_test
else:
synchronize_parameters(
src=self.model,
dst=self.target_model,
method=self.target_update_method,
tau=self.soft_update_tau)
def update(self, experiences, errors_out=None):
"""Update the model from experiences
Args:
experiences (list): List of lists of dicts.
For DQN, each dict must contains:
- state (object): State
- action (object): Action
- reward (float): Reward
- is_state_terminal (bool): True iff next state is terminal
- next_state (object): Next state
- weight (float, optional): Weight coefficient. It can be
used for importance sampling.
errors_out (list or None): If set to a list, then TD-errors
computed from the given experiences are appended to the list.
Returns:
None
"""
has_weight = 'weight' in experiences[0][0]
exp_batch = batch_experiences(
experiences, xp=self.xp,
phi=self.phi, gamma=self.gamma,
batch_states=self.batch_states)
if has_weight:
exp_batch['weights'] = self.xp.asarray(
[elem[0]['weight']for elem in experiences],
dtype=self.xp.float32)
if errors_out is None:
errors_out = []
loss = self._compute_loss(exp_batch, errors_out=errors_out)
if has_weight:
self.replay_buffer.update_errors(errors_out)
# Update stats
self.average_loss *= self.average_loss_decay
self.average_loss += (1 - self.average_loss_decay) * float(loss.array)
self.model.cleargrads()
loss.backward()
self.optimizer.update()
def update_from_episodes(self, episodes, errors_out=None):
assert errors_out is None,\
"Recurrent DQN does not support PrioritizedBuffer"
exp_batch = batch_recurrent_experiences(
episodes,
model=self.model,
xp=self.xp,
phi=self.phi, gamma=self.gamma,
batch_states=self.batch_states,
)
loss = self._compute_loss(exp_batch, errors_out=None)
# Update stats
self.average_loss *= self.average_loss_decay
self.average_loss += (1 - self.average_loss_decay) * float(loss.array)
self.optimizer.update(lambda: loss)
def _compute_target_values(self, exp_batch):
batch_next_state = exp_batch['next_state']
if self.recurrent:
target_next_qout, _ = self.target_model.n_step_forward(
batch_next_state, exp_batch['next_recurrent_state'],
output_mode='concat')
else:
target_next_qout = self.target_model(batch_next_state)
next_q_max = target_next_qout.max
batch_rewards = exp_batch['reward']
batch_terminal = exp_batch['is_state_terminal']
discount = exp_batch['discount']
return batch_rewards + discount * (1.0 - batch_terminal) * next_q_max
def _compute_y_and_t(self, exp_batch):
batch_size = exp_batch['reward'].shape[0]
# Compute Q-values for current states
batch_state = exp_batch['state']
if self.recurrent:
qout, _ = self.model.n_step_forward(
batch_state,
exp_batch['recurrent_state'],
output_mode='concat',
)
else:
qout = self.model(batch_state)
batch_actions = exp_batch['action']
batch_q = F.reshape(qout.evaluate_actions(
batch_actions), (batch_size, 1))
with chainer.no_backprop_mode():
batch_q_target = F.reshape(
self._compute_target_values(exp_batch),
(batch_size, 1))
return batch_q, batch_q_target
def _compute_loss(self, exp_batch, errors_out=None):
"""Compute the Q-learning loss for a batch of experiences
Args:
exp_batch (dict): A dict of batched arrays of transitions
Returns:
Computed loss from the minibatch of experiences
"""
y, t = self._compute_y_and_t(exp_batch)
if errors_out is not None:
del errors_out[:]
delta = F.absolute(y - t)
if delta.ndim == 2:
delta = F.sum(delta, axis=1)
delta = cuda.to_cpu(delta.array)
for e in delta:
errors_out.append(e)
if 'weights' in exp_batch:
return compute_weighted_value_loss(
y, t, exp_batch['weights'],
clip_delta=self.clip_delta,
batch_accumulator=self.batch_accumulator)
else:
return compute_value_loss(y, t, clip_delta=self.clip_delta,
batch_accumulator=self.batch_accumulator)
def act(self, obs):
with chainer.using_config('train', False), chainer.no_backprop_mode():
action_value =\
self._evaluate_model_and_update_recurrent_states(
[obs], test=True)
q = float(action_value.max.array)
action = cuda.to_cpu(action_value.greedy_actions.array)[0]
# Update stats
self.average_q *= self.average_q_decay
self.average_q += (1 - self.average_q_decay) * q
self.logger.debug('t:%s q:%s action_value:%s', self.t, q, action_value)
return action
def act_and_train(self, obs, reward):
# Observe the consequences
if self.last_state is not None:
assert self.last_action is not None
# Add a transition to the replay buffer
transition = {
'state': self.last_state,
'action': self.last_action,
'reward': reward,
'next_state': obs,
'is_state_terminal': False,
}
if self.recurrent:
transition['recurrent_state'] =\
self.model.get_recurrent_state_at(
self.train_prev_recurrent_states,
0, unwrap_variable=True)
self.train_prev_recurrent_states = None
transition['next_recurrent_state'] =\
self.model.get_recurrent_state_at(
self.train_recurrent_states, 0, unwrap_variable=True)
self.replay_buffer.append(**transition)
# Update the target network
if self.t % self.target_update_interval == 0:
self.sync_target_network()
# Update the model
self.replay_updater.update_if_necessary(self.t)
# Choose an action
with chainer.using_config('train', False), chainer.no_backprop_mode():
action_value =\
self._evaluate_model_and_update_recurrent_states(
[obs], test=False)
q = float(action_value.max.array)
greedy_action = cuda.to_cpu(action_value.greedy_actions.array)[0]
action = self.explorer.select_action(
self.t, lambda: greedy_action, action_value=action_value)
# Update stats
self.average_q *= self.average_q_decay
self.average_q += (1 - self.average_q_decay) * q
self.t += 1
self.last_state = obs
self.last_action = action
self.logger.debug('t:%s q:%s action_value:%s', self.t, q, action_value)
self.logger.debug('t:%s r:%s a:%s', self.t, reward, action)
return self.last_action
def _evaluate_model_and_update_recurrent_states(self, batch_obs, test):
batch_xs = self.batch_states(batch_obs, self.xp, self.phi)
if self.recurrent:
if test:
batch_av, self.test_recurrent_states = self.model(
batch_xs, self.test_recurrent_states)
else:
self.train_prev_recurrent_states = self.train_recurrent_states
batch_av, self.train_recurrent_states = self.model(
batch_xs, self.train_recurrent_states)
else:
batch_av = self.model(batch_xs)
return batch_av
def batch_act_and_train(self, batch_obs):
with chainer.using_config('train', False), chainer.no_backprop_mode():
batch_av = self._evaluate_model_and_update_recurrent_states(
batch_obs, test=False)
batch_maxq = batch_av.max.array
batch_argmax = cuda.to_cpu(batch_av.greedy_actions.array)
batch_action = [
self.explorer.select_action(
self.t, lambda: batch_argmax[i],
action_value=batch_av[i:i + 1],
)
for i in range(len(batch_obs))]
self.batch_last_obs = list(batch_obs)
self.batch_last_action = list(batch_action)
# Update stats
self.average_q *= self.average_q_decay
self.average_q += (1 - self.average_q_decay) * float(batch_maxq.mean())
return batch_action
def batch_act(self, batch_obs):
with chainer.using_config('train', False), chainer.no_backprop_mode():
batch_av = self._evaluate_model_and_update_recurrent_states(
batch_obs, test=True)
batch_argmax = cuda.to_cpu(batch_av.greedy_actions.array)
return batch_argmax
def batch_observe_and_train(self, batch_obs, batch_reward,
batch_done, batch_reset):
for i in range(len(batch_obs)):
self.t += 1
# Update the target network
if self.t % self.target_update_interval == 0:
self.sync_target_network()
if self.batch_last_obs[i] is not None:
assert self.batch_last_action[i] is not None
# Add a transition to the replay buffer
transition = {
'state': self.batch_last_obs[i],
'action': self.batch_last_action[i],
'reward': batch_reward[i],
'next_state': batch_obs[i],
'next_action': None,
'is_state_terminal': batch_done[i],
}
if self.recurrent:
transition['recurrent_state'] =\
self.model.get_recurrent_state_at(
self.train_prev_recurrent_states,
i, unwrap_variable=True)
transition['next_recurrent_state'] =\
self.model.get_recurrent_state_at(
self.train_recurrent_states,
i, unwrap_variable=True)
self.replay_buffer.append(env_id=i, **transition)
if batch_reset[i] or batch_done[i]:
self.batch_last_obs[i] = None
self.batch_last_action[i] = None
self.replay_buffer.stop_current_episode(env_id=i)
self.replay_updater.update_if_necessary(self.t)
if self.recurrent:
# Reset recurrent states when episodes end
self.train_prev_recurrent_states = None
self.train_recurrent_states =\
_batch_reset_recurrent_states_when_episodes_end(
model=self.model,
batch_done=batch_done,
batch_reset=batch_reset,
recurrent_states=self.train_recurrent_states,
)
def batch_observe(self, batch_obs, batch_reward,
batch_done, batch_reset):
if self.recurrent:
# Reset recurrent states when episodes end
self.test_recurrent_states =\
_batch_reset_recurrent_states_when_episodes_end(
model=self.model,
batch_done=batch_done,
batch_reset=batch_reset,
recurrent_states=self.test_recurrent_states,
)
def stop_episode_and_train(self, state, reward, done=False):
"""Observe a terminal state and a reward.
This function must be called once when an episode terminates.
"""
assert self.last_state is not None
assert self.last_action is not None
# Add a transition to the replay buffer
transition = {
'state': self.last_state,
'action': self.last_action,
'reward': reward,
'next_state': state,
'next_action': self.last_action,
'is_state_terminal': done,
}
if self.recurrent:
transition['recurrent_state'] =\
self.model.get_recurrent_state_at(
self.train_prev_recurrent_states, 0, unwrap_variable=True)
self.train_prev_recurrent_states = None
transition['next_recurrent_state'] =\
self.model.get_recurrent_state_at(
self.train_recurrent_states, 0, unwrap_variable=True)
self.replay_buffer.append(**transition)
self.last_state = None
self.last_action = None
if self.recurrent:
self.train_recurrent_states = None
self.replay_buffer.stop_current_episode()
def stop_episode(self):
if self.recurrent:
self.test_recurrent_states = None
def get_statistics(self):
return [
('average_q', self.average_q),
('average_loss', self.average_loss),
('n_updates', self.optimizer.t),
]