From e9d8dff7e65513ed994ec0117e70ea06cb2e2431 Mon Sep 17 00:00:00 2001 From: Sebastian Bischoff Date: Tue, 19 Mar 2024 20:31:39 +0100 Subject: [PATCH] Ignore errors despite mismatch in Types but as the code works I would suspect that they are compatible because they're all based on nn.Modules --- sbi/neural_nets/flow.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index 479664620..23c531654 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -1,7 +1,6 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Affero General Public License v3, see . - from functools import partial from typing import List, Optional, Sequence, Union from warnings import warn @@ -15,6 +14,7 @@ rational_quadratic, # pyright: ignore[reportAttributeAccessIssue] ) from torch import Tensor, nn, relu, tanh, tensor, uint8 +from zuko.flows import LazyTransform from sbi.neural_nets.density_estimators import NFlowsFlow, ZukoFlow from sbi.utils.sbiutils import ( @@ -501,13 +501,15 @@ def build_zuko_maf( residual=residual, ) - transforms = maf.transform.transforms + transforms: Union[Sequence[LazyTransform], LazyTransform] + transforms = maf.transform.transforms # pyright: ignore[reportAssignmentType] z_score_x_bool, structured_x = z_score_parser(z_score_x) if z_score_x_bool: # transforms = transforms transforms = ( *transforms, - standardizing_transform(batch_x, structured_x, backend="zuko"), + # Ideally `standardizing_transform` would return a `LazyTransform` instead of ` AffineTransform | Unconditional`, maybe all three are compatible + standardizing_transform(batch_x, structured_x, backend="zuko"), # pyright: ignore[reportAssignmentType] ) z_score_y_bool, structured_y = z_score_parser(z_score_y)