Source code for chainerrl.experiments.collect_demos

import logging
import os

import chainer


[docs]def collect_demonstrations(agent, env, steps, episodes, outdir, max_episode_len=None, logger=None): """Collects demonstrations from an agent and writes them to a file. Args: agent: Agent from which demonstrations are collected. env: Environment in which the agent produces demonstrations. steps (int): Number of total time steps to collect demonstrations for. episodes (int): Number of total episodes to collect demonstrations for. outdir (str): Path to the directory to output demonstrations. max_episode_len (int): Maximum episode length. logger (logging.Logger): Logger used in this function. """ assert (steps is None) != (episodes is None) logger = logger or logging.getLogger(__name__) with chainer.datasets.open_pickle_dataset_writer( os.path.join(outdir, "demos.pickle")) as dataset: # o_0, r_0 terminate = False # True if we should stop collecting demos timestep = 0 # number of timesteps of demos collected episode_num = 0 # number of episodes of demos collected episode_len = 0 # length of most recent episode reset = True # whether to reset environment episode_r = 0 # Episode reward while not terminate: if reset: if episode_num > 0: logger.info('demonstration episode %s length:%s R:%s', episode_num, episode_len, episode_r) obs = env.reset() done = False r = 0 episode_r = 0 episode_len = 0 info = {} # a_t a = agent.act(obs) # o_{t+1}, r_{t+1} new_obs, r, done, info = env.step(a) # o_t, a_t, r__{t+1}, o_{t+1} dataset.write((obs, a, r, new_obs, done, info)) obs = new_obs reset = (done or episode_len == max_episode_len or info.get('needs_reset', False)) timestep += 1 episode_len += 1 episode_r += r episode_num = episode_num + 1 if reset else episode_num if steps is None: terminate = episode_num >= episodes else: terminate = timestep >= steps if reset or terminate: agent.stop_episode() if terminate: # log final episode logger.info('demonstration episode %s length:%s R:%s', episode_num, episode_len, episode_r)