-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Environment Factory+Habitat 2.0 Code Cleanup #1401
Changes from 53 commits
03acdf8
c286381
2c74d33
3d38661
ce2be9d
f97ca99
ee048e4
f85aa36
41a5932
6b0f3bb
9489329
c7f8ff3
de60c70
c258ce5
863e567
39aa5ef
add16c3
05f138b
bc5a23e
5c128d4
b684acd
6642340
13a8d6c
a92a538
c433a15
67afc58
02e9675
682bcb7
cc96de2
1a3a599
ce76655
571c486
5597da7
cc41d54
940bead
997031c
0d0fb0c
116a7ab
205b4e4
55a51a1
63b2591
41573d3
a2799c4
99dd2d7
27cd367
a204d1c
d62d0d5
3738b0d
b92a701
516967b
5499e6c
a4d3830
af9d621
0d86ea0
e98ea57
9d7217a
6b1c0fe
8d0d414
d865cbf
a7f5e0d
2997ad0
c6cd2fe
151a726
d769af5
5d0a742
7e1093d
7baa1cb
f5eee45
16230d8
c793c8c
f8a297b
b8a185a
4e03f78
15301b4
c4b9c3f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import TYPE_CHECKING | ||
|
||
from habitat import VectorEnv | ||
|
||
if TYPE_CHECKING: | ||
from omegaconf import DictConfig | ||
|
||
|
||
class VectorEnvFactory(ABC): | ||
""" | ||
Interface responsible for constructing vectorized environments used in training. | ||
""" | ||
|
||
@abstractmethod | ||
def construct_envs( | ||
self, | ||
config: "DictConfig", | ||
workers_ignore_signals: bool = False, | ||
enforce_scenes_greater_eq_environments: bool = False, | ||
is_first_rank: bool = True, | ||
) -> VectorEnv: | ||
""" | ||
Setup a vectorized environment. | ||
|
||
:param config: configs that contain num_environments as well as information | ||
:param workers_ignore_signals: Passed to :ref:`habitat.VectorEnv`'s constructor | ||
:param enforce_scenes_greater_eq_environments: Make sure that there are more (or equal) | ||
:param enforce_scenes_greater_eq_environments: Make sure that there are more (or equal) | ||
scenes than environments. This is needed for correct evaluation. | ||
:param is_first_rank: If these environments are being constructed on the rank0 GPU. | ||
|
||
:return: VectorEnv object created according to specification. | ||
""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# Copyright (c) Meta Platforms, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import os | ||
import random | ||
from typing import TYPE_CHECKING, Any, List, Type | ||
|
||
from habitat import ThreadedVectorEnv, VectorEnv, logger, make_dataset | ||
from habitat.config import read_write | ||
from habitat.gym import make_gym_from_config | ||
from habitat_baselines.common.env_factory import VectorEnvFactory | ||
|
||
if TYPE_CHECKING: | ||
from omegaconf import DictConfig | ||
|
||
|
||
class HabitatVectorEnvFactory(VectorEnvFactory): | ||
def construct_envs( | ||
self, | ||
config: "DictConfig", | ||
workers_ignore_signals: bool = False, | ||
enforce_scenes_greater_eq_environments: bool = False, | ||
is_first_rank: bool = True, | ||
) -> VectorEnv: | ||
r"""Create VectorEnv object with specified config and env class type. | ||
To allow better performance, dataset are split into small ones for | ||
each individual env, grouped by scenes. | ||
""" | ||
|
||
num_environments = config.habitat_baselines.num_environments | ||
configs = [] | ||
dataset = make_dataset(config.habitat.dataset.type) | ||
scenes = config.habitat.dataset.content_scenes | ||
if "*" in config.habitat.dataset.content_scenes: | ||
scenes = dataset.get_scenes_to_load(config.habitat.dataset) | ||
|
||
if num_environments < 1: | ||
raise RuntimeError("num_environments must be strictly positive") | ||
|
||
if len(scenes) == 0: | ||
raise RuntimeError( | ||
"No scenes to load, multiple process logic relies on being able to split scenes uniquely between processes" | ||
) | ||
|
||
random.shuffle(scenes) | ||
|
||
scene_splits: List[List[str]] = [[] for _ in range(num_environments)] | ||
if len(scenes) < num_environments: | ||
msg = f"There are less scenes ({len(scenes)}) than environments ({num_environments}). " | ||
if enforce_scenes_greater_eq_environments: | ||
logger.warn( | ||
msg | ||
+ "Reducing the number of environments to be the number of scenes." | ||
) | ||
num_environments = len(scenes) | ||
scene_splits = [[s] for s in scenes] | ||
else: | ||
logger.warn( | ||
msg | ||
+ "Each environment will use all the scenes instead of using a subset." | ||
) | ||
for scene in scenes: | ||
for split in scene_splits: | ||
split.append(scene) | ||
else: | ||
for idx, scene in enumerate(scenes): | ||
scene_splits[idx % len(scene_splits)].append(scene) | ||
assert sum(map(len, scene_splits)) == len(scenes) | ||
|
||
for env_index in range(num_environments): | ||
proc_config = config.copy() | ||
with read_write(proc_config): | ||
task_config = proc_config.habitat | ||
task_config.seed = task_config.seed + env_index | ||
remove_measure_names = [] | ||
if not is_first_rank: | ||
# Filter out non rank0_measure from the task config if we are not on rank0. | ||
remove_measure_names.extend( | ||
task_config.task.rank0_measure_names | ||
) | ||
if (env_index != 0) or not is_first_rank: | ||
# Filter out non-rank0_env0 measures from the task config if we | ||
# are not on rank0 env0. | ||
remove_measure_names.extend( | ||
task_config.task.rank0_env0_measure_names | ||
) | ||
|
||
task_config.task.measurements = { | ||
k: v | ||
for k, v in task_config.task.measurements.items() | ||
if k not in remove_measure_names | ||
} | ||
|
||
if len(scenes) > 0: | ||
task_config.dataset.content_scenes = scene_splits[ | ||
env_index | ||
] | ||
|
||
configs.append(proc_config) | ||
|
||
vector_env_cls: Type[Any] | ||
if int(os.environ.get("HABITAT_ENV_DEBUG", 0)): | ||
logger.warn( | ||
"Using the debug Vector environment interface. Expect slower performance." | ||
) | ||
vector_env_cls = ThreadedVectorEnv | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be a separate VectorEnvFactory? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so because this is controlling the VectorEnv implementation returned. |
||
else: | ||
vector_env_cls = VectorEnv | ||
|
||
envs = vector_env_cls( | ||
make_env_fn=make_gym_from_config, | ||
env_fn_args=tuple((c,) for c in configs), | ||
workers_ignore_signals=workers_ignore_signals, | ||
) | ||
|
||
if config.habitat.simulator.renderer.enable_batch_renderer: | ||
envs.initialize_batch_renderer(config) | ||
|
||
return envs |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -264,6 +264,8 @@ class HrlDefinedSkillConfig(HabitatBaselinesBaseConfig): | |
@dataclass | ||
class HierarchicalPolicyConfig(HabitatBaselinesBaseConfig): | ||
high_level_policy: Dict[str, Any] = MISSING | ||
# Names of the skills to not load. | ||
ignore_skills: List[str] = field(default_factory=list) | ||
defined_skills: Dict[str, HrlDefinedSkillConfig] = field( | ||
default_factory=dict | ||
) | ||
|
@@ -383,6 +385,28 @@ class ProfilingConfig(HabitatBaselinesBaseConfig): | |
num_steps_to_capture: int = -1 | ||
|
||
|
||
@dataclass | ||
class VectorEnvFactoryConfig(HabitatBaselinesBaseConfig): | ||
""" | ||
`_target_` points to the `VectorEnvFactory` to setup the vectorized | ||
environment. Defaults to the Habitat vectorized environment setup. | ||
""" | ||
|
||
_target_: str = ( | ||
"habitat_baselines.common.habitat_env_factory.HabitatEnvFactory" | ||
) | ||
|
||
|
||
@dataclass | ||
class HydraCallbackConfig(HabitatBaselinesBaseConfig): | ||
""" | ||
Generic callback option for Hydra. Used to create the `_target_` class or | ||
call the `_target_` method. | ||
""" | ||
|
||
_target_: Optional[str] = None | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should these two be merged into a single class? Like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so because I want the default to be different. |
||
@dataclass | ||
class HabitatBaselinesConfig(HabitatBaselinesBaseConfig): | ||
# task config can be a list of configs like "A.yaml,B.yaml" | ||
|
@@ -413,6 +437,8 @@ class HabitatBaselinesConfig(HabitatBaselinesBaseConfig): | |
log_file: str = "train.log" | ||
force_blind_policy: bool = False | ||
verbose: bool = True | ||
# Creates the vectorized environment. | ||
vector_env_factory: VectorEnvFactoryConfig = VectorEnvFactoryConfig() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not have this be a string? Is it because we need to have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it is so |
||
eval_keys_to_include_in_name: List[str] = field(default_factory=list) | ||
# For our use case, the CPU side things are mainly memory copies | ||
# and nothing of substantive compute. PyTorch has been making | ||
|
@@ -430,6 +456,13 @@ class HabitatBaselinesConfig(HabitatBaselinesBaseConfig): | |
load_resume_state_config: bool = True | ||
eval: EvalConfig = EvalConfig() | ||
profiling: ProfilingConfig = ProfilingConfig() | ||
# Whether to log the infos that are only logged to a single process to the | ||
# CLI along with the other metrics. | ||
should_log_single_proc_infos: bool = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. docstring |
||
# Called every time a checkpoint is saved. | ||
# Function signature: fn(save_file_path: str) -> None | ||
# If not specified, there is no callback. | ||
on_save_ckpt_callback: Optional[HydraCallbackConfig] = None | ||
|
||
|
||
@dataclass | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VectorEnvFactory seem to only have one implementation. Is this overengineered?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree it only has 1 implementation in the current setup. However, this makes it much easier to (1) integrate with external environments that have a custom vectorized environment implementation, (2) create custom vector env logic.