Skip to content

Commit

Permalink
Merge branch 'instadeepai:main' into feat/lbf-truncate
Browse files Browse the repository at this point in the history
  • Loading branch information
WiemKhlifi authored Jan 18, 2024
2 parents d7a78ea + 8168c5c commit 92fc862
Show file tree
Hide file tree
Showing 16 changed files with 38 additions and 30 deletions.
8 changes: 3 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="docs/img/jumanji_logo_dm.png">
<source media="(prefers-color-scheme: light)" srcset="docs/img/jumanji_logo.png">
<img alt="Jumanji Logo" src="docs/img/jumanji_logo.png", width="50%">
</picture>
<a href="docs/img/jumanji_logo.png">
<img src="docs/img/jumanji_logo.png" alt="Jumanji logo" width="50%"/>
</a>
</p>

[![Python Versions](https://img.shields.io/pypi/pyversions/jumanji.svg?style=flat-square)](https://www.python.org/doc/versions/)
Expand Down
2 changes: 1 addition & 1 deletion docs/environments/cleaner.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ always start in the top left corner of the maze.
## Observation
The **observation** seen by the agent is a `NamedTuple` containing the following:

- `grid`: jax array (int) of shape `(num_rows, num_cols)`, array representing the grid, each tile is
- `grid`: jax array (int8) of shape `(num_rows, num_cols)`, array representing the grid, each tile is
either dirty (0), clean (1), or a wall (2).

- `agents_locations`: jax array (int) of shape `(num_agents, 2)`, array specifying the x and y
Expand Down
Binary file removed docs/img/jumanji_logo_dm.png
Binary file not shown.
6 changes: 4 additions & 2 deletions jumanji/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Abstract environment class"""

from __future__ import annotations

import abc
from typing import Any, Generic, Tuple, TypeVar

Expand Down Expand Up @@ -105,7 +107,7 @@ def discount_spec(self) -> specs.BoundedArray:
)

@property
def unwrapped(self) -> "Environment":
def unwrapped(self) -> Environment:
return self

def render(self, state: State) -> Any:
Expand All @@ -119,7 +121,7 @@ def render(self, state: State) -> Any:
def close(self) -> None:
"""Perform any necessary cleanup."""

def __enter__(self) -> "Environment":
def __enter__(self) -> Environment:
return self

def __exit__(self, *args: Any) -> None:
Expand Down
11 changes: 6 additions & 5 deletions jumanji/environments/packing/bin_pack/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any

Expand All @@ -32,7 +33,7 @@ class Space:
z1: chex.Numeric
z2: chex.Numeric

def astype(self, dtype: Any) -> "Space":
def astype(self, dtype: Any) -> Space:
space_dict = {
key: jnp.asarray(value, dtype) for key, value in self.__dict__.items()
}
Expand Down Expand Up @@ -90,7 +91,7 @@ def volume(self) -> chex.Numeric:
z_len = jnp.asarray(self.z2 - self.z1, float)
return x_len * y_len * z_len

def intersection(self, space: "Space") -> "Space":
def intersection(self, space: Space) -> Space:
"""Returns the intersected space with another space (i.e. the space that is included in both
spaces whose volume is maximum).
"""
Expand All @@ -102,15 +103,15 @@ def intersection(self, space: "Space") -> "Space":
z2 = jnp.minimum(self.z2, space.z2)
return Space(x1=x1, x2=x2, y1=y1, y2=y2, z1=z1, z2=z2)

def intersect(self, space: "Space") -> chex.Numeric:
def intersect(self, space: Space) -> chex.Numeric:
"""Returns whether a space intersect another space or not."""
return ~(self.intersection(space).is_empty())

def is_empty(self) -> chex.Numeric:
"""A space is empty if at least one dimension is negative or zero."""
return (self.x1 >= self.x2) | (self.y1 >= self.y2) | (self.z1 >= self.z2)

def is_included(self, space: "Space") -> chex.Numeric:
def is_included(self, space: Space) -> chex.Numeric:
"""Returns whether self is included into another space."""
return (
(self.x1 >= space.x1)
Expand All @@ -121,7 +122,7 @@ def is_included(self, space: "Space") -> chex.Numeric:
& (self.z2 <= space.z2)
)

def hyperplane(self, axis: str, direction: str) -> "Space":
def hyperplane(self, axis: str, direction: str) -> Space:
"""Returns the hyperplane (e.g. lower hyperplane on the x axis) for EMS creation when
packing an item.
Expand Down
8 changes: 4 additions & 4 deletions jumanji/environments/routing/cleaner/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Cleaner(Environment[State]):
a maze.
- observation: `Observation`
- grid: jax array (int32) of shape (num_rows, num_cols)
- grid: jax array (int8) of shape (num_rows, num_cols)
contains the state of the board: 0 for dirty tile, 1 for clean tile, 2 for wall.
- agents_locations: jax array (int32) of shape (num_agents, 2)
contains the location of each agent on the board.
Expand All @@ -57,7 +57,7 @@ class Cleaner(Environment[State]):
- An invalid action is selected for any of the agents.
- state: `State`
- grid: jax array (int32) of shape (num_rows, num_cols)
- grid: jax array (int8) of shape (num_rows, num_cols)
contains the current state of the board: 0 for dirty tile, 1 for clean tile, 2 for wall.
- agents_locations: jax array (int32) of shape (num_agents, 2)
contains the location of each agent on the board.
Expand Down Expand Up @@ -127,15 +127,15 @@ def observation_spec(self) -> specs.Spec[Observation]:
Returns:
Spec for the `Observation`, consisting of the fields:
- grid: BoundedArray (int32) of shape (num_rows, num_cols). Values
- grid: BoundedArray (int8) of shape (num_rows, num_cols). Values
are between 0 and 2 (inclusive).
- agent_locations_spec: BoundedArray (int32) of shape (num_agents, 2).
Maximum value for the first column is num_rows, and maximum value
for the second is num_cols.
- action_mask: BoundedArray (bool) of shape (num_agent, 4).
- step_count: BoundedArray (int32) of shape ().
"""
grid = specs.BoundedArray(self.grid_shape, jnp.int32, 0, 2, "grid")
grid = specs.BoundedArray(self.grid_shape, jnp.int8, 0, 2, "grid")
agents_locations = specs.BoundedArray(
(self.num_agents, 2), jnp.int32, [0, 0], self.grid_shape, "agents_locations"
)
Expand Down
3 changes: 2 additions & 1 deletion jumanji/environments/routing/cleaner/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
[DIRTY, DIRTY, DIRTY, DIRTY, WALL],
[DIRTY, WALL, WALL, DIRTY, WALL],
[DIRTY, WALL, DIRTY, DIRTY, DIRTY],
]
],
dtype=jnp.int8,
)


Expand Down
3 changes: 2 additions & 1 deletion jumanji/environments/routing/connector/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any, NamedTuple

Expand Down Expand Up @@ -42,7 +43,7 @@ def connected(self) -> chex.Array:
"""returns: True if the agent has reached its target."""
return jnp.all(self.position == self.target, axis=-1)

def __eq__(self: "Agent", agent_2: Any) -> chex.Array:
def __eq__(self: Agent, agent_2: Any) -> chex.Array:
if not isinstance(agent_2, Agent):
return NotImplemented
same_ids = (agent_2.id == self.id).all()
Expand Down
5 changes: 3 additions & 2 deletions jumanji/environments/routing/snake/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from enum import IntEnum
from typing import TYPE_CHECKING, NamedTuple
Expand All @@ -27,12 +28,12 @@ class Position(NamedTuple):
row: chex.Array
col: chex.Array

def __eq__(self, other: "Position") -> chex.Array: # type: ignore[override]
def __eq__(self, other: Position) -> chex.Array: # type: ignore[override]
if not isinstance(other, Position):
return NotImplemented
return (self.row == other.row) & (self.col == other.col)

def __add__(self, other: "Position") -> "Position": # type: ignore[override]
def __add__(self, other: Position) -> Position: # type: ignore[override]
if not isinstance(other, Position):
return NotImplemented
return Position(row=self.row + other.row, col=self.col + other.col)
Expand Down
4 changes: 2 additions & 2 deletions jumanji/specs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,13 @@ def test_shape_element_type_error(self) -> None:

def test_dtype_type_error(self) -> None:
with pytest.raises(TypeError):
specs.Array((1, 2, 3), "32")
specs.Array((1, 2, 3), "32") # type: ignore

def test_scalar_shape(self) -> None:
specs.Array((), jnp.int32)

def test_string_dtype_error(self) -> None:
specs.Array((1, 2, 3), "int32")
specs.Array((1, 2, 3), "int32") # type: ignore

def test_dtype(self) -> None:
specs.Array((1, 2, 3), int)
Expand Down
3 changes: 2 additions & 1 deletion jumanji/training/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import abc
import collections
Expand Down Expand Up @@ -56,7 +57,7 @@ def close(self) -> None:
def upload_checkpoint(self) -> None:
"""Uploads a checkpoint when exiting the logger."""

def __enter__(self) -> "Logger":
def __enter__(self) -> Logger:
logging.info("Starting logger.")
self._variables_enter = self._get_variables()
return self
Expand Down
5 changes: 3 additions & 2 deletions jumanji/training/networks/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Adapted from Brax."""
from __future__ import annotations

import abc

Expand All @@ -39,7 +40,7 @@ def entropy(self) -> chex.Array:
pass

@abc.abstractmethod
def kl_divergence(self, other: "Distribution") -> chex.Array:
def kl_divergence(self, other: Distribution) -> chex.Array:
pass


Expand Down Expand Up @@ -77,7 +78,7 @@ def entropy(self) -> chex.Array:

def kl_divergence( # type: ignore[override]
self,
other: "CategoricalDistribution",
other: CategoricalDistribution,
) -> chex.Array:
log_probs = jax.nn.log_softmax(self.logits)
probs = jax.nn.softmax(self.logits)
Expand Down
3 changes: 2 additions & 1 deletion jumanji/training/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# Inspired from https://stackoverflow.com/questions/51849395/how-can-we-associate-a-python-context-m
# anager-to-the-variables-appearing-in-it#:~:text=also%20inspect%20the-,stack,-for%20locals()%20variables
from __future__ import annotations

import inspect
import logging
Expand Down Expand Up @@ -45,7 +46,7 @@ def _get_variables(self) -> Dict:
"""
return {(k, id(v)): v for k, v in inspect.stack()[2].frame.f_locals.items()}

def __enter__(self) -> "Timer":
def __enter__(self) -> Timer:
self._variables_enter = self._get_variables()
self._start_time = time.perf_counter()
return self
Expand Down
2 changes: 1 addition & 1 deletion jumanji/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,4 @@ def get_valid_dtype(dtype: Union[jnp.dtype, type]) -> jnp.dtype:
Returns:
dtype converted to the correct type precision.
"""
return jnp.empty((), dtype).dtype
return jnp.empty((), dtype).dtype # type: ignore
3 changes: 2 additions & 1 deletion jumanji/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import (
Any,
Expand Down Expand Up @@ -120,7 +121,7 @@ def close(self) -> None:
"""
return self._env.close()

def __enter__(self) -> "Wrapper":
def __enter__(self) -> Wrapper:
return self

def __exit__(self, *args: Any) -> None:
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ chex>=0.1.3
dm-env>=1.5
gym>=0.22.0
jax>=0.2.26
matplotlib>=3.3.4
matplotlib~=3.7.4
numpy>=1.19.5
Pillow>=9.0.0
typing-extensions>=4.0.0

0 comments on commit 92fc862

Please sign in to comment.