Skip to content

Commit

Permalink
[NPU] fix npu llava infer (#757)
Browse files Browse the repository at this point in the history
`paddle.amp.is_bfloat16_supported()` will raise error in NPU device. And
bfloat16 is supported in default at 910B NPU device.

Co-authored-by: LokeZhou <aishenghuoaiqq@163.com>
  • Loading branch information
Birdylx and LokeZhou authored Oct 18, 2024
1 parent 67d0764 commit ef12f49
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion paddlemix/examples/llava/run_predict_multiround.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@
def main(args):
paddle.seed(seed=0)
compute_dtype = "float16" if args.fp16 else "bfloat16"
if compute_dtype== "bfloat16" and not paddle.amp.is_bfloat16_supported():
if "npu" in paddle.get_device():
is_bfloat16_supported = True
else:
is_bfloat16_supported = paddle.amp.is_bfloat16_supported()
if compute_dtype== "bfloat16" and not is_bfloat16_supported:
logger.warning("bfloat16 is not supported on your device,change to float32")
compute_dtype = "float32"

Expand Down

0 comments on commit ef12f49

Please sign in to comment.