Source code for chainerrl.replay_buffers.episodic

import collections
import pickle

from chainerrl.misc.collections import RandomAccessQueue
from chainerrl.replay_buffer import AbstractEpisodicReplayBuffer
from chainerrl.replay_buffer import random_subseq


[docs]class EpisodicReplayBuffer(AbstractEpisodicReplayBuffer): def __init__(self, capacity=None): self.current_episode = collections.defaultdict(list) self.episodic_memory = RandomAccessQueue() self.memory = RandomAccessQueue() self.capacity = capacity def append(self, state, action, reward, next_state=None, next_action=None, is_state_terminal=False, env_id=0, **kwargs): current_episode = self.current_episode[env_id] experience = dict(state=state, action=action, reward=reward, next_state=next_state, next_action=next_action, is_state_terminal=is_state_terminal, **kwargs) current_episode.append(experience) if is_state_terminal: self.stop_current_episode(env_id=env_id) def sample(self, n): assert len(self.memory) >= n return self.memory.sample(n) def sample_episodes(self, n_episodes, max_len=None): assert len(self.episodic_memory) >= n_episodes episodes = self.episodic_memory.sample(n_episodes) if max_len is not None: return [random_subseq(ep, max_len) for ep in episodes] else: return episodes def __len__(self): return len(self.memory) @property def n_episodes(self): return len(self.episodic_memory) def save(self, filename): with open(filename, 'wb') as f: pickle.dump((self.memory, self.episodic_memory), f) def load(self, filename): with open(filename, 'rb') as f: memory = pickle.load(f) if isinstance(memory, tuple): self.memory, self.episodic_memory = memory else: # Load v0.2 # FIXME: The code works with EpisodicReplayBuffer # but not with PrioritizedEpisodicReplayBuffer self.memory = RandomAccessQueue(memory) self.episodic_memory = RandomAccessQueue() # Recover episodic_memory with best effort. episode = [] for item in self.memory: episode.append(item) if item['is_state_terminal']: self.episodic_memory.append(episode) episode = [] def stop_current_episode(self, env_id=0): current_episode = self.current_episode[env_id] if current_episode: self.episodic_memory.append(current_episode) for transition in current_episode: self.memory.append([transition]) self.current_episode[env_id] = [] while self.capacity is not None and \ len(self.memory) > self.capacity: discarded_episode = self.episodic_memory.popleft() for _ in range(len(discarded_episode)): self.memory.popleft() assert not self.current_episode[env_id]