Skip to content

Commit

Permalink
Fix (examples/generative): improved export (#838)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Feb 12, 2024
1 parent 65360f4 commit e0d78a6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
12 changes: 7 additions & 5 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,22 +189,20 @@ def model_export(model, ref_input, args):
elif args.export_target == 'onnx_qcdq':
if args.weight_quant_granularity == 'per_group':
export_manager = BlockQuantProxyLevelManager
constant_folding = False
else:
export_manager = StdQCDQONNXManager
export_manager.change_weight_export(export_weight_q_node=True)
constant_folding = True

print(f"Exporting the model in ./quantized_onnx/{args.model.replace('/', '-')}")
with torch.no_grad(), brevitas_proxy_export_mode(model, export_manager=export_manager):
onnx_export(
model,
f"./quantized_onnx/{args.model.replace('/', '-')}",
task="text-generation-with-past",
do_validation=False,
do_constant_folding=constant_folding)
do_validation=False)
elif args.export_target == 'torch_qcdq':
export_torch_qcdq(model, ref_input, export_path=f"{args.model.replace('/', '-')}.pt")
export_torch_qcdq(
model, ref_input['input_ids'], export_path=f"{args.model.replace('/', '-')}.pt")


def validate(args):
Expand Down Expand Up @@ -251,6 +249,10 @@ def main():
dtype = torch.float16

kwargs = {"torch_dtype": dtype}

if args.export_target == 'torch_qcdq':
kwargs['torchscript'] = True

print("Model loading...")
model = AutoModelForCausalLM.from_pretrained(args.model, **kwargs)
print("Model loaded.")
Expand Down
4 changes: 3 additions & 1 deletion src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ def bit_width(module):

if args.export_target:
# Move to cpu and to float32 to enable CPU export
pipe.unet.to('cpu').to(torch.float32)
if not (args.float16 and args.export_cuda_float16):
pipe.unet.to('cpu').to(torch.float32)
pipe.unet.eval()
device = next(iter(pipe.unet.parameters())).device
dtype = next(iter(pipe.unet.parameters())).dtype
Expand Down Expand Up @@ -233,6 +234,7 @@ def bit_width(module):
help='Group size for per_group weight quantization. Default: 16.')
add_bool_arg(
parser, 'quantize-weight-zero-point', default=True, help='Quantize weight zero-point.')
add_bool_arg(parser, 'export-cuda-float16', default=False, help='Export FP16 on CUDA')
args = parser.parse_args()
print("Args: " + str(vars(args)))
main(args)

0 comments on commit e0d78a6

Please sign in to comment.