Skip to content

Commit

Permalink
added workaround for equal_cpu error
Browse files Browse the repository at this point in the history
  • Loading branch information
costigt-dev committed Apr 8, 2024
1 parent e3c125e commit 6552b68
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
12 changes: 7 additions & 5 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,10 @@ def prepare_for_export(self, module):
else:
self.symbolic_kwargs = None

def quantize_from_floating_point(self, x: Tensor):
def quantize_from_floating_point(self, x: Tensor, zp):
quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs']
# workaround for equal_cpu issue
quantize_symbolic_kwargs['zero_point'] = zp
# Before quantization, cast input to float32
if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16:
x = self.cast_fn(x, torch.float32)
Expand All @@ -225,19 +227,19 @@ def quantize_from_integer(self, x: Tensor):

def symbolic_execution(self, x: Tensor):
assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled'
dequantize_symbolic_kwargs = copy(self.symbolic_kwargs['dequantize_symbolic_kwargs'])
zero_point = dequantize_symbolic_kwargs['zero_point']
if self._export_q_node:
x = self.quantize_from_floating_point(x)
x = self.quantize_from_floating_point(x, zero_point)
else:
x = self.quantize_from_integer(x)
clip_symbolic_kwargs = self.symbolic_kwargs['clip_symbolic_kwargs']
# Copy dict to allow for popping kwargs even on shared quantizers
dequantize_symbolic_kwargs = copy(self.symbolic_kwargs['dequantize_symbolic_kwargs'])
scale = dequantize_symbolic_kwargs['scale']
zero_point = dequantize_symbolic_kwargs['zero_point']
bit_width = self.symbolic_kwargs['bit_width']
scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape')
# Workaround to trick the tracer into believing all return values are used
self.assert_ge_zero(scale, zero_point, bit_width)
self.assert_ge_zero(scale, bit_width)
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values())
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/quant/experimental/scaled_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ class Fp8e4m3OCPWeightPerTensorFloat(Fp8e4m3Weight,
MaxStatsScaling,
PerTensorFloatScaling8bit,
WeightQuantSolver):
pass
pass
3 changes: 2 additions & 1 deletion tests/brevitas/export/test_onnx_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ def test_simple_fp8_export():


if __name__ == "__main__":
test_simple_fp8_export()
test_simple_fp8_export()
print("Done")

0 comments on commit 6552b68

Please sign in to comment.