Skip to content

Commit

Permalink
rolled back moving to attr.s in #1120 (#1182)
Browse files Browse the repository at this point in the history
  • Loading branch information
rpartsey authored Mar 3, 2023
1 parent 53b7a7a commit b63b024
Show file tree
Hide file tree
Showing 2 changed files with 248 additions and 228 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

import attr
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING

Expand All @@ -15,12 +15,12 @@
cs = ConfigStore.instance()


@attr.s(auto_attribs=True, slots=True)
@dataclass
class HabitatBaselinesBaseConfig:
pass


@attr.s(auto_attribs=True, slots=True)
@dataclass
class WBConfig(HabitatBaselinesBaseConfig):
"""Weights and Biases config"""

Expand All @@ -35,7 +35,7 @@ class WBConfig(HabitatBaselinesBaseConfig):
run_name: str = ""


@attr.s(auto_attribs=True, slots=True)
@dataclass
class EvalConfig(HabitatBaselinesBaseConfig):
# The split to evaluate on
split: str = "val"
Expand All @@ -44,13 +44,16 @@ class EvalConfig(HabitatBaselinesBaseConfig):
# The number of time to run each episode through evaluation.
# Only works when evaluating on all episodes.
evals_per_ep: int = 1
video_option: List[
str
] = [] # available options are "disk" and "tensorboard"
extra_sim_sensors: Dict[str, SimulatorSensorConfig] = dict()
video_option: List[str] = field(
# available options are "disk" and "tensorboard"
default_factory=list
)
extra_sim_sensors: Dict[str, SimulatorSensorConfig] = field(
default_factory=dict
)


@attr.s(auto_attribs=True, slots=True)
@dataclass
class PreemptionConfig(HabitatBaselinesBaseConfig):
# Append the slurm job ID to the resume state filename if running
# a slurm job. This is useful when you want to have things from a different
Expand All @@ -63,7 +66,7 @@ class PreemptionConfig(HabitatBaselinesBaseConfig):
save_state_batch_only: bool = False


@attr.s(auto_attribs=True, slots=True)
@dataclass
class ActionDistributionConfig(HabitatBaselinesBaseConfig):
use_log_std: bool = True
use_softplus: bool = False
Expand All @@ -82,12 +85,12 @@ class ActionDistributionConfig(HabitatBaselinesBaseConfig):
scheduled_std: bool = False


@attr.s(auto_attribs=True, slots=True)
@dataclass
class ObsTransformConfig(HabitatBaselinesBaseConfig):
type: str = MISSING


@attr.s(auto_attribs=True, slots=True)
@dataclass
class CenterCropperConfig(ObsTransformConfig):
type: str = "CenterCropper"
height: int = 256
Expand All @@ -108,7 +111,7 @@ class CenterCropperConfig(ObsTransformConfig):
)


@attr.s(auto_attribs=True, slots=True)
@dataclass
class ResizeShortestEdgeConfig(ObsTransformConfig):
type: str = "ResizeShortestEdge"
size: int = 256
Expand All @@ -129,19 +132,21 @@ class ResizeShortestEdgeConfig(ObsTransformConfig):
)


@attr.s(auto_attribs=True, slots=True)
@dataclass
class Cube2EqConfig(ObsTransformConfig):
type: str = "CubeMap2Equirect"
height: int = 256
width: int = 512
sensor_uuids: List[str] = [
"BACK",
"DOWN",
"FRONT",
"LEFT",
"RIGHT",
"UP",
]
sensor_uuids: List[str] = field(
default_factory=lambda: [
"BACK",
"DOWN",
"FRONT",
"LEFT",
"RIGHT",
"UP",
]
)


cs.store(
Expand All @@ -152,21 +157,23 @@ class Cube2EqConfig(ObsTransformConfig):
)


@attr.s(auto_attribs=True, slots=True)
@dataclass
class Cube2FishConfig(ObsTransformConfig):
type: str = "CubeMap2Fisheye"
height: int = 256
width: int = 256
fov: int = 180
params: Tuple[float, ...] = (0.2, 0.2, 0.2)
sensor_uuids: List[str] = [
"BACK",
"DOWN",
"FRONT",
"LEFT",
"RIGHT",
"UP",
]
sensor_uuids: List[str] = field(
default_factory=lambda: [
"BACK",
"DOWN",
"FRONT",
"LEFT",
"RIGHT",
"UP",
]
)


cs.store(
Expand All @@ -177,10 +184,10 @@ class Cube2FishConfig(ObsTransformConfig):
)


@attr.s(auto_attribs=True, slots=True)
@dataclass
class AddVirtualKeysConfig(ObsTransformConfig):
type: str = "AddVirtualKeys"
virtual_keys: Dict[str, int] = dict()
virtual_keys: Dict[str, int] = field(default_factory=dict)


cs.store(
Expand All @@ -191,19 +198,21 @@ class AddVirtualKeysConfig(ObsTransformConfig):
)


@attr.s(auto_attribs=True, slots=True)
@dataclass
class Eq2CubeConfig(ObsTransformConfig):
type: str = "Equirect2CubeMap"
height: int = 256
width: int = 256
sensor_uuids: List[str] = [
"BACK",
"DOWN",
"FRONT",
"LEFT",
"RIGHT",
"UP",
]
sensor_uuids: List[str] = field(
default_factory=lambda: [
"BACK",
"DOWN",
"FRONT",
"LEFT",
"RIGHT",
"UP",
]
)


cs.store(
Expand All @@ -214,7 +223,7 @@ class Eq2CubeConfig(ObsTransformConfig):
)


