diff --git a/olive/passes/onnx/mnb_to_qdq.py b/olive/passes/onnx/mnb_to_qdq.py index 2ec5e333c..a1a5d3d75 100644 --- a/olive/passes/onnx/mnb_to_qdq.py +++ b/olive/passes/onnx/mnb_to_qdq.py @@ -31,11 +31,12 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon "use_transpose_op": PassConfigParam( type_=bool, # TODO(jambayk): decide whether to enable this by default or not - # False gives same output on arm Mac/Windows, but not on x64 Linux/Windows - default_value=True, + # CPU-EP: False gives same output on arm Mac/Windows, but not on x64 Linux/Windows + default_value=False, description=( "Whether to use a Transpose operator after the DequantizeLinear operator. If False, the weight" - " initializer will be transposed instead." + " initializer will be transposed instead. Default is False. True might be more efficient on some" + " EPs such as DirectML." ), ), **get_external_data_config(), @@ -76,8 +77,9 @@ def _run_for_config( block_size = node_attributes["block_size"] num_k_blocks = math.ceil(K / block_size) - # only deal with 4 bits for now + # only deal with 4 bits (int4) for now if node_attributes["bits"] != 4: + logger.debug("%s uses %d bits, only 4 bits is supported", node_name, node_attributes["bits"]) continue # we can only deal with trivial g_idx, dequantize linear does not support g_idx diff --git a/test/unit_test/passes/onnx/test_mnb_to_qdq.py b/test/unit_test/passes/onnx/test_mnb_to_qdq.py index 12f52099f..7722b455f 100644 --- a/test/unit_test/passes/onnx/test_mnb_to_qdq.py +++ b/test/unit_test/passes/onnx/test_mnb_to_qdq.py @@ -61,7 +61,12 @@ def forward(self, x): reason="Int4 DQ is only supported in ORT >= 1.20", ) @pytest.mark.parametrize("use_transpose_op", [True, False]) -def test_mnb_to_qdq(create_mnb_model, use_transpose_op, tmp_path): +@pytest.mark.parametrize("execution_provider", ["CPUExecutionProvider", "CUDAExecutionProvider"]) +def test_mnb_to_qdq(create_mnb_model, execution_provider, use_transpose_op, tmp_path): + available_providers = onnxruntime.get_available_providers() + if execution_provider not in available_providers: + pytest.skip(f"{execution_provider} is not available on this system {available_providers}") + mnb_path, in_dim = create_mnb_model input_model = ONNXModelHandler(mnb_path) @@ -73,14 +78,20 @@ def test_mnb_to_qdq(create_mnb_model, use_transpose_op, tmp_path): qdq_model: ONNXModelHandler = p.run(input_model, output_folder) # validate - original_session = onnxruntime.InferenceSession(str(mnb_path)) - qdq_session = onnxruntime.InferenceSession(str(qdq_model.model_path)) + original_session = onnxruntime.InferenceSession(str(mnb_path), providers=[execution_provider]) + original_session.disable_fallback() + qdq_session = onnxruntime.InferenceSession(str(qdq_model.model_path), providers=[execution_provider]) + qdq_session.disable_fallback() input_data = {"input": np.random.randn(1, 1, in_dim).astype(np.float32)} original_output = original_session.run(None, input_data)[0] qdq_output = qdq_session.run(None, input_data)[0] assert original_output.shape == qdq_output.shape assert original_output.dtype == qdq_output.dtype - if use_transpose_op: - # Pre transposed DQ model does not match the expected output on x64 + if execution_provider == "CPUExecutionProvider" and not use_transpose_op: + # Pre transposed DQ model does not match the expected output on x64 CPU + # check for assertion failure so we know when the test is fixed + with pytest.raises(AssertionError): + np.testing.assert_allclose(original_output, qdq_output, rtol=1e-3, atol=1e-3) + else: np.testing.assert_allclose(original_output, qdq_output, rtol=1e-3, atol=1e-3)