Skip to content

Commit

Permalink
Ignore errors despite mismatch in Types but as the code works I would…
Browse files Browse the repository at this point in the history
… suspect that they are compatible because they're all based on nn.Modules
  • Loading branch information
Baschdl committed Mar 19, 2024
1 parent bf84dce commit e9d8dff
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions sbi/neural_nets/flow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.


from functools import partial
from typing import List, Optional, Sequence, Union
from warnings import warn
Expand All @@ -15,6 +14,7 @@
rational_quadratic, # pyright: ignore[reportAttributeAccessIssue]
)
from torch import Tensor, nn, relu, tanh, tensor, uint8
from zuko.flows import LazyTransform

from sbi.neural_nets.density_estimators import NFlowsFlow, ZukoFlow
from sbi.utils.sbiutils import (
Expand Down Expand Up @@ -501,13 +501,15 @@ def build_zuko_maf(
residual=residual,
)

transforms = maf.transform.transforms
transforms: Union[Sequence[LazyTransform], LazyTransform]
transforms = maf.transform.transforms # pyright: ignore[reportAssignmentType]
z_score_x_bool, structured_x = z_score_parser(z_score_x)
if z_score_x_bool:
# transforms = transforms
transforms = (
*transforms,
standardizing_transform(batch_x, structured_x, backend="zuko"),
# Ideally `standardizing_transform` would return a `LazyTransform` instead of ` AffineTransform | Unconditional`, maybe all three are compatible
standardizing_transform(batch_x, structured_x, backend="zuko"), # pyright: ignore[reportAssignmentType]
)

z_score_y_bool, structured_y = z_score_parser(z_score_y)
Expand Down

0 comments on commit e9d8dff

Please sign in to comment.