diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 4c7e9dccb..2af4875d9 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -393,7 +393,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype): gemm_input_role=GemmInputRole.WEIGHT, ) assert torch.equal(float8_eager._scale, float8_compile._scale) - assert torch.equal(float8_eager._data, float_compile._data) + assert torch.equal(float8_eager._data, float8_compile._data) if __name__ == "__main__":