@attr.s(auto_attribs=True, slots=True)
@dataclass
class HrlDefinedSkillConfig(HabitatBaselinesBaseConfig):
"""
Defines a low-level skill to be used in the hierarchical policy.
Expand All @@ -234,7 +243,7 @@ class HrlDefinedSkillConfig(HabitatBaselinesBaseConfig):
# If true, this willapply the post-conditions of the skill after it
# terminates.
apply_postconds: bool = False
obs_skill_inputs: List[str] = list()
obs_skill_inputs: List[str] = field(default_factory=list)
obs_skill_input_dim: int = 3
start_zone_radius: float = 0.3
# For the oracle navigation skill
Expand All @@ -248,25 +257,27 @@ class HrlDefinedSkillConfig(HabitatBaselinesBaseConfig):
pddl_action_names: Optional[List[str]] = None


@attr.s(auto_attribs=True, slots=True)
@dataclass
class HierarchicalPolicyConfig(HabitatBaselinesBaseConfig):
high_level_policy: Dict[str, Any] = MISSING
defined_skills: Dict[str, HrlDefinedSkillConfig] = dict()
use_skills: Dict[str, str] = dict()
defined_skills: Dict[str, HrlDefinedSkillConfig] = field(
default_factory=dict
)
use_skills: Dict[str, str] = field(default_factory=dict)


@attr.s(auto_attribs=True, slots=True)
@dataclass
class PolicyConfig(HabitatBaselinesBaseConfig):
name: str = "PointNavResNetPolicy"
action_distribution_type: str = "categorical" # or 'gaussian'
# If the list is empty, all keys will be included.
# For gaussian action distribution:
action_dist: ActionDistributionConfig = ActionDistributionConfig()
obs_transforms: Dict[str, ObsTransformConfig] = dict()
obs_transforms: Dict[str, ObsTransformConfig] = field(default_factory=dict)
hierarchical_policy: HierarchicalPolicyConfig = MISSING


@attr.s(auto_attribs=True, slots=True)
@dataclass
class PPOConfig(HabitatBaselinesBaseConfig):
"""Proximal policy optimization config"""

Expand Down Expand Up @@ -297,7 +308,7 @@ class PPOConfig(HabitatBaselinesBaseConfig):
use_double_buffered_sampler: bool = False


@attr.s(auto_attribs=True, slots=True)
@dataclass
class VERConfig(HabitatBaselinesBaseConfig):
"""Variable experience rollout config"""

Expand All @@ -306,12 +317,12 @@ class VERConfig(HabitatBaselinesBaseConfig):
overlap_rollouts_and_learn: bool = False


@attr.s(auto_attribs=True, slots=True)
@dataclass
class AuxLossConfig(HabitatBaselinesBaseConfig):
pass


@attr.s(auto_attribs=True, slots=True)
@dataclass
class CPCALossConfig(AuxLossConfig):
"""Action-conditional contrastive predictive coding loss"""

Expand All @@ -321,7 +332,7 @@ class CPCALossConfig(AuxLossConfig):
loss_scale: float = 0.1


@attr.s(auto_attribs=True, slots=True)
@dataclass
class DDPPOConfig(HabitatBaselinesBaseConfig):
"""Decentralized distributed proximal policy optimization config"""

Expand All @@ -344,7 +355,7 @@ class DDPPOConfig(HabitatBaselinesBaseConfig):
force_distributed: bool = False


@attr.s(auto_attribs=True, slots=True)
@dataclass
class RLConfig(HabitatBaselinesBaseConfig):
"""Reinforcement learning config"""

Expand All @@ -353,16 +364,16 @@ class RLConfig(HabitatBaselinesBaseConfig):
ppo: PPOConfig = PPOConfig()
ddppo: DDPPOConfig = DDPPOConfig()
ver: VERConfig = VERConfig()
auxiliary_losses: Dict[str, AuxLossConfig] = dict()
auxiliary_losses: Dict[str, AuxLossConfig] = field(default_factory=dict)


@attr.s(auto_attribs=True, slots=True)
@dataclass
class ProfilingConfig(HabitatBaselinesBaseConfig):
capture_start_step: int = -1
num_steps_to_capture: int = -1


@attr.s(auto_attribs=True, slots=True)
@dataclass
class HabitatBaselinesConfig(HabitatBaselinesBaseConfig):
# task config can be a list of configs like "A.yaml,B.yaml"
# If habitat_baselines.evaluate is true, the run will be in evaluation mode
Expand Down Expand Up @@ -392,7 +403,7 @@ class HabitatBaselinesConfig(HabitatBaselinesBaseConfig):
log_file: str = "train.log"
force_blind_policy: bool = False
verbose: bool = True
eval_keys_to_include_in_name: List[str] = []
eval_keys_to_include_in_name: List[str] = field(default_factory=list)
# For our use case, the CPU side things are mainly memory copies
# and nothing of substantive compute. PyTorch has been making
# more and more memory copies parallel, but that just ends up
Expand All @@ -411,17 +422,17 @@ class HabitatBaselinesConfig(HabitatBaselinesBaseConfig):
profiling: ProfilingConfig = ProfilingConfig()


@attr.s(auto_attribs=True, slots=True)
@dataclass
class HabitatBaselinesRLConfig(HabitatBaselinesConfig):
rl: RLConfig = RLConfig()


@attr.s(auto_attribs=True, slots=True)
@dataclass
class HabitatBaselinesILConfig(HabitatBaselinesConfig):
il: Dict[str, Any] = dict()
il: Dict[str, Any] = field(default_factory=dict)


@attr.s(auto_attribs=True, slots=True)
@dataclass
class HabitatBaselinesSPAConfig(HabitatBaselinesConfig):
sense_plan_act: Any = MISSING

Expand Down
Loading

0 comments on commit b63b024

Please sign in to comment.