Skip to content

Commit

Permalink
Fix set goal bug (facebookresearch#772)
Browse files Browse the repository at this point in the history
* Fix a bug
  • Loading branch information
erikwijmans authored Dec 14, 2021
1 parent f512fae commit 27c487d
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 2 deletions.
10 changes: 10 additions & 0 deletions habitat/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ class Episode:
info: Optional[Dict[str, Any]] = None
_shortest_path_cache: Any = attr.ib(init=False, default=None)

# NB: This method is marked static despite taking self so that
# on_setattr=Episode._reset_shortest_path_cache_hook works as attrs
# will pass the instance as the first argument!
@staticmethod
def _reset_shortest_path_cache_hook(
self: "Episode", attribute: attr.Attribute, value: Any
) -> Any:
self._shortest_path_cache = None
return value

def __getstate__(self):
return {
k: v
Expand Down
3 changes: 2 additions & 1 deletion habitat/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import math
from typing import Any, Dict, List, Optional

import attr
import numpy as np
import quaternion # noqa: F401

Expand Down Expand Up @@ -60,7 +61,7 @@ def tile_images(images: List[np.ndarray]) -> np.ndarray:


def not_none_validator(
self: Any, attribute: Any, value: Optional[Any]
self: Any, attribute: attr.Attribute, value: Optional[Any]
) -> None:
if value is None:
raise ValueError(f"Argument '{attribute.name}' must be set")
Expand Down
4 changes: 3 additions & 1 deletion habitat/tasks/nav/nav.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ class NavigationEpisode(Episode):
"""

goals: List[NavigationGoal] = attr.ib(
default=None, validator=not_none_validator
default=None,
validator=not_none_validator,
on_setattr=Episode._reset_shortest_path_cache_hook,
)
start_room: Optional[str] = None
shortest_paths: Optional[List[List[ShortestPathPoint]]] = None
Expand Down
17 changes: 17 additions & 0 deletions test/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest

from habitat.core.dataset import Dataset, Episode
from habitat.tasks.nav.nav import NavigationEpisode, NavigationGoal


def _construct_dataset(num_episodes, num_groups=10):
Expand Down Expand Up @@ -359,3 +360,19 @@ def test_preserve_order():
episode_iter = dataset.get_episode_iterator(shuffle=False, cycle=False)

assert list(episode_iter) == episodes


def test_reset_goals():
ep = NavigationEpisode(
episode_id="0",
scene_id="1",
start_position=[0, 0, 0],
start_rotation=[1, 0, 0, 0],
goals=[NavigationGoal(position=[1, 2, 3])],
)

ep._shortest_path_cache = "Dummy"
assert ep._shortest_path_cache is not None

ep.goals = [NavigationGoal(position=[3, 4, 5])]
assert ep._shortest_path_cache is None

0 comments on commit 27c487d

Please sign in to comment.