Skip to content

Commit

Permalink
updated docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Nov 14, 2024
1 parent db13637 commit be1087c
Showing 1 changed file with 53 additions and 19 deletions.
72 changes: 53 additions & 19 deletions src/gfn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
from gfn.utils.distributions import UnsqueezedCategorical


REDUCTION_FXNS = {
"mean": torch.mean,
"sum": torch.sum,
"prod": torch.prod,
}


class GFNModule(ABC, nn.Module):
r"""Base class for modules mapping states distributions.
Expand Down Expand Up @@ -41,9 +48,11 @@ class GFNModule(ABC, nn.Module):
`gfn.utils.modules`), then the environment preprocessor needs to be an
`EnumPreprocessor`.
preprocessor: Preprocessor from the environment.
_output_dim_is_checked: Flag for tracking whether the output dimenions of
_output_dim_is_checked: Flag for tracking whether the output dimensions of
the states (after being preprocessed and transformed by the modules) have
been verified.
_is_backward: Flag for tracking whether this estimator is used for predicting
probability distributions over parents.
"""

def __init__(
Expand All @@ -52,7 +61,7 @@ def __init__(
preprocessor: Preprocessor | None = None,
is_backward: bool = False,
) -> None:
"""Initalize the FunctionEstimator with an environment and a module.
"""Initialize the GFNModule with nn.Module and a preprocessor.
Args:
module: The module to use. If the module is a Tabular module (from
`gfn.utils.modules`), then the environment preprocessor needs to be an
Expand Down Expand Up @@ -152,9 +161,12 @@ class ScalarEstimator(GFNModule):
`gfn.utils.modules`), then the environment preprocessor needs to be an
`EnumPreprocessor`.
preprocessor: Preprocessor from the environment.
_output_dim_is_checked: Flag for tracking whether the output dimenions of
_output_dim_is_checked: Flag for tracking whether the output dimensions of
the states (after being preprocessed and transformed by the modules) have
been verified.
_is_backward: Flag for tracking whether this estimator is used for predicting
probability distributions over parents.
reduction_function: String denoting the
"""

def __init__(
Expand All @@ -164,14 +176,22 @@ def __init__(
is_backward: bool = False,
reduction: str = "mean",
):
"""Initialize the GFNModule with a scalar output.
Args:
module: The module to use. If the module is a Tabular module (from
`gfn.utils.modules`), then the environment preprocessor needs to be an
`EnumPreprocessor`.
preprocessor: Preprocessor object.
is_backward: Flags estimators of probability distributions over parents.
reduction: str name of the one of the REDUCTION_FXNS keys: {}
""".format(
list(REDUCTION_FXNS.keys())
)
super().__init__(module, preprocessor, is_backward)
reduction_fxns = {
"mean": torch.mean,
"sum": torch.sum,
"prod": torch.prod,
}
assert reduction in reduction_fxns
self.reduction_fxn = reduction_fxns[reduction]
assert reduction in REDUCTION_FXNS, "reduction function not one of {}".format(
REDUCTION_FXNS.keys()
)
self.reduction_fxn = REDUCTION_FXNS[reduction]

def expected_output_dim(self) -> int:
return 1
Expand Down Expand Up @@ -359,8 +379,8 @@ class ConditionalScalarEstimator(ConditionalDiscretePolicyEstimator):
conditional GFN, the logZ or logF estimate is also conditional. This Estimator is
designed for those cases.
The function approximator used for `module` need not directly output a scalar. If
it does not, `reduction` will be used to aggregate the outputs of the module into
The function approximator used for `final_module` need not directly output a scalar.
If it does not, `reduction` will be used to aggregate the outputs of the module into
a single scalar.
Attributes:
Expand All @@ -375,6 +395,9 @@ class ConditionalScalarEstimator(ConditionalDiscretePolicyEstimator):
_output_dim_is_checked: Flag for tracking whether the output dimensions of
the states (after being preprocessed and transformed by the modules) have
been verified.
_is_backward: Flag for tracking whether this estimator is used for predicting
probability distributions over parents.
reduction_function: String denoting the
"""

def __init__(
Expand All @@ -386,6 +409,20 @@ def __init__(
is_backward: bool = False,
reduction: str = "mean",
):
"""Initialize a conditional GFNModule with a scalar output.
Args:
state_module: The module to use for state representations. If the module is
a Tabular module (from `gfn.utils.modules`), then the environment
preprocessor needs to be an `EnumPreprocessor`.
conditioning_module: The module to use for conditioning representations.
final_module: The module to use for computing the final output.
preprocessor: Preprocessor object.
is_backward: Flags estimators of probability distributions over parents.
reduction: str name of the one of the REDUCTION_FXNS keys: {}
""".format(
list(REDUCTION_FXNS.keys())
)

super().__init__(
state_module,
conditioning_module,
Expand All @@ -394,13 +431,10 @@ def __init__(
preprocessor=preprocessor,
is_backward=is_backward,
)
reduction_fxns = {
"mean": torch.mean,
"sum": torch.sum,
"prod": torch.prod,
}
assert reduction in reduction_fxns
self.reduction_fxn = reduction_fxns[reduction]
assert reduction in REDUCTION_FXNS, "reduction function not one of {}".format(
REDUCTION_FXNS.keys()
)
self.reduction_fxn = REDUCTION_FXNS[reduction]

def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor:
"""Forward pass of the module.
Expand Down

0 comments on commit be1087c

Please sign in to comment.