Source code for chainerrl.experiments.train_agent

from __future__ import print_function
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
from future import standard_library
standard_library.install_aliases()

import logging
import os

from chainerrl.experiments.evaluator import Evaluator
from chainerrl.experiments.evaluator import save_agent
from chainerrl.misc.ask_yes_no import ask_yes_no
from chainerrl.misc.makedirs import makedirs


def save_agent_replay_buffer(agent, t, outdir, suffix='', logger=None):
    logger = logger or logging.getLogger(__name__)
    filename = os.path.join(outdir, '{}{}.replay.pkl'.format(t, suffix))
    agent.replay_buffer.save(filename)
    logger.info('Saved the current replay buffer to %s', filename)


def ask_and_save_agent_replay_buffer(agent, t, outdir, suffix=''):
    if hasattr(agent, 'replay_buffer') and \
            ask_yes_no('Replay buffer has {} transitions. Do you save them to a file?'.format(len(agent.replay_buffer))):  # NOQA
        save_agent_replay_buffer(agent, t, outdir, suffix=suffix)


def train_agent(agent, env, steps, outdir, max_episode_len=None,
                step_offset=0, evaluator=None, successful_score=None,
                step_hooks=[], logger=None):

    logger = logger or logging.getLogger(__name__)

    episode_r = 0
    episode_idx = 0

    # o_0, r_0
    obs = env.reset()
    r = 0
    done = False

    t = step_offset
    if hasattr(agent, 't'):
        agent.t = step_offset

    episode_len = 0
    try:
        while t < steps:

            # a_t
            action = agent.act_and_train(obs, r)
            # o_{t+1}, r_{t+1}
            obs, r, done, info = env.step(action)
            t += 1
            episode_r += r
            episode_len += 1

            for hook in step_hooks:
                hook(env, agent, t)

            if done or episode_len == max_episode_len or t == steps:
                agent.stop_episode_and_train(obs, r, done=done)
                logger.info('outdir:%s step:%s episode:%s R:%s',
                            outdir, t, episode_idx, episode_r)
                logger.info('statistics:%s', agent.get_statistics())
                if evaluator is not None:
                    evaluator.evaluate_if_necessary(
                        t=t, episodes=episode_idx + 1)
                    if (successful_score is not None and
                            evaluator.max_score >= successful_score):
                        break
                if t == steps:
                    break
                # Start a new episode
                episode_r = 0
                episode_idx += 1
                episode_len = 0
                obs = env.reset()
                r = 0
                done = False

    except Exception:
        # Save the current model before being killed
        save_agent(agent, t, outdir, logger, suffix='_except')
        raise

    # Save the final model
    save_agent(agent, t, outdir, logger, suffix='_finish')


[docs]def train_agent_with_evaluation( agent, env, steps, eval_n_runs, eval_interval, outdir, max_episode_len=None, step_offset=0, eval_explorer=None, eval_max_episode_len=None, eval_env=None, successful_score=None, step_hooks=[], logger=None): """Train an agent while regularly evaluating it. Args: agent: Agent to train. env: Environment train the againt against. steps (int): Number of total time steps for training. eval_n_runs (int): Number of runs for each time of evaluation. eval_interval (int): Interval of evaluation. outdir (str): Path to the directory to output things. max_episode_len (int): Maximum episode length. step_offset (int): Time step from which training starts. eval_explorer: Explorer used for evaluation. eval_env: Environment used for evaluation. successful_score (float): Finish training if the mean score is greater or equal to this value if not None step_hooks (list): List of callable objects that accepts (env, agent, step) as arguments. They are called every step. See chainerrl.experiments.hooks. logger (logging.Logger): Logger used in this function. """ logger = logger or logging.getLogger(__name__) makedirs(outdir, exist_ok=True) if eval_env is None: eval_env = env if eval_max_episode_len is None: eval_max_episode_len = max_episode_len evaluator = Evaluator(agent=agent, n_runs=eval_n_runs, eval_interval=eval_interval, outdir=outdir, max_episode_len=eval_max_episode_len, explorer=eval_explorer, env=eval_env, step_offset=step_offset, logger=logger) train_agent( agent, env, steps, outdir, max_episode_len=max_episode_len, step_offset=step_offset, evaluator=evaluator, successful_score=successful_score, step_hooks=step_hooks, logger=logger)