import copy
from logging import getLogger
import chainer
from chainer import functions as F
import numpy as np
from chainerrl import agent
from chainerrl.misc import async_
from chainerrl.misc.batch_states import batch_states
from chainerrl.misc import copy_param
from chainerrl.recurrent import Recurrent
from chainerrl.recurrent import RecurrentChainMixin
from chainerrl.recurrent import state_kept
logger = getLogger(__name__)
class A3CModel(chainer.Link):
"""A3C model."""
def pi_and_v(self, obs):
"""Evaluate the policy and the V-function.
Args:
obs (Variable or ndarray): Batched observations.
Returns:
Distribution and Variable
"""
raise NotImplementedError()
def __call__(self, obs):
return self.pi_and_v(obs)
class A3CSeparateModel(chainer.Chain, A3CModel, RecurrentChainMixin):
"""A3C model that consists of a separate policy and V-function.
Args:
pi (Policy): Policy.
v (VFunction): V-function.
"""
def __init__(self, pi, v):
super().__init__(pi=pi, v=v)
def pi_and_v(self, obs):
pout = self.pi(obs)
vout = self.v(obs)
return pout, vout
class A3CSharedModel(chainer.Chain, A3CModel, RecurrentChainMixin):
"""A3C model where the policy and V-function share parameters.
Args:
shared (Link): Shared part. Nonlinearity must be included in it.
pi (Policy): Policy that receives output of shared as input.
v (VFunction): V-function that receives output of shared as input.
"""
def __init__(self, shared, pi, v):
super().__init__(shared=shared, pi=pi, v=v)
def pi_and_v(self, obs):
h = self.shared(obs)
pout = self.pi(h)
vout = self.v(h)
return pout, vout
[docs]class A3C(agent.AttributeSavingMixin, agent.AsyncAgent):
"""A3C: Asynchronous Advantage Actor-Critic.
See http://arxiv.org/abs/1602.01783
Args:
model (A3CModel): Model to train
optimizer (chainer.Optimizer): optimizer used to train the model
t_max (int): The model is updated after every t_max local steps
gamma (float): Discount factor [0,1]
beta (float): Weight coefficient for the entropy regularizaiton term.
process_idx (int): Index of the process.
phi (callable): Feature extractor function
pi_loss_coef (float): Weight coefficient for the loss of the policy
v_loss_coef (float): Weight coefficient for the loss of the value
function
act_deterministically (bool): If set true, choose most probable actions
in act method.
batch_states (callable): method which makes a batch of observations.
default is `chainerrl.misc.batch_states.batch_states`
"""
process_idx = None
saved_attributes = ['model', 'optimizer']
def __init__(self, model, optimizer, t_max, gamma, beta=1e-2,
process_idx=0, phi=lambda x: x,
pi_loss_coef=1.0, v_loss_coef=0.5,
keep_loss_scale_same=False,
normalize_grad_by_t_max=False,
use_average_reward=False, average_reward_tau=1e-2,
act_deterministically=False,
average_entropy_decay=0.999,
average_value_decay=0.999,
batch_states=batch_states):
assert isinstance(model, A3CModel)
# Globally shared model
self.shared_model = model
# Thread specific model
self.model = copy.deepcopy(self.shared_model)
async_.assert_params_not_shared(self.shared_model, self.model)
self.optimizer = optimizer
self.t_max = t_max
self.gamma = gamma
self.beta = beta
self.phi = phi
self.pi_loss_coef = pi_loss_coef
self.v_loss_coef = v_loss_coef
self.keep_loss_scale_same = keep_loss_scale_same
self.normalize_grad_by_t_max = normalize_grad_by_t_max
self.use_average_reward = use_average_reward
self.average_reward_tau = average_reward_tau
self.act_deterministically = act_deterministically
self.average_value_decay = average_value_decay
self.average_entropy_decay = average_entropy_decay
self.batch_states = batch_states
self.t = 0
self.t_start = 0
self.past_action_log_prob = {}
self.past_action_entropy = {}
self.past_states = {}
self.past_rewards = {}
self.past_values = {}
self.average_reward = 0
# A3C won't use a explorer, but this arrtibute is referenced by run_dqn
self.explorer = None
# Stats
self.average_value = 0
self.average_entropy = 0
def sync_parameters(self):
copy_param.copy_param(target_link=self.model,
source_link=self.shared_model)
@property
def shared_attributes(self):
return ('shared_model', 'optimizer')
def update(self, statevar):
assert self.t_start < self.t
if statevar is None:
R = 0
else:
with state_kept(self.model):
_, vout = self.model.pi_and_v(statevar)
R = float(vout.array)
pi_loss = 0
v_loss = 0
for i in reversed(range(self.t_start, self.t)):
R *= self.gamma
R += self.past_rewards[i]
if self.use_average_reward:
R -= self.average_reward
v = self.past_values[i]
advantage = R - v
if self.use_average_reward:
self.average_reward += self.average_reward_tau * \
float(advantage.array)
# Accumulate gradients of policy
log_prob = self.past_action_log_prob[i]
entropy = self.past_action_entropy[i]
# Log probability is increased proportionally to advantage
pi_loss -= log_prob * float(advantage.array)
# Entropy is maximized
pi_loss -= self.beta * entropy
# Accumulate gradients of value function
v_loss += (v - R) ** 2 / 2
if self.pi_loss_coef != 1.0:
pi_loss *= self.pi_loss_coef
if self.v_loss_coef != 1.0:
v_loss *= self.v_loss_coef
# Normalize the loss of sequences truncated by terminal states
if self.keep_loss_scale_same and \
self.t - self.t_start < self.t_max:
factor = self.t_max / (self.t - self.t_start)
pi_loss *= factor
v_loss *= factor
if self.normalize_grad_by_t_max:
pi_loss /= self.t - self.t_start
v_loss /= self.t - self.t_start
if self.process_idx == 0:
logger.debug('pi_loss:%s v_loss:%s', pi_loss.array, v_loss.array)
total_loss = F.squeeze(pi_loss) + F.squeeze(v_loss)
# Compute gradients using thread-specific model
self.model.cleargrads()
total_loss.backward()
# Copy the gradients to the globally shared model
copy_param.copy_grad(
target_link=self.shared_model, source_link=self.model)
# Update the globally shared model
if self.process_idx == 0:
norm = sum(np.sum(np.square(param.grad))
for param in self.optimizer.target.params()
if param.grad is not None)
logger.debug('grad norm:%s', norm)
self.optimizer.update()
if self.process_idx == 0:
logger.debug('update')
self.sync_parameters()
if isinstance(self.model, Recurrent):
self.model.unchain_backward()
self.past_action_log_prob = {}
self.past_action_entropy = {}
self.past_states = {}
self.past_rewards = {}
self.past_values = {}
self.t_start = self.t
def act_and_train(self, obs, reward):
statevar = self.batch_states([obs], np, self.phi)
self.past_rewards[self.t - 1] = reward
if self.t - self.t_start == self.t_max:
self.update(statevar)
self.past_states[self.t] = statevar
pout, vout = self.model.pi_and_v(statevar)
action = pout.sample().array # Do not backprop through sampled actions
self.past_action_log_prob[self.t] = pout.log_prob(action)
self.past_action_entropy[self.t] = pout.entropy
self.past_values[self.t] = vout
self.t += 1
action = action[0]
if self.process_idx == 0:
logger.debug('t:%s r:%s a:%s pout:%s',
self.t, reward, action, pout)
# Update stats
self.average_value += (
(1 - self.average_value_decay) *
(float(vout.array[0]) - self.average_value))
self.average_entropy += (
(1 - self.average_entropy_decay) *
(float(pout.entropy.array[0]) - self.average_entropy))
return action
def act(self, obs):
# Use the process-local model for acting
with chainer.no_backprop_mode():
statevar = self.batch_states([obs], np, self.phi)
pout, _ = self.model.pi_and_v(statevar)
if self.act_deterministically:
return pout.most_probable.array[0]
else:
return pout.sample().array[0]
def stop_episode_and_train(self, state, reward, done=False):
self.past_rewards[self.t - 1] = reward
if done:
self.update(None)
else:
statevar = self.batch_states([state], np, self.phi)
self.update(statevar)
if isinstance(self.model, Recurrent):
self.model.reset_state()
def stop_episode(self):
if isinstance(self.model, Recurrent):
self.model.reset_state()
def load(self, dirname):
super().load(dirname)
copy_param.copy_param(target_link=self.shared_model,
source_link=self.model)
def get_statistics(self):
return [
('average_value', self.average_value),
('average_entropy', self.average_entropy),
]