Source code for chainerrl.experiments.train_agent_async

import logging
import multiprocessing as mp
import os

from chainerrl.experiments.evaluator import AsyncEvaluator
from chainerrl.misc import async_
from chainerrl.misc import random_seed


def train_loop(process_idx, env, agent, steps, outdir, counter,
               episodes_counter, training_done,
               max_episode_len=None, evaluator=None, eval_env=None,
               successful_score=None, logger=None,
               global_step_hooks=()):

    logger = logger or logging.getLogger(__name__)

    if eval_env is None:
        eval_env = env

    try:

        episode_r = 0
        global_t = 0
        local_t = 0
        global_episodes = 0
        obs = env.reset()
        r = 0
        done = False
        episode_len = 0
        successful = False

        while True:

            # a_t
            a = agent.act_and_train(obs, r)
            # o_{t+1}, r_{t+1}
            obs, r, done, info = env.step(a)
            local_t += 1
            episode_r += r
            episode_len += 1

            # Get and increment the global counter
            with counter.get_lock():
                counter.value += 1
                global_t = counter.value

            for hook in global_step_hooks:
                hook(env, agent, global_t)

            reset = (episode_len == max_episode_len
                     or info.get('needs_reset', False))
            if done or reset or global_t >= steps or training_done.value:
                agent.stop_episode_and_train(obs, r, done)

                if process_idx == 0:
                    logger.info(
                        'outdir:%s global_step:%s local_step:%s R:%s',
                        outdir, global_t, local_t, episode_r)
                    logger.info('statistics:%s', agent.get_statistics())

                # Evaluate the current agent
                if evaluator is not None:
                    eval_score = evaluator.evaluate_if_necessary(
                        t=global_t, episodes=global_episodes,
                        env=eval_env, agent=agent)
                    if (eval_score is not None and
                            successful_score is not None and
                            eval_score >= successful_score):
                        with training_done.get_lock():
                            if not training_done.value:
                                training_done.value = True
                                successful = True
                        # Break immediately in order to avoid an additional
                        # call of agent.act_and_train
                        break

                with episodes_counter.get_lock():
                    episodes_counter.value += 1
                    global_episodes = episodes_counter.value

                if global_t >= steps or training_done.value:
                    break

                # Start a new episode
                episode_r = 0
                episode_len = 0
                obs = env.reset()
                r = 0

    except (Exception, KeyboardInterrupt):
        if process_idx == 0:
            # Save the current model before being killed
            dirname = os.path.join(outdir, '{}_except'.format(global_t))
            agent.save(dirname)
            logger.warning('Saved the current model to %s', dirname)
        raise

    if global_t == steps:
        # Save the final model
        dirname = os.path.join(outdir, '{}_finish'.format(steps))
        agent.save(dirname)
        logger.info('Saved the final agent to %s', dirname)

    if successful:
        # Save the successful model
        dirname = os.path.join(outdir, 'successful')
        agent.save(dirname)
        logger.info('Saved the successful agent to %s', dirname)


def extract_shared_objects_from_agent(agent):
    return dict((attr, async_.as_shared_objects(getattr(agent, attr)))
                for attr in agent.shared_attributes)


def set_shared_objects(agent, shared_objects):
    for attr, shared in shared_objects.items():
        new_value = async_.synchronize_to_shared_objects(
            getattr(agent, attr), shared)
        setattr(agent, attr, new_value)


[docs]def train_agent_async(outdir, processes, make_env, profile=False, steps=8 * 10 ** 7, eval_interval=10 ** 6, eval_n_steps=None, eval_n_episodes=10, max_episode_len=None, step_offset=0, successful_score=None, agent=None, make_agent=None, global_step_hooks=(), save_best_so_far_agent=True, logger=None, ): """Train agent asynchronously using multiprocessing. Either `agent` or `make_agent` must be specified. Args: outdir (str): Path to the directory to output things. processes (int): Number of processes. make_env (callable): (process_idx, test) -> Environment. profile (bool): Profile if set True. steps (int): Number of global time steps for training. eval_interval (int): Interval of evaluation. If set to None, the agent will not be evaluated at all. eval_n_steps (int): Number of eval timesteps at each eval phase eval_n_episodes (int): Number of eval episodes at each eval phase max_episode_len (int): Maximum episode length. step_offset (int): Time step from which training starts. successful_score (float): Finish training if the mean score is greater or equal to this value if not None agent (Agent): Agent to train. make_agent (callable): (process_idx) -> Agent global_step_hooks (Sequence): Sequence of callable objects that accepts (env, agent, step) as arguments. They are called every global step. See chainerrl.experiments.hooks. save_best_so_far_agent (bool): If set to True, after each evaluation, if the score (= mean return of evaluation episodes) exceeds the best-so-far score, the current agent is saved. logger (logging.Logger): Logger used in this function. Returns: Trained agent. """ logger = logger or logging.getLogger(__name__) # Prevent numpy from using multiple threads os.environ['OMP_NUM_THREADS'] = '1' counter = mp.Value('l', 0) episodes_counter = mp.Value('l', 0) training_done = mp.Value('b', False) # bool if agent is None: assert make_agent is not None agent = make_agent(0) shared_objects = extract_shared_objects_from_agent(agent) set_shared_objects(agent, shared_objects) if eval_interval is None: evaluator = None else: evaluator = AsyncEvaluator( n_steps=eval_n_steps, n_episodes=eval_n_episodes, eval_interval=eval_interval, outdir=outdir, max_episode_len=max_episode_len, step_offset=step_offset, save_best_so_far_agent=save_best_so_far_agent, logger=logger, ) def run_func(process_idx): random_seed.set_random_seed(process_idx) env = make_env(process_idx, test=False) if evaluator is None: eval_env = env else: eval_env = make_env(process_idx, test=True) if make_agent is not None: local_agent = make_agent(process_idx) set_shared_objects(local_agent, shared_objects) else: local_agent = agent local_agent.process_idx = process_idx def f(): train_loop( process_idx=process_idx, counter=counter, episodes_counter=episodes_counter, agent=local_agent, env=env, steps=steps, outdir=outdir, max_episode_len=max_episode_len, evaluator=evaluator, successful_score=successful_score, training_done=training_done, eval_env=eval_env, global_step_hooks=global_step_hooks, logger=logger) if profile: import cProfile cProfile.runctx('f()', globals(), locals(), 'profile-{}.out'.format(os.getpid())) else: f() env.close() if eval_env is not env: eval_env.close() async_.run_async(processes, run_func) return agent