diff --git a/vizier/_src/algorithms/designers/scalarization.py b/vizier/_src/algorithms/designers/scalarization.py index 4cd99806b..076a4a2c8 100644 --- a/vizier/_src/algorithms/designers/scalarization.py +++ b/vizier/_src/algorithms/designers/scalarization.py @@ -51,8 +51,7 @@ def __call__( class LinearScalarization(Scalarization): """Linear Scalarization.""" - @jt.jaxtyped - @typeguard.typechecked + @jt.jaxtyped(typechecker=typeguard.typechecked) def __call__( self, objectives: jt.Float[jax.Array, '*Batch Obj'] ) -> jt.Float[jax.Array, '*Batch']: @@ -62,8 +61,7 @@ def __call__( class ChebyshevScalarization(Scalarization): """Chebyshev Scalarization.""" - @jt.jaxtyped - @typeguard.typechecked + @jt.jaxtyped(typechecker=typeguard.typechecked) def __call__( self, objectives: jt.Float[jax.Array, '*Batch Obj'] ) -> jt.Float[jax.Array, '*Batch']: @@ -77,8 +75,7 @@ class HyperVolumeScalarization(Scalarization): default=None ) - @jt.jaxtyped - @typeguard.typechecked + @jt.jaxtyped(typechecker=typeguard.typechecked) def __call__( self, objectives: jt.Float[jax.Array, '*Batch Obj'] ) -> jt.Float[jax.Array, '*Batch']: @@ -105,8 +102,7 @@ class LinearAugmentedScalarization(Scalarization): default=1.0, converter=jnp.asarray ) - @jt.jaxtyped - @typeguard.typechecked + @jt.jaxtyped(typechecker=typeguard.typechecked) def __call__( self, objectives: jt.Float[jax.Array, '*Batch Obj'] ) -> jt.Float[jax.Array, '*Batch']: