|
9 | 9 | from flax import nnx |
10 | 10 | from jax.sharding import Mesh, NamedSharding |
11 | 11 | from jax.sharding import PartitionSpec as P |
| 12 | +from qwix._src.providers import ptq |
12 | 13 |
|
13 | 14 | import tpu_commons.models.jax.utils.quantization.quantization_utils as quantize_qwix # noqa: E402 |
14 | 15 | from tpu_commons.models.jax.model_loader import apply_qwix_quantization |
@@ -631,5 +632,204 @@ def test_get_random_sharded_array_sharding_fallback( |
631 | 632 | self.assertEqual(fallback_sharding, NamedSharding(self.mesh, P())) |
632 | 633 |
|
633 | 634 |
|
| 635 | +class TestManualQwixQuantization(unittest.TestCase): |
| 636 | + """Tests for manual Qwix quantization functions.""" |
| 637 | + |
| 638 | + def setUp(self): |
| 639 | + if not jax.devices(): |
| 640 | + self.skipTest( |
| 641 | + "JAX device not found, skipping JAX-dependent tests.") |
| 642 | + self.weight = jnp.ones((4, 4)) |
| 643 | + self.inputs = jnp.ones((8, 4)) |
| 644 | + self.qtype = jnp.int8 |
| 645 | + self.channelwise_axes = [0] |
| 646 | + self.tiled_axes = {} |
| 647 | + self.calibration_method = 'max' |
| 648 | + |
| 649 | + @patch( |
| 650 | + 'tpu_commons.models.jax.utils.quantization.quantization_utils.ptq.create_quantized_param' |
| 651 | + ) |
| 652 | + def test_manually_quantize_qwix_weight(self, mock_create_param): |
| 653 | + """Test that manually_quantize_qwix_weight calls ptq.create_quantized_param correctly.""" |
| 654 | + quantize_qwix.manually_quantize_qwix_weight( |
| 655 | + weight=self.weight, |
| 656 | + qtype=self.qtype, |
| 657 | + channelwise_axes=self.channelwise_axes, |
| 658 | + tiled_axes=self.tiled_axes, |
| 659 | + calibration_method=self.calibration_method) |
| 660 | + |
| 661 | + mock_create_param.assert_called_once() |
| 662 | + args, _ = mock_create_param.call_args |
| 663 | + passed_weight, passed_how_to_quantize = args |
| 664 | + |
| 665 | + self.assertTrue(jnp.array_equal(passed_weight, self.weight)) |
| 666 | + self.assertIsInstance(passed_how_to_quantize, ptq.qarray.HowToQuantize) |
| 667 | + self.assertEqual(passed_how_to_quantize.qtype, self.qtype) |
| 668 | + self.assertEqual(passed_how_to_quantize.channelwise_axes, |
| 669 | + self.channelwise_axes) |
| 670 | + self.assertEqual(passed_how_to_quantize.tiled_axes, self.tiled_axes) |
| 671 | + self.assertEqual(passed_how_to_quantize.calibration_method, |
| 672 | + self.calibration_method) |
| 673 | + |
| 674 | + @patch( |
| 675 | + 'tpu_commons.models.jax.utils.quantization.quantization_utils.ptq.quantize_act' |
| 676 | + ) |
| 677 | + @patch('qwix.pallas.get_current_rule') |
| 678 | + def test_manually_quantize_qwix_activation(self, mock_get_rule, |
| 679 | + mock_quantize_act): |
| 680 | + """Test that manually_quantize_qwix_activation calls ptq.quantize_act correctly.""" |
| 681 | + mock_rule = MagicMock() |
| 682 | + mock_rule.act_static_scale = False |
| 683 | + mock_get_rule.return_value = mock_rule |
| 684 | + rule_name = "test_rule" |
| 685 | + |
| 686 | + quantize_qwix.manually_quantize_qwix_activation( |
| 687 | + inputs=self.inputs, |
| 688 | + rule_name=rule_name, |
| 689 | + qtype=self.qtype, |
| 690 | + channelwise_axes=self.channelwise_axes, |
| 691 | + tiled_axes=self.tiled_axes, |
| 692 | + calibration_method=self.calibration_method) |
| 693 | + |
| 694 | + mock_get_rule.assert_called_once_with(rule_name) |
| 695 | + mock_quantize_act.assert_called_once() |
| 696 | + |
| 697 | + args, _ = mock_quantize_act.call_args |
| 698 | + passed_inputs, passed_how, passed_rule, passed_act_name = args |
| 699 | + |
| 700 | + self.assertTrue(jnp.array_equal(passed_inputs, self.inputs)) |
| 701 | + self.assertIsInstance(passed_how, ptq.qarray.HowToQuantize) |
| 702 | + self.assertEqual(passed_how.qtype, self.qtype) |
| 703 | + self.assertEqual(passed_how.channelwise_axes, self.channelwise_axes) |
| 704 | + self.assertEqual(passed_how.tiled_axes, self.tiled_axes) |
| 705 | + self.assertEqual(passed_how.calibration_method, |
| 706 | + self.calibration_method) |
| 707 | + self.assertIs(passed_rule, mock_rule) |
| 708 | + self.assertEqual(passed_act_name, "") # act_name is hardcoded to "" |
| 709 | + |
| 710 | + @patch('qwix.pallas.get_current_rule') |
| 711 | + def test_manually_quantize_qwix_activation_static_scale_raises_error( |
| 712 | + self, mock_get_rule): |
| 713 | + """Test that an assertion is raised if the rule has static scale.""" |
| 714 | + mock_rule = MagicMock() |
| 715 | + mock_rule.act_static_scale = True |
| 716 | + mock_get_rule.return_value = mock_rule |
| 717 | + |
| 718 | + with self.assertRaisesRegex(AssertionError, |
| 719 | + "Static scale not supported right now"): |
| 720 | + quantize_qwix.manually_quantize_qwix_activation( |
| 721 | + inputs=self.inputs, |
| 722 | + rule_name="any_rule", |
| 723 | + qtype=self.qtype, |
| 724 | + channelwise_axes=self.channelwise_axes, |
| 725 | + tiled_axes=self.tiled_axes, |
| 726 | + calibration_method=self.calibration_method) |
| 727 | + |
| 728 | + |
| 729 | +class TestGetQuantDtypeFromQwixConfig(unittest.TestCase): |
| 730 | + """Tests for the get_quant_dtype_from_qwix_config function.""" |
| 731 | + |
| 732 | + def setUp(self): |
| 733 | + self.mock_vllm_config = MagicMock() |
| 734 | + self.mock_vllm_config.additional_config = {} |
| 735 | + |
| 736 | + def test_get_quant_dtype_success(self): |
| 737 | + """Test successful extraction of dtypes from a valid config.""" |
| 738 | + self.mock_vllm_config.additional_config = { |
| 739 | + "quantization": { |
| 740 | + "qwix": { |
| 741 | + "scale_dtype": |
| 742 | + "float16", |
| 743 | + "rules": [ |
| 744 | + { |
| 745 | + "module_path": ".*mlp.*", |
| 746 | + "weight_qtype": "int4" |
| 747 | + }, |
| 748 | + { |
| 749 | + "module_path": ".*", |
| 750 | + "weight_qtype": "int8" |
| 751 | + }, |
| 752 | + ], |
| 753 | + } |
| 754 | + } |
| 755 | + } |
| 756 | + scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config( |
| 757 | + self.mock_vllm_config) |
| 758 | + self.assertEqual(scale_dtype, jnp.float16) |
| 759 | + self.assertEqual(quant_dtype, jnp.int8) |
| 760 | + |
| 761 | + def test_get_quant_dtype_default_scale(self): |
| 762 | + """Test that scale_dtype defaults to bfloat16 when not specified.""" |
| 763 | + self.mock_vllm_config.additional_config = { |
| 764 | + "quantization": { |
| 765 | + "qwix": { |
| 766 | + "rules": [{ |
| 767 | + "module_path": ".*", |
| 768 | + "weight_qtype": "int8" |
| 769 | + }] |
| 770 | + } |
| 771 | + } |
| 772 | + } |
| 773 | + scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config( |
| 774 | + self.mock_vllm_config) |
| 775 | + self.assertEqual(scale_dtype, jnp.bfloat16) |
| 776 | + self.assertEqual(quant_dtype, jnp.int8) |
| 777 | + |
| 778 | + def test_no_quantization_config_returns_defaults(self): |
| 779 | + """Test that default dtypes are returned when config is missing.""" |
| 780 | + self.mock_vllm_config.additional_config = {} |
| 781 | + scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config( |
| 782 | + self.mock_vllm_config) |
| 783 | + self.assertEqual(scale_dtype, jnp.bfloat16) |
| 784 | + self.assertIsNone(quant_dtype) |
| 785 | + |
| 786 | + def test_get_quant_dtype_no_wildcard_rule_returns_none(self): |
| 787 | + """Test that quant_dtype is None if no wildcard rule is found.""" |
| 788 | + self.mock_vllm_config.additional_config = { |
| 789 | + "quantization": { |
| 790 | + "qwix": { |
| 791 | + "rules": [{ |
| 792 | + "module_path": ".*mlp.*", |
| 793 | + "weight_qtype": "int4" |
| 794 | + }] |
| 795 | + } |
| 796 | + } |
| 797 | + } |
| 798 | + scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config( |
| 799 | + self.mock_vllm_config) |
| 800 | + self.assertEqual(scale_dtype, jnp.bfloat16) |
| 801 | + self.assertIsNone(quant_dtype) |
| 802 | + |
| 803 | + def test_get_quant_dtype_wildcard_rule_missing_qtype_raises_error(self): |
| 804 | + """Test that an assertion is raised if the wildcard rule is missing weight_qtype.""" |
| 805 | + self.mock_vllm_config.additional_config = { |
| 806 | + "quantization": { |
| 807 | + "qwix": { |
| 808 | + "rules": [{ |
| 809 | + "module_path": ".*" |
| 810 | + }] |
| 811 | + } |
| 812 | + } |
| 813 | + } |
| 814 | + with self.assertRaisesRegex(AssertionError, |
| 815 | + "Quantization dtype not found"): |
| 816 | + quantize_qwix.get_quant_dtype_from_qwix_config( |
| 817 | + self.mock_vllm_config) |
| 818 | + |
| 819 | + def test_get_quant_dtype_no_rules_key_returns_none(self): |
| 820 | + """Test that quant_dtype is None if 'rules' key is missing.""" |
| 821 | + self.mock_vllm_config.additional_config = { |
| 822 | + "quantization": { |
| 823 | + "qwix": { |
| 824 | + "scale_dtype": "float16", |
| 825 | + } |
| 826 | + } |
| 827 | + } |
| 828 | + scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config( |
| 829 | + self.mock_vllm_config) |
| 830 | + self.assertEqual(scale_dtype, jnp.float16) |
| 831 | + self.assertIsNone(quant_dtype) |
| 832 | + |
| 833 | + |
634 | 834 | if __name__ == '__main__': |
635 | 835 | unittest.main() |
0 commit comments