diff --git a/paddlemix/examples/llava/run_predict_multiround.py b/paddlemix/examples/llava/run_predict_multiround.py index fb1a9746c..811c4e6a1 100644 --- a/paddlemix/examples/llava/run_predict_multiround.py +++ b/paddlemix/examples/llava/run_predict_multiround.py @@ -36,7 +36,11 @@ def main(args): paddle.seed(seed=0) compute_dtype = "float16" if args.fp16 else "bfloat16" - if compute_dtype== "bfloat16" and not paddle.amp.is_bfloat16_supported(): + if "npu" in paddle.get_device(): + is_bfloat16_supported = True + else: + is_bfloat16_supported = paddle.amp.is_bfloat16_supported() + if compute_dtype== "bfloat16" and not is_bfloat16_supported: logger.warning("bfloat16 is not supported on your device,change to float32") compute_dtype = "float32"