Source code for chainerrl.experiments.train_agent

import logging
import os

from chainerrl.experiments.evaluator import Evaluator
from chainerrl.experiments.evaluator import save_agent
from chainerrl.misc.ask_yes_no import ask_yes_no


def save_agent_replay_buffer(agent, t, outdir, suffix='', logger=None):
    logger = logger or logging.getLogger(__name__)
    filename = os.path.join(outdir, '{}{}.replay.pkl'.format(t, suffix))
    agent.replay_buffer.save(filename)
    logger.info('Saved the current replay buffer to %s', filename)


def ask_and_save_agent_replay_buffer(agent, t, outdir, suffix=''):
    if hasattr(agent, 'replay_buffer') and \
            ask_yes_no('Replay buffer has {} transitions. Do you save them to a file?'.format(len(agent.replay_buffer))):  # NOQA
        save_agent_replay_buffer(agent, t, outdir, suffix=suffix)


def train_agent(agent, env, steps, outdir, checkpoint_freq=None,
                max_episode_len=None, step_offset=0, evaluator=None,
                successful_score=None, step_hooks=(), logger=None):

    logger = logger or logging.getLogger(__name__)

    episode_r = 0
    episode_idx = 0

    # o_0, r_0
    obs = env.reset()
    r = 0

    t = step_offset
    if hasattr(agent, 't'):
        agent.t = step_offset

    episode_len = 0
    try:
        while t < steps:

            # a_t
            action = agent.act_and_train(obs, r)
            # o_{t+1}, r_{t+1}
            obs, r, done, info = env.step(action)
            t += 1
            episode_r += r
            episode_len += 1

            for hook in step_hooks:
                hook(env, agent, t)

            reset = (episode_len == max_episode_len
                     or info.get('needs_reset', False))
            if done or reset or t == steps:
                agent.stop_episode_and_train(obs, r, done=done)
                logger.info('outdir:%s step:%s episode:%s R:%s',
                            outdir, t, episode_idx, episode_r)
                logger.info('statistics:%s', agent.get_statistics())
                if evaluator is not None:
                    evaluator.evaluate_if_necessary(
                        t=t, episodes=episode_idx + 1)
                    if (successful_score is not None and
                            evaluator.max_score >= successful_score):
                        break
                if t == steps:
                    break
                # Start a new episode
                episode_r = 0
                episode_idx += 1
                episode_len = 0
                obs = env.reset()
                r = 0
            if checkpoint_freq and t % checkpoint_freq == 0:
                save_agent(agent, t, outdir, logger, suffix='_checkpoint')

    except (Exception, KeyboardInterrupt):
        # Save the current model before being killed
        save_agent(agent, t, outdir, logger, suffix='_except')
        raise

    # Save the final model
    save_agent(agent, t, outdir, logger, suffix='_finish')


[docs]def train_agent_with_evaluation(agent, env, steps, eval_n_steps, eval_n_episodes, eval_interval, outdir, checkpoint_freq=None, train_max_episode_len=None, step_offset=0, eval_max_episode_len=None, eval_env=None, successful_score=None, step_hooks=(), save_best_so_far_agent=True, logger=None, ): """Train an agent while periodically evaluating it. Args: agent: A chainerrl.agent.Agent env: Environment train the agent against. steps (int): Total number of timesteps for training. eval_n_steps (int): Number of timesteps at each evaluation phase. eval_n_episodes (int): Number of episodes at each evaluation phase. eval_interval (int): Interval of evaluation. outdir (str): Path to the directory to output data. checkpoint_freq (int): frequency at which agents are stored. train_max_episode_len (int): Maximum episode length during training. step_offset (int): Time step from which training starts. eval_max_episode_len (int or None): Maximum episode length of evaluation runs. If None, train_max_episode_len is used instead. eval_env: Environment used for evaluation. successful_score (float): Finish training if the mean score is greater than or equal to this value if not None step_hooks (Sequence): Sequence of callable objects that accepts (env, agent, step) as arguments. They are called every step. See chainerrl.experiments.hooks. save_best_so_far_agent (bool): If set to True, after each evaluation phase, if the score (= mean return of evaluation episodes) exceeds the best-so-far score, the current agent is saved. logger (logging.Logger): Logger used in this function. """ logger = logger or logging.getLogger(__name__) os.makedirs(outdir, exist_ok=True) if eval_env is None: eval_env = env if eval_max_episode_len is None: eval_max_episode_len = train_max_episode_len evaluator = Evaluator(agent=agent, n_steps=eval_n_steps, n_episodes=eval_n_episodes, eval_interval=eval_interval, outdir=outdir, max_episode_len=eval_max_episode_len, env=eval_env, step_offset=step_offset, save_best_so_far_agent=save_best_so_far_agent, logger=logger, ) train_agent( agent, env, steps, outdir, checkpoint_freq=checkpoint_freq, max_episode_len=train_max_episode_len, step_offset=step_offset, evaluator=evaluator, successful_score=successful_score, step_hooks=step_hooks, logger=logger)