Skip to content

Commit

Permalink
Apply suggestion from ConnorBaker in #1029
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 596925777
  • Loading branch information
sagipe authored and copybara-github committed Jan 9, 2024
1 parent e9e1671 commit 2877348
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions vizier/_src/algorithms/designers/scalarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def __call__(
class LinearScalarization(Scalarization):
"""Linear Scalarization."""

@jt.jaxtyped
@typeguard.typechecked
@jt.jaxtyped(typechecker=typeguard.typechecked)
def __call__(
self, objectives: jt.Float[jax.Array, '*Batch Obj']
) -> jt.Float[jax.Array, '*Batch']:
Expand All @@ -62,8 +61,7 @@ def __call__(
class ChebyshevScalarization(Scalarization):
"""Chebyshev Scalarization."""

@jt.jaxtyped
@typeguard.typechecked
@jt.jaxtyped(typechecker=typeguard.typechecked)
def __call__(
self, objectives: jt.Float[jax.Array, '*Batch Obj']
) -> jt.Float[jax.Array, '*Batch']:
Expand All @@ -77,8 +75,7 @@ class HyperVolumeScalarization(Scalarization):
default=None
)

@jt.jaxtyped
@typeguard.typechecked
@jt.jaxtyped(typechecker=typeguard.typechecked)
def __call__(
self, objectives: jt.Float[jax.Array, '*Batch Obj']
) -> jt.Float[jax.Array, '*Batch']:
Expand All @@ -105,8 +102,7 @@ class LinearAugmentedScalarization(Scalarization):
default=1.0, converter=jnp.asarray
)

@jt.jaxtyped
@typeguard.typechecked
@jt.jaxtyped(typechecker=typeguard.typechecked)
def __call__(
self, objectives: jt.Float[jax.Array, '*Batch Obj']
) -> jt.Float[jax.Array, '*Batch']:
Expand Down

0 comments on commit 2877348

Please sign in to comment.