Skip to content

Commit

Permalink
Add pre-commit workflow & fix remaining pyright issues (#209)
Browse files Browse the repository at this point in the history
* fix remianing issue and add pre-commit to workflow

* use python 3.10
  • Loading branch information
younik authored Oct 31, 2024
1 parent 0d593d8 commit 9c9e1af
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 22 deletions.
20 changes: 20 additions & 0 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# https://pre-commit.com
# This GitHub Action assumes that the repo contains a valid .pre-commit-config.yaml file.
---
name: pre-commit
on: [push]

permissions:
contents: read

jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- run: pip install .[all]
- run: pre-commit --version
- run: pre-commit run --all-files
6 changes: 6 additions & 0 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]):

# Our buffer is full and we will prioritize diverse, high reward additions.
else:
if (
self.training_objects.log_rewards is None
or training_objects.log_rewards is None
):
raise ValueError("log_rewards must be defined for prioritized replay.")

# Sort the incoming elements by their logrewards.
ix = torch.argsort(training_objects.log_rewards, descending=True)
training_objects = training_objects[ix]
Expand Down
17 changes: 4 additions & 13 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@
from gfn.utils.common import has_log_probs


def is_tensor(t) -> bool:
"""Checks whether t is a torch.Tensor instance."""
return isinstance(t, torch.Tensor)


# TODO: remove env from this class?
class Trajectories(Container):
"""Container for complete trajectories (starting in $s_0$ and ending in $s_f$).
Expand Down Expand Up @@ -113,7 +108,7 @@ def __init__(
)
else:
log_probs = torch.full(size=(0, 0), fill_value=0, dtype=torch.float)
self.log_probs = log_probs
self.log_probs: torch.Tensor = log_probs

self.estimator_outputs = estimator_outputs
if self.estimator_outputs is not None:
Expand Down Expand Up @@ -187,7 +182,7 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories:
log_rewards = (
self._log_rewards[index] if self._log_rewards is not None else None
)
if is_tensor(self.estimator_outputs):
if self.estimator_outputs is not None:
# TODO: Is there a safer way to index self.estimator_outputs for
# for n-dimensional estimator outputs?
#
Expand Down Expand Up @@ -292,13 +287,9 @@ def extend(self, other: Trajectories) -> None:

# Either set, or append, estimator outputs if they exist in the submitted
# trajectory.
if self.estimator_outputs is None and isinstance(
other.estimator_outputs, torch.Tensor
):
if self.estimator_outputs is None and other.estimator_outputs is not None:
self.estimator_outputs = other.estimator_outputs
elif isinstance(self.estimator_outputs, torch.Tensor) and isinstance(
other.estimator_outputs, torch.Tensor
):
elif self.estimator_outputs is not None and other.estimator_outputs is not None:
batch_shape = self.actions.batch_shape
n_bs = len(batch_shape)

Expand Down
8 changes: 4 additions & 4 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def __init__(
assert s0.shape == state_shape
if sf is None:
sf = torch.full(s0.shape, -float("inf")).to(self.device)
assert sf.shape == state_shape
self.sf = sf
self.sf: torch.Tensor = sf
assert self.sf.shape == state_shape
self.state_shape = state_shape
self.action_shape = action_shape
self.dummy_action = dummy_action
Expand Down Expand Up @@ -381,11 +381,11 @@ def __init__(

# The default dummy action is -1.
if dummy_action is None:
dummy_action = torch.tensor([-1], device=device)
dummy_action: torch.Tensor = torch.tensor([-1], device=device)

# The default exit action index is the final element of the action space.
if exit_action is None:
exit_action = torch.tensor([n_actions - 1], device=device)
exit_action: torch.Tensor = torch.tensor([n_actions - 1], device=device)

assert s0.shape == state_shape
assert dummy_action.shape == action_shape
Expand Down
8 changes: 4 additions & 4 deletions src/gfn/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,11 @@ def __init__(
dtype=torch.bool,
device=self.__class__.device,
)
assert forward_masks.shape == (*self.batch_shape, self.n_actions)
assert backward_masks.shape == (*self.batch_shape, self.n_actions - 1)

self.forward_masks = forward_masks
self.backward_masks = backward_masks
self.forward_masks: torch.Tensor = forward_masks
self.backward_masks: torch.Tensor = backward_masks
assert self.forward_masks.shape == (*self.batch_shape, self.n_actions)
assert self.backward_masks.shape == (*self.batch_shape, self.n_actions - 1)

def clone(self) -> States:
"""Returns a clone of the current instance."""
Expand Down
2 changes: 1 addition & 1 deletion testing/test_parametrizations_and_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
BoxPFMLP,
)
from gfn.modules import DiscretePolicyEstimator, ScalarEstimator
from gfn.utils.modules import DiscreteUniform, MLP, Tabular
from gfn.utils.modules import MLP, DiscreteUniform, Tabular

N = 10 # Number of trajectories from sample_trajectories (changes tests globally).

Expand Down

0 comments on commit 9c9e1af

Please sign in to comment.