import copy
from logging import getLogger
import chainer
from chainer import cuda
import chainer.functions as F
from chainerrl.agent import AttributeSavingMixin
from chainerrl.agent import BatchAgent
from chainerrl.misc.batch_states import batch_states
from chainerrl.misc.copy_param import synchronize_parameters
from chainerrl.recurrent import Recurrent
from chainerrl.recurrent import RecurrentChainMixin
from chainerrl.recurrent import state_kept
from chainerrl.replay_buffer import batch_experiences
from chainerrl.replay_buffer import ReplayUpdater
def disable_train(chain):
call_orig = chain.__call__
def call_test(self, x):
with chainer.using_config('train', False):
return call_orig(self, x)
chain.__call__ = call_test
class DDPGModel(chainer.Chain, RecurrentChainMixin):
def __init__(self, policy, q_func):
super().__init__(policy=policy, q_function=q_func)
[docs]class DDPG(AttributeSavingMixin, BatchAgent):
"""Deep Deterministic Policy Gradients.
This can be used as SVG(0) by specifying a Gaussian policy instead of a
deterministic policy.
Args:
model (DDPGModel): DDPG model that contains both a policy and a
Q-function
actor_optimizer (Optimizer): Optimizer setup with the policy
critic_optimizer (Optimizer): Optimizer setup with the Q-function
replay_buffer (ReplayBuffer): Replay buffer
gamma (float): Discount factor
explorer (Explorer): Explorer that specifies an exploration strategy.
gpu (int): GPU device id if not None nor negative.
replay_start_size (int): if the replay buffer's size is less than
replay_start_size, skip update
minibatch_size (int): Minibatch size
update_interval (int): Model update interval in step
target_update_interval (int): Target model update interval in step
phi (callable): Feature extractor applied to observations
target_update_method (str): 'hard' or 'soft'.
soft_update_tau (float): Tau of soft target update.
n_times_update (int): Number of repetition of update
average_q_decay (float): Decay rate of average Q, only used for
recording statistics
average_loss_decay (float): Decay rate of average loss, only used for
recording statistics
batch_accumulator (str): 'mean' or 'sum'
episodic_update (bool): Use full episodes for update if set True
episodic_update_len (int or None): Subsequences of this length are used
for update if set int and episodic_update=True
logger (Logger): Logger used
batch_states (callable): method which makes a batch of observations.
default is `chainerrl.misc.batch_states.batch_states`
burnin_action_func (callable or None): If not None, this callable
object is used to select actions before the model is updated
one or more times during training.
"""
saved_attributes = ('model',
'target_model',
'actor_optimizer',
'critic_optimizer')
def __init__(self, model, actor_optimizer, critic_optimizer, replay_buffer,
gamma, explorer,
gpu=None, replay_start_size=50000,
minibatch_size=32, update_interval=1,
target_update_interval=10000,
phi=lambda x: x,
target_update_method='hard',
soft_update_tau=1e-2,
n_times_update=1, average_q_decay=0.999,
average_loss_decay=0.99,
episodic_update=False,
episodic_update_len=None,
logger=getLogger(__name__),
batch_states=batch_states,
burnin_action_func=None,
):
self.model = model
if gpu is not None and gpu >= 0:
cuda.get_device_from_id(gpu).use()
self.model.to_gpu(device=gpu)
self.xp = self.model.xp
self.replay_buffer = replay_buffer
self.gamma = gamma
self.explorer = explorer
self.gpu = gpu
self.target_update_interval = target_update_interval
self.phi = phi
self.target_update_method = target_update_method
self.soft_update_tau = soft_update_tau
self.logger = logger
self.average_q_decay = average_q_decay
self.average_loss_decay = average_loss_decay
self.actor_optimizer = actor_optimizer
self.critic_optimizer = critic_optimizer
if episodic_update:
update_func = self.update_from_episodes
else:
update_func = self.update
self.replay_updater = ReplayUpdater(
replay_buffer=replay_buffer,
update_func=update_func,
batchsize=minibatch_size,
episodic_update=episodic_update,
episodic_update_len=episodic_update_len,
n_times_update=n_times_update,
replay_start_size=replay_start_size,
update_interval=update_interval,
)
self.batch_states = batch_states
self.burnin_action_func = burnin_action_func
self.t = 0
self.last_state = None
self.last_action = None
self.target_model = copy.deepcopy(self.model)
disable_train(self.target_model['q_function'])
disable_train(self.target_model['policy'])
self.average_q = 0
self.average_actor_loss = 0.0
self.average_critic_loss = 0.0
# Aliases for convenience
self.q_function = self.model['q_function']
self.policy = self.model['policy']
self.target_q_function = self.target_model['q_function']
self.target_policy = self.target_model['policy']
self.sync_target_network()
def sync_target_network(self):
"""Synchronize target network with current network."""
synchronize_parameters(
src=self.model,
dst=self.target_model,
method=self.target_update_method,
tau=self.soft_update_tau)
# Update Q-function
def compute_critic_loss(self, batch):
"""Compute loss for critic.
Preconditions:
target_q_function must have seen up to s_t and a_t.
target_policy must have seen up to s_t.
q_function must have seen up to s_{t-1} and a_{t-1}.
Postconditions:
target_q_function must have seen up to s_{t+1} and a_{t+1}.
target_policy must have seen up to s_{t+1}.
q_function must have seen up to s_t and a_t.
"""
batch_next_state = batch['next_state']
batch_rewards = batch['reward']
batch_terminal = batch['is_state_terminal']
batch_state = batch['state']
batch_actions = batch['action']
batchsize = len(batch_rewards)
with chainer.no_backprop_mode():
# Target policy observes s_{t+1}
next_actions = self.target_policy(
batch_next_state).sample()
# Q(s_{t+1}, mu(a_{t+1})) is evaluated.
# This should not affect the internal state of Q.
with state_kept(self.target_q_function):
next_q = self.target_q_function(batch_next_state, next_actions)
# Target Q-function observes s_{t+1} and a_{t+1}
if isinstance(self.target_q_function, Recurrent):
batch_next_actions = batch['next_action']
self.target_q_function.update_state(
batch_next_state, batch_next_actions)
target_q = batch_rewards + self.gamma * \
(1.0 - batch_terminal) * F.reshape(next_q, (batchsize,))
# Estimated Q-function observes s_t and a_t
predict_q = F.reshape(
self.q_function(batch_state, batch_actions),
(batchsize,))
loss = F.mean_squared_error(target_q, predict_q)
# Update stats
self.average_critic_loss *= self.average_loss_decay
self.average_critic_loss += ((1 - self.average_loss_decay) *
float(loss.array))
return loss
def compute_actor_loss(self, batch):
"""Compute loss for actor.
Preconditions:
q_function must have seen up to s_{t-1} and s_{t-1}.
policy must have seen up to s_{t-1}.
Postconditions:
q_function must have seen up to s_t and s_t.
policy must have seen up to s_t.
"""
batch_state = batch['state']
batch_action = batch['action']
batch_size = len(batch_action)
# Estimated policy observes s_t
onpolicy_actions = self.policy(batch_state).sample()
# Q(s_t, mu(s_t)) is evaluated.
# This should not affect the internal state of Q.
with state_kept(self.q_function):
q = self.q_function(batch_state, onpolicy_actions)
# Estimated Q-function observes s_t and a_t
if isinstance(self.q_function, Recurrent):
self.q_function.update_state(batch_state, batch_action)
# Avoid the numpy #9165 bug (see also: chainer #2744)
q = q[:, :]
# Since we want to maximize Q, loss is negation of Q
loss = - F.sum(q) / batch_size
# Update stats
self.average_actor_loss *= self.average_loss_decay
self.average_actor_loss += ((1 - self.average_loss_decay) *
float(loss.array))
return loss
def update(self, experiences, errors_out=None):
"""Update the model from experiences"""
batch = batch_experiences(experiences, self.xp, self.phi, self.gamma)
self.critic_optimizer.update(lambda: self.compute_critic_loss(batch))
self.actor_optimizer.update(lambda: self.compute_actor_loss(batch))
def update_from_episodes(self, episodes, errors_out=None):
# Sort episodes desc by their lengths
sorted_episodes = list(reversed(sorted(episodes, key=len)))
max_epi_len = len(sorted_episodes[0])
# Precompute all the input batches
batches = []
for i in range(max_epi_len):
transitions = []
for ep in sorted_episodes:
if len(ep) <= i:
break
transitions.append([ep[i]])
batch = batch_experiences(
transitions, xp=self.xp, phi=self.phi, gamma=self.gamma)
batches.append(batch)
with self.model.state_reset(), self.target_model.state_reset():
# Since the target model is evaluated one-step ahead,
# its internal states need to be updated
self.target_q_function.update_state(
batches[0]['state'], batches[0]['action'])
self.target_policy(batches[0]['state'])
# Update critic through time
critic_loss = 0
for batch in batches:
critic_loss += self.compute_critic_loss(batch)
self.critic_optimizer.update(lambda: critic_loss / max_epi_len)
with self.model.state_reset():
# Update actor through time
actor_loss = 0
for batch in batches:
actor_loss += self.compute_actor_loss(batch)
self.actor_optimizer.update(lambda: actor_loss / max_epi_len)
def act_and_train(self, obs, reward):
self.logger.debug('t:%s r:%s', self.t, reward)
if (self.burnin_action_func is not None
and self.actor_optimizer.t == 0):
action = self.burnin_action_func()
else:
greedy_action = self.act(obs)
action = self.explorer.select_action(self.t, lambda: greedy_action)
self.t += 1
# Update the target network
if self.t % self.target_update_interval == 0:
self.sync_target_network()
if self.last_state is not None:
assert self.last_action is not None
# Add a transition to the replay buffer
self.replay_buffer.append(
state=self.last_state,
action=self.last_action,
reward=reward,
next_state=obs,
next_action=action,
is_state_terminal=False)
self.last_state = obs
self.last_action = action
self.replay_updater.update_if_necessary(self.t)
return self.last_action
def act(self, obs):
with chainer.using_config('train', False):
s = self.batch_states([obs], self.xp, self.phi)
action = self.policy(s).sample()
# Q is not needed here, but log it just for information
q = self.q_function(s, action)
# Update stats
self.average_q *= self.average_q_decay
self.average_q += (1 - self.average_q_decay) * float(q.array)
self.logger.debug('t:%s a:%s q:%s',
self.t, action.array[0], q.array)
return cuda.to_cpu(action.array[0])
def batch_act(self, batch_obs):
"""Select a batch of actions for evaluation.
Args:
batch_obs (Sequence of ~object): Observations.
Returns:
Sequence of ~object: Actions.
"""
with chainer.using_config('train', False), chainer.no_backprop_mode():
batch_xs = self.batch_states(batch_obs, self.xp, self.phi)
batch_action = self.policy(batch_xs).sample()
# Q is not needed here, but log it just for information
q = self.q_function(batch_xs, batch_action)
# Update stats
self.average_q *= self.average_q_decay
self.average_q += (1 - self.average_q_decay) * float(
q.array.mean(axis=0))
self.logger.debug('t:%s a:%s q:%s',
self.t, batch_action.array[0], q.array)
return [cuda.to_cpu(action.array) for action in batch_action]
def batch_act_and_train(self, batch_obs):
"""Select a batch of actions for training.
Args:
batch_obs (Sequence of ~object): Observations.
Returns:
Sequence of ~object: Actions.
"""
if (self.burnin_action_func is not None
and self.actor_optimizer.t == 0):
batch_action = [self.burnin_action_func()
for _ in range(len(batch_obs))]
else:
batch_greedy_action = self.batch_act(batch_obs)
batch_action = [
self.explorer.select_action(
self.t, lambda: batch_greedy_action[i])
for i in range(len(batch_greedy_action))]
self.batch_last_obs = list(batch_obs)
self.batch_last_action = list(batch_action)
return batch_action
def batch_observe_and_train(
self, batch_obs, batch_reward, batch_done, batch_reset):
"""Observe a batch of action consequences for training.
Args:
batch_obs (Sequence of ~object): Observations.
batch_reward (Sequence of float): Rewards.
batch_done (Sequence of boolean): Boolean values where True
indicates the current state is terminal.
batch_reset (Sequence of boolean): Boolean values where True
indicates the current episode will be reset, even if the
current state is not terminal.
Returns:
None
"""
for i in range(len(batch_obs)):
self.t += 1
# Update the target network
if self.t % self.target_update_interval == 0:
self.sync_target_network()
if self.batch_last_obs[i] is not None:
assert self.batch_last_action[i] is not None
# Add a transition to the replay buffer
self.replay_buffer.append(
state=self.batch_last_obs[i],
action=self.batch_last_action[i],
reward=batch_reward[i],
next_state=batch_obs[i],
next_action=None,
is_state_terminal=batch_done[i],
env_id=i,
)
if batch_reset[i] or batch_done[i]:
self.batch_last_obs[i] = None
self.replay_buffer.stop_current_episode(env_id=i)
self.replay_updater.update_if_necessary(self.t)
def batch_observe(self, batch_obs, batch_reward,
batch_done, batch_reset):
pass
def stop_episode_and_train(self, state, reward, done=False):
assert self.last_state is not None
assert self.last_action is not None
# Add a transition to the replay buffer
self.replay_buffer.append(
state=self.last_state,
action=self.last_action,
reward=reward,
next_state=state,
next_action=self.last_action,
is_state_terminal=done)
self.stop_episode()
def stop_episode(self):
self.last_state = None
self.last_action = None
if isinstance(self.model, Recurrent):
self.model.reset_state()
self.replay_buffer.stop_current_episode()
def get_statistics(self):
return [
('average_q', self.average_q),
('average_actor_loss', self.average_actor_loss),
('average_critic_loss', self.average_critic_loss),
]