Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Generics for new benchmark components #1011

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions avalanche/benchmarks/scenarios/generic_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
################################################################################
from copy import copy
from enum import Enum
from typing import List, Iterable, TypeVar, Union, Generic
from typing import Generator, List, Iterable, TypeVar, Union, Generic

T = TypeVar("T")
E = TypeVar("E")


class MaskedAttributeError(ValueError):
Expand Down Expand Up @@ -151,7 +152,7 @@ def logging(self):
return exp


class CLStream:
class CLStream(Generic[E]):
"""A CL stream is a named iterator of experiences.

In general, many streams may be generator and not explicit lists to avoid
Expand All @@ -176,7 +177,7 @@ def __init__(

self._iter = None

def __iter__(self):
def __iter__(self) -> Generator[E, None, None]:
def foo(self):
for i, exp in enumerate(self.exps_iter):
if self.set_stream_info:
Expand All @@ -187,7 +188,7 @@ def foo(self):
return foo(self)


class EagerCLStream(CLStream):
class EagerCLStream(CLStream[E]):
"""A CL stream which is a named list of experiences.

Eager streams are indexable and sliceable, like python lists.
Expand Down Expand Up @@ -223,7 +224,7 @@ def __init__(
e.origin_stream = self
e.current_experience = i

def __getitem__(self, item):
def __getitem__(self, item) -> Union['EagerCLStream[E]', E]:
# This check allows CL streams slicing
if isinstance(item, slice):
return EagerCLStream(
Expand Down
55 changes: 31 additions & 24 deletions avalanche/benchmarks/scenarios/rl_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class RLScenario(CLScenario):

def __init__(self, envs: List[Env],
n_parallel_envs: Union[int, List[int]],
eval_envs: Union[List[Env], List[Callable[[], Env]]],
eval_envs: Union[List[Env], List[Callable[[], Env]]] = None,
wrappers_generators: Dict[str, List[Wrapper]] = None,
task_labels: bool = True,
shuffle: bool = False):
Expand All @@ -83,7 +83,8 @@ def __init__(self, envs: List[Env],
the same degree of parallelism will be used for every env.
:param eval_envs: list of gym environments
to be used for evaluating the agent. Each environment will
be wrapped within a RLExperience.
be wrapped within a RLExperience.
Passing None or `[]` will result in no evaluation.
:param wrappers_generators: dict mapping env ids to a list of
`gym.Wrapper` generator. Wrappers represent behavior
added as post-processing steps (e.g. reward scaling).
Expand All @@ -95,7 +96,7 @@ def __init__(self, envs: List[Env],
"""

n_experiences = len(envs)
if type(n_parallel_envs) is int:
if isinstance(n_parallel_envs, int):
n_parallel_envs = [n_parallel_envs] * n_experiences
assert len(n_parallel_envs) == len(envs)
# this is so that we can infer the task labels
Expand All @@ -106,24 +107,28 @@ def __init__(self, envs: List[Env],
must be a positive integer"
tr_envs = envs
eval_envs = eval_envs or []
self._num_original_envs = len(tr_envs)
self.n_envs = n_parallel_envs
# this can contain shallow copies of envs to have multiple

def get_unique_task_labels(env_list):
# assign task label by checking whether the same instance of env is
# provided multiple times, using object hash as key
tlabels = []
env_occ = {}
j = 0
for e in env_list:
if e in env_occ:
tlabels.append(env_occ[e])
else:
tlabels.append(j)
env_occ[e] = j
j += 1
return tlabels

# accounts for shallow copies of envs to have multiple
# experiences from the same task
tr_task_labels = []
env_occ = {}
j = 0
# assign task label by checking whether the same instance of env is
# provided multiple times (shallow copy only)
for e in envs:
if e in env_occ:
tr_task_labels.append(env_occ[e])
else:
tr_task_labels.append(j)
env_occ[e] = j
j += 1

# eval_task_labels = list(range(len(eval_envs)))
tr_task_labels = get_unique_task_labels(tr_envs)
eval_task_labels = get_unique_task_labels(eval_envs)

self._wrappers_generators = wrappers_generators

if shuffle:
Expand All @@ -137,19 +142,21 @@ def __init__(self, envs: List[Env],

tr_exps = [RLExperience(tr_envs[i], n_parallel_envs[i],
tr_task_labels[i]) for i in range(len(tr_envs))]
tstream = EagerCLStream("train", tr_exps)
tstream = EagerCLStream[RLExperience]("train", tr_exps)
# we're only supporting single process envs in evaluation atm
eval_exps = [RLExperience(e, 1) for e in eval_envs]
estream = EagerCLStream("eval", eval_exps)
print("EVAL ", eval_task_labels)
eval_exps = [RLExperience(e, 1, l)
for e, l in zip(eval_envs, eval_task_labels)]
estream = EagerCLStream[RLExperience]("eval", eval_exps)

super().__init__([tstream, estream])

@property
def train_stream(self):
def train_stream(self) -> EagerCLStream[RLExperience]:
return self.streams["train_stream"]

@property
def eval_stream(self):
def eval_stream(self) -> EagerCLStream[RLExperience]:
return self.streams["eval_stream"]


Expand Down
16 changes: 14 additions & 2 deletions tests/benchmarks/scenarios/test_rl_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,29 @@ def test_multiple_envs():
envs = [gym.make('CartPole-v0'), gym.make('CartPole-v1'),
gym.make('Acrobot-v1')]
rl_scenario = RLScenario(envs, n_parallel_envs=1,
task_labels=True, eval_envs=[])
task_labels=True, eval_envs=envs[:2])
tr_stream = rl_scenario.train_stream
assert len(tr_stream) == 3

for i, exp in enumerate(tr_stream):
assert exp.current_experience == i == exp.task_label

assert len(rl_scenario.eval_stream) == 2
for i, exp in enumerate(rl_scenario.eval_stream):
assert exp.task_label == i
assert exp.environment.spec.id == envs[i].spec.id

# deep copies of the same env are considered as different tasks
envs = [gym.make('CartPole-v1') for _ in range(3)]
eval_envs = [gym.make('CartPole-v1')] * 2
rl_scenario = RLScenario(envs, n_parallel_envs=1,
task_labels=True, eval_envs=[])
task_labels=True, eval_envs=eval_envs)
for i, exp in enumerate(rl_scenario.train_stream):
assert exp.task_label == i
# while shallow copies in eval behave like the ones in train
assert len(rl_scenario.eval_stream) == 2
for i, exp in enumerate(rl_scenario.eval_stream):
assert exp.task_label == 0
assert exp.environment.spec.id == envs[0].spec.id