Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@dataclass-style structured configs to @attrs-style #1172

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/new_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
We will use the strafe action outline in the habitat_sim example
"""

import attr
import attrs
import numpy as np

import habitat
Expand All @@ -25,7 +25,7 @@
from habitat.tasks.nav.nav import SimulatorTaskAction


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class NoisyStrafeActuationSpec:
move_amount: float
# Classic strafing is to move perpendicular (90 deg) to the forward direction
Expand Down
6 changes: 3 additions & 3 deletions examples/register_new_sensors_and_measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Any

import attrs
import numpy as np
from gym import spaces
from omegaconf import MISSING
Expand Down Expand Up @@ -46,7 +46,7 @@ def update_metric(self, *args: Any, episode, action, **kwargs: Any):


# define a configuration for this new measure
@dataclass
@attrs.define(auto_attribs=True)
class EpisodeInfoExampleConfig(MeasurementConfig):
# Note that typing is required on all fields
type: str = "EpisodeInfoExample"
Expand Down Expand Up @@ -89,7 +89,7 @@ def get_observation(


# define a configuration for this new sensor
@dataclass
@attrs.define(auto_attribs=True)
class AgentPositionSensorConfig(LabSensorConfig):
# Note that typing is required on all fields
type: str = "my_supercool_sensor"
Expand Down
8 changes: 4 additions & 4 deletions examples/tutorials/colabs/Habitat2_Quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"outputs": [],
"source": [
"# Play a teaser video\n",
"from dataclasses import dataclass\n",
"import attrs\n",
"\n",
"from habitat.config.default import get_agent_config\n",
"from habitat.config.default_structured_configs import (\n",
Expand Down Expand Up @@ -354,12 +354,12 @@
" self._metric = abs_targ_obj_idx == self._sim.grasp_mgr.snap_idx\n",
"\n",
"\n",
"@dataclass\n",
"@attrs.define(auto_attribs=True)\n",
"class DistanceToTargetObjectMeasurementConfig(MeasurementConfig):\n",
" type: str = \"DistanceToTargetObject\"\n",
"\n",
"\n",
"@dataclass\n",
"@attrs.define(auto_attribs=True)\n",
"class NavPickRewardMeasurementConfig(MeasurementConfig):\n",
" type: str = \"NavPickReward\"\n",
" scaling_factor: float = 0.1\n",
Expand All @@ -370,7 +370,7 @@
" force_end_pen: float = 10.0\n",
"\n",
"\n",
"@dataclass\n",
"@attrs.define(auto_attribs=True)\n",
"class NavPickSuccessMeasurementConfig(MeasurementConfig):\n",
" type: str = \"NavPickSuccess\"\n",
"\n",
Expand Down
8 changes: 4 additions & 4 deletions examples/tutorials/nb_python/Habitat2_Quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

# %%
# Play a teaser video
from dataclasses import dataclass
import attrs

from habitat.config.default import get_agent_config
from habitat.config.default_structured_configs import (
Expand Down Expand Up @@ -309,12 +309,12 @@ def update_metric(self, *args, episode, task, observations, **kwargs):
self._metric = abs_targ_obj_idx == self._sim.grasp_mgr.snap_idx


@dataclass
@attrs.define(auto_attribs=True)
class DistanceToTargetObjectMeasurementConfig(MeasurementConfig):
type: str = "DistanceToTargetObject"


@dataclass
@attrs.define(auto_attribs=True)
class NavPickRewardMeasurementConfig(MeasurementConfig):
type: str = "NavPickReward"
scaling_factor: float = 0.1
Expand All @@ -325,7 +325,7 @@ class NavPickRewardMeasurementConfig(MeasurementConfig):
force_end_pen: float = 10.0


@dataclass
@attrs.define(auto_attribs=True)
class NavPickSuccessMeasurementConfig(MeasurementConfig):
type: str = "NavPickSuccess"

Expand Down
4 changes: 2 additions & 2 deletions habitat-baselines/habitat_baselines/agents/ppo_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

import argparse
import random
from dataclasses import dataclass
from typing import Dict, Optional

import attrs
import numpy as np
import torch
from gym.spaces import Box
Expand All @@ -24,7 +24,7 @@
from habitat_baselines.utils.common import batch_obs


@dataclass
@attrs.define(auto_attribs=True)
class PPOAgentConfig:
INPUT_TYPE: str = "rgb"
MODEL_PATH: str = "data/checkpoints/gibson-rgb-best.pth"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@
import numbers
from typing import Optional, Sequence, Union

import attr
import attrs
import numpy as np


@attr.s(auto_attribs=True, slots=True, repr=False)
@attrs.define(auto_attribs=True, repr=False)
class WindowedRunningMean:
r"""Efficient implementation of a windowed running mean. Supports an infinite window"""
window_size: Union[int, float]
_sum: float = attr.ib(0.0, init=False)
_count: int = attr.ib(0, init=False)
_ptr: int = attr.ib(0, init=False)
_buffer: Optional[np.ndarray] = attr.ib(None, init=False)
_sum: float = attrs.field(default=0.0, init=False)
_count: int = attrs.field(default=0, init=False)
_ptr: int = attrs.field(default=0, init=False)
_buffer: Optional[np.ndarray] = attrs.field(default=None, init=False)

def __attrs_post_init__(self):
if not self.infinite_window:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import Any, Dict, List, Optional, Tuple

import attr
import attrs
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING

Expand All @@ -15,12 +15,12 @@
cs = ConfigStore.instance()


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class HabitatBaselinesBaseConfig:
pass


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class WBConfig(HabitatBaselinesBaseConfig):
"""Weights and Biases config"""

Expand All @@ -35,7 +35,7 @@ class WBConfig(HabitatBaselinesBaseConfig):
run_name: str = ""


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class EvalConfig(HabitatBaselinesBaseConfig):
# The split to evaluate on
split: str = "val"
Expand All @@ -50,7 +50,7 @@ class EvalConfig(HabitatBaselinesBaseConfig):
extra_sim_sensors: Dict[str, SimulatorSensorConfig] = dict()


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class PreemptionConfig(HabitatBaselinesBaseConfig):
# Append the slurm job ID to the resume state filename if running
# a slurm job. This is useful when you want to have things from a different
Expand All @@ -63,7 +63,7 @@ class PreemptionConfig(HabitatBaselinesBaseConfig):
save_state_batch_only: bool = False


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class ActionDistributionConfig(HabitatBaselinesBaseConfig):
use_log_std: bool = True
use_softplus: bool = False
Expand All @@ -82,12 +82,12 @@ class ActionDistributionConfig(HabitatBaselinesBaseConfig):
scheduled_std: bool = False


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class ObsTransformConfig(HabitatBaselinesBaseConfig):
type: str = MISSING


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class CenterCropperConfig(ObsTransformConfig):
type: str = "CenterCropper"
height: int = 256
Expand All @@ -108,7 +108,7 @@ class CenterCropperConfig(ObsTransformConfig):
)


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class ResizeShortestEdgeConfig(ObsTransformConfig):
type: str = "ResizeShortestEdge"
size: int = 256
Expand All @@ -129,7 +129,7 @@ class ResizeShortestEdgeConfig(ObsTransformConfig):
)


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class Cube2EqConfig(ObsTransformConfig):
type: str = "CubeMap2Equirect"
height: int = 256
Expand All @@ -152,7 +152,7 @@ class Cube2EqConfig(ObsTransformConfig):
)


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class Cube2FishConfig(ObsTransformConfig):
type: str = "CubeMap2Fisheye"
height: int = 256
Expand All @@ -177,7 +177,7 @@ class Cube2FishConfig(ObsTransformConfig):
)


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class AddVirtualKeysConfig(ObsTransformConfig):
type: str = "AddVirtualKeys"
virtual_keys: Dict[str, int] = dict()
Expand All @@ -191,7 +191,7 @@ class AddVirtualKeysConfig(ObsTransformConfig):
)


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class Eq2CubeConfig(ObsTransformConfig):
type: str = "Equirect2CubeMap"
height: int = 256
Expand All @@ -214,7 +214,7 @@ class Eq2CubeConfig(ObsTransformConfig):
)


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class HrlDefinedSkillConfig(HabitatBaselinesBaseConfig):
"""
Defines a low-level skill to be used in the hierarchical policy.
Expand Down Expand Up @@ -248,14 +248,14 @@ class HrlDefinedSkillConfig(HabitatBaselinesBaseConfig):
pddl_action_names: Optional[List[str]] = None


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class HierarchicalPolicyConfig(HabitatBaselinesBaseConfig):
high_level_policy: Dict[str, Any] = MISSING
defined_skills: Dict[str, HrlDefinedSkillConfig] = dict()
use_skills: Dict[str, str] = dict()


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class PolicyConfig(HabitatBaselinesBaseConfig):
name: str = "PointNavResNetPolicy"
action_distribution_type: str = "categorical" # or 'gaussian'
Expand All @@ -266,7 +266,7 @@ class PolicyConfig(HabitatBaselinesBaseConfig):
hierarchical_policy: HierarchicalPolicyConfig = MISSING


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class PPOConfig(HabitatBaselinesBaseConfig):
"""Proximal policy optimization config"""

Expand Down Expand Up @@ -297,7 +297,7 @@ class PPOConfig(HabitatBaselinesBaseConfig):
use_double_buffered_sampler: bool = False


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class VERConfig(HabitatBaselinesBaseConfig):
"""Variable experience rollout config"""

Expand All @@ -306,12 +306,12 @@ class VERConfig(HabitatBaselinesBaseConfig):
overlap_rollouts_and_learn: bool = False


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class AuxLossConfig(HabitatBaselinesBaseConfig):
pass


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class CPCALossConfig(AuxLossConfig):
"""Action-conditional contrastive predictive coding loss"""

Expand All @@ -321,7 +321,7 @@ class CPCALossConfig(AuxLossConfig):
loss_scale: float = 0.1


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class DDPPOConfig(HabitatBaselinesBaseConfig):
"""Decentralized distributed proximal policy optimization config"""

Expand All @@ -344,7 +344,7 @@ class DDPPOConfig(HabitatBaselinesBaseConfig):
force_distributed: bool = False


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class RLConfig(HabitatBaselinesBaseConfig):
"""Reinforcement learning config"""

Expand All @@ -356,13 +356,13 @@ class RLConfig(HabitatBaselinesBaseConfig):
auxiliary_losses: Dict[str, AuxLossConfig] = dict()


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class ProfilingConfig(HabitatBaselinesBaseConfig):
capture_start_step: int = -1
num_steps_to_capture: int = -1


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class HabitatBaselinesConfig(HabitatBaselinesBaseConfig):
# task config can be a list of configs like "A.yaml,B.yaml"
# If habitat_baselines.evaluate is true, the run will be in evaluation mode
Expand Down Expand Up @@ -411,17 +411,17 @@ class HabitatBaselinesConfig(HabitatBaselinesBaseConfig):
profiling: ProfilingConfig = ProfilingConfig()


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class HabitatBaselinesRLConfig(HabitatBaselinesConfig):
rl: RLConfig = RLConfig()


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class HabitatBaselinesILConfig(HabitatBaselinesConfig):
il: Dict[str, Any] = dict()


@attr.s(auto_attribs=True, slots=True)
@attrs.define(auto_attribs=True)
class HabitatBaselinesSPAConfig(HabitatBaselinesConfig):
sense_plan_act: Any = MISSING

Expand Down
Loading