Skip to content

Commit

Permalink
better typing following the upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Jan 19, 2025
1 parent bc4a8f8 commit 9ea2a9a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
4 changes: 2 additions & 2 deletions rl4co/models/zoo/gfacs/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math

from typing import Optional, Union
from typing import Optional

import numpy as np
import scipy
Expand Down Expand Up @@ -33,7 +33,7 @@ def __init__(
self,
env: RL4COEnvBase,
policy: Optional[GFACSPolicy] = None,
baseline: Union[REINFORCEBaseline, str] = "no",
baseline: REINFORCEBaseline | str = "no",
train_with_local_search: bool = True,
ls_reward_aug_W: float = 0.95,
policy_kwargs: dict = {},
Expand Down
39 changes: 28 additions & 11 deletions rl4co/models/zoo/gfacs/policy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from functools import partial
from typing import Optional, Type, Union
import math
from typing import Optional, Type

from tensordict import TensorDict
import torch
Expand Down Expand Up @@ -50,8 +49,8 @@ def __init__(
aco_class: Optional[Type[AntSystem]] = None,
aco_kwargs: dict = {},
train_with_local_search: bool = True,
n_ants: Optional[Union[int, dict]] = None,
n_iterations: Optional[Union[int, dict]] = None,
n_ants: Optional[int | dict] = None,
n_iterations: Optional[int | dict] = None,
**encoder_kwargs,
):
if encoder is None:
Expand All @@ -74,7 +73,7 @@ def __init__(
def forward(
self,
td_initial: TensorDict,
env: Optional[Union[str, RL4COEnvBase]] = None,
env: Optional[str | RL4COEnvBase] = None,
phase: str = "train",
return_actions: bool = True,
return_hidden: bool = False,
Expand All @@ -87,7 +86,9 @@ def forward(
"""
n_ants = self.n_ants[phase]
# Instantiate environment if needed
if (phase != "train" or self.train_with_local_search) and (env is None or isinstance(env, str)):
if (phase != "train" or self.train_with_local_search) and (
env is None or isinstance(env, str)
):
env_name = self.env_name if env is None else env
env = get_env(env_name)
else:
Expand All @@ -102,7 +103,8 @@ def forward(
logZ = logZ[:, [0]]

select_start_nodes_fn = partial(
self.aco_class.select_start_node_fn, start_node=self.aco_kwargs.get("start_node", None)
self.aco_class.select_start_node_fn,
start_node=self.aco_kwargs.get("start_node", None),
)
decoding_kwargs.update(
{
Expand All @@ -113,7 +115,13 @@ def forward(
}
)
logprobs, actions, td, env = self.common_decoding(
"multistart_sampling", td_initial, env, hidden, n_ants, actions, **decoding_kwargs
"multistart_sampling",
td_initial,
env,
hidden,
n_ants,
actions,
**decoding_kwargs,
)
td.set("reward", env.get_reward(td, actions))

Expand All @@ -122,7 +130,8 @@ def forward(
"logZ": logZ,
"reward": unbatchify(td["reward"], n_ants),
"log_likelihood": unbatchify(
get_log_likelihood(logprobs, actions, td.get("mask", None), True), n_ants
get_log_likelihood(logprobs, actions, td.get("mask", None), True),
n_ants,
)
}

Expand All @@ -138,15 +147,23 @@ def forward(
batchify(td_initial, n_ants), env, actions # type:ignore
)
ls_logprobs, ls_actions, td, env = self.common_decoding(
"evaluate", td_initial, env, hidden, n_ants, ls_actions, **decoding_kwargs
"evaluate",
td_initial,
env,
hidden,
n_ants,
ls_actions,
**decoding_kwargs,
)
td.set("ls_reward", ls_reward)
outdict.update(
{
"ls_logZ": ls_logZ,
"ls_reward": unbatchify(ls_reward, n_ants),
"ls_log_likelihood": unbatchify(
get_log_likelihood(ls_logprobs, ls_actions, td.get("mask", None), True),
get_log_likelihood(
ls_logprobs, ls_actions, td.get("mask", None), True
),
n_ants,
)
}
Expand Down

0 comments on commit 9ea2a9a

Please sign in to comment.