Skip to content

Commit 08909f6

Browse files
committed
Fix small bug with import statement
Signed-off-by: Jacob Platin <jacobplatin@google.com>
1 parent 8aeee3a commit 08909f6

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tpu_commons/models/jax/utils/quantization/quantization_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import jax
88
import jax.numpy as jnp
99
import qwix
10+
import qwix.pallas as qpl
1011
import yaml
1112
from flax import nnx
1213
from flax.typing import PRNGKey
@@ -549,7 +550,7 @@ def manually_quantize_qwix_activation(inputs: jax.Array, rule_name: str,
549550
Returns:
550551
The quantized activation tensor.
551552
"""
552-
rule = qwix.pallas.get_current_rule(rule_name)
553+
rule = qpl.get_current_rule(rule_name)
553554
lhs_how = ptq.qarray.HowToQuantize(qtype=qtype,
554555
channelwise_axes=channelwise_axes,
555556
tiled_axes=tiled_axes,

0 commit comments

Comments
 (0)