Skip to content

Commit 0eea95b

Browse files
committed
Remove underscored funcs
Signed-off-by: Jacob Platin <jacobplatin@google.com>
1 parent 08909f6 commit 0eea95b

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

tests/test_quantization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -647,10 +647,10 @@ def setUp(self):
647647
self.calibration_method = 'max'
648648

649649
@patch(
650-
'tpu_commons.models.jax.utils.quantization.quantization_utils.ptq._create_quantized_param'
650+
'tpu_commons.models.jax.utils.quantization.quantization_utils.ptq.create_quantized_param'
651651
)
652652
def test_manually_quantize_qwix_weight(self, mock_create_param):
653-
"""Test that manually_quantize_qwix_weight calls ptq._create_quantized_param correctly."""
653+
"""Test that manually_quantize_qwix_weight calls ptq.create_quantized_param correctly."""
654654
quantize_qwix.manually_quantize_qwix_weight(
655655
weight=self.weight,
656656
qtype=self.qtype,
@@ -672,12 +672,12 @@ def test_manually_quantize_qwix_weight(self, mock_create_param):
672672
self.calibration_method)
673673

674674
@patch(
675-
'tpu_commons.models.jax.utils.quantization.quantization_utils.ptq._quantize_act'
675+
'tpu_commons.models.jax.utils.quantization.quantization_utils.ptq.quantize_act'
676676
)
677677
@patch('qwix.pallas.get_current_rule')
678678
def test_manually_quantize_qwix_activation(self, mock_get_rule,
679679
mock_quantize_act):
680-
"""Test that manually_quantize_qwix_activation calls ptq._quantize_act correctly."""
680+
"""Test that manually_quantize_qwix_activation calls ptq.quantize_act correctly."""
681681
mock_rule = MagicMock()
682682
mock_rule.act_static_scale = False
683683
mock_get_rule.return_value = mock_rule

tpu_commons/models/jax/utils/quantization/quantization_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ def manually_quantize_qwix_weight(weight: jax.Array, qtype: jnp.dtype,
527527
tiled_axes=tiled_axes,
528528
calibration_method=calibration_method)
529529

530-
return ptq._create_quantized_param(weight, how_to_quantize)
530+
return ptq.create_quantized_param(weight, how_to_quantize)
531531

532532

533533
def manually_quantize_qwix_activation(inputs: jax.Array, rule_name: str,
@@ -561,7 +561,7 @@ def manually_quantize_qwix_activation(inputs: jax.Array, rule_name: str,
561561
# channelwise_axes should be set to (a subset of) non-contraction axes. e.g.
562562
# for ragged_dot [m, k] x [g, k, n], they are [0] and [0, 2]
563563
# TODO (jacobplatin): add support for `act_name`
564-
return ptq._quantize_act(inputs, lhs_how, rule, "")
564+
return ptq.quantize_act(inputs, lhs_how, rule, "")
565565

566566

567567
def get_quant_dtype_from_qwix_config(

0 commit comments

Comments
 (0)