Experiments

Training and evaluation

chainerrl.experiments.train_agent_async(outdir, processes, make_env, profile=False, steps=80000000, eval_interval=1000000, eval_n_steps=None, eval_n_episodes=10, max_episode_len=None, step_offset=0, successful_score=None, agent=None, make_agent=None, global_step_hooks=[], save_best_so_far_agent=True, logger=None)[source]

Train agent asynchronously using multiprocessing.

Either agent or make_agent must be specified.

Parameters:
  • outdir (str) – Path to the directory to output things.
  • processes (int) – Number of processes.
  • make_env (callable) – (process_idx, test) -> Environment.
  • profile (bool) – Profile if set True.
  • steps (int) – Number of global time steps for training.
  • eval_interval (int) – Interval of evaluation. If set to None, the agent will not be evaluated at all.
  • eval_n_steps (int) – Number of eval timesteps at each eval phase
  • eval_n_episodes (int) – Number of eval episodes at each eval phase
  • max_episode_len (int) – Maximum episode length.
  • step_offset (int) – Time step from which training starts.
  • successful_score (float) – Finish training if the mean score is greater or equal to this value if not None
  • agent (Agent) – Agent to train.
  • make_agent (callable) – (process_idx) -> Agent
  • global_step_hooks (list) – List of callable objects that accepts (env, agent, step) as arguments. They are called every global step. See chainerrl.experiments.hooks.
  • save_best_so_far_agent (bool) – If set to True, after each evaluation, if the score (= mean return of evaluation episodes) exceeds the best-so-far score, the current agent is saved.
  • logger (logging.Logger) – Logger used in this function.
Returns:

Trained agent.

chainerrl.experiments.train_agent_with_evaluation(agent, env, steps, eval_n_steps, eval_n_episodes, eval_interval, outdir, train_max_episode_len=None, step_offset=0, eval_max_episode_len=None, eval_env=None, successful_score=None, step_hooks=[], save_best_so_far_agent=True, logger=None)[source]

Train an agent while periodically evaluating it.

Parameters:
  • agent – A chainerrl.agent.Agent
  • env – Environment train the agent against.
  • steps (int) – Total number of timesteps for training.
  • eval_n_steps (int) – Number of timesteps at each evaluation phase.
  • eval_n_episodes (int) – Number of episodes at each evaluation phase.
  • eval_interval (int) – Interval of evaluation.
  • outdir (str) – Path to the directory to output data.
  • train_max_episode_len (int) – Maximum episode length during training.
  • step_offset (int) – Time step from which training starts.
  • eval_max_episode_len (int or None) – Maximum episode length of evaluation runs. If None, train_max_episode_len is used instead.
  • eval_env – Environment used for evaluation.
  • successful_score (float) – Finish training if the mean score is greater than or equal to this value if not None
  • step_hooks (list) – List of callable objects that accepts (env, agent, step) as arguments. They are called every step. See chainerrl.experiments.hooks.
  • save_best_so_far_agent (bool) – If set to True, after each evaluation phase, if the score (= mean return of evaluation episodes) exceeds the best-so-far score, the current agent is saved.
  • logger (logging.Logger) – Logger used in this function.

Training hooks

class chainerrl.experiments.StepHook[source]

Hook function that will be called in training.

This class is for clarifying the interface required for Hook functions. You don’t need to inherit this class to define your own hooks. Any callable that accepts (env, agent, step) as arguments can be used as a hook.

class chainerrl.experiments.LinearInterpolationHook(total_steps, start_value, stop_value, setter)[source]

Hook that will set a linearly interpolated value.

You can use this hook to decay the learning rate by using a setter function as follows:

def lr_setter(env, agent, value):
    agent.optimizer.lr = value

hook = LinearInterpolationHook(10 ** 6, 1e-3, 0, lr_setter)
Parameters:
  • total_steps (int) – Number of total steps.
  • start_value (float) – Start value.
  • stop_value (float) – Stop value.
  • setter (callable) – (env, agent, value) -> None