From dbe93b5a52ee3abbae10d1dfa845dc9ab13742bc Mon Sep 17 00:00:00 2001 From: cyber-pioneer <116002591+cyber-pioneer@users.noreply.github.com> Date: Wed, 24 Apr 2024 14:27:59 +0800 Subject: [PATCH] polish prim log (#63788) * polish prim log * polish code --- python/paddle/base/core.py | 21 +++++++++++++------ .../symbolic/test_sub_graph_chatglm2_4_st.py | 3 --- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/python/paddle/base/core.py b/python/paddle/base/core.py index d07b3faadbe8d..6931d22c45750 100644 --- a/python/paddle/base/core.py +++ b/python/paddle/base/core.py @@ -432,6 +432,15 @@ def _model_return_data(): return False +# This api is used for check whether prim is on +def _prim_return_log(): + flag = os.getenv("FLAGS_prim_log") + if flag and flag.lower() in ("1", "true"): + return True + else: + return False + + # We have 3 FLAGS to judge whether prim is enabled # FLAGS_prim_forward: Open or close forward prim strategy # FLAGS_prim_backward: Open or close backward prim strategy @@ -577,25 +586,25 @@ def _set_prim_backward_blacklist(*args): def _set_prim_backward_enabled(value): __set_bwd_prim_enabled(bool(value)) - if os.getenv("FLAGS_prim_log") == "1": + if _prim_return_log(): print("backward prim enabled: ", bool(_is_bwd_prim_enabled())) def _set_prim_forward_enabled(value): __set_fwd_prim_enabled(bool(value)) - if os.getenv("FLAGS_prim_log") == "1": + if _prim_return_log(): print("forward prim enabled: ", bool(_is_fwd_prim_enabled())) def set_prim_eager_enabled(value): __set_eager_prim_enabled(bool(value)) - if os.getenv("FLAGS_prim_log") == "1": + if _prim_return_log(): print("eager prim enabled: ", bool(_is_eager_prim_enabled())) def _set_prim_all_enabled(value): __set_all_prim_enabled(bool(value)) - if os.getenv("FLAGS_prim_log") == "1": + if _prim_return_log(): print( "all prim enabled: ", bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()), @@ -605,7 +614,7 @@ def _set_prim_all_enabled(value): def __sync_prim_backward_status(): flag_value = os.getenv("FLAGS_prim_backward") if flag_value is None: - if os.getenv("FLAGS_prim_log") == "1": + if _prim_return_log(): print("backward prim enabled: ", bool(_is_bwd_prim_enabled())) else: __sync_stat_with_flag("FLAGS_prim_backward") @@ -614,7 +623,7 @@ def __sync_prim_backward_status(): def __sync_prim_forward_status(): flag_value = os.getenv("FLAGS_prim_forward") if flag_value is None: - if os.getenv("FLAGS_prim_log") == "1": + if _prim_return_log(): print("forward prim enabled: ", bool(_is_fwd_prim_enabled())) else: __sync_stat_with_flag("FLAGS_prim_forward") diff --git a/test/ir/pir/cinn/symbolic/test_sub_graph_chatglm2_4_st.py b/test/ir/pir/cinn/symbolic/test_sub_graph_chatglm2_4_st.py index 6404c6fa91c2c..b8748500821e3 100644 --- a/test/ir/pir/cinn/symbolic/test_sub_graph_chatglm2_4_st.py +++ b/test/ir/pir/cinn/symbolic/test_sub_graph_chatglm2_4_st.py @@ -15,13 +15,10 @@ # repo: llm_sub_graphs # model: chatglm2 # api:paddle.nn.functional.input.embedding||method:transpose||api:paddle.tensor.creation.ones||api:paddle.tensor.creation.tril||method:astype||api:paddle.tensor.creation.ones||method:astype||method:__and__||api:paddle.tensor.creation.arange||method:__truediv__||method:__rpow__||method:__rtruediv__||api:paddle.tensor.creation.arange||api:paddle.tensor.math.outer||method:astype||api:paddle.tensor.ops.cos||api:paddle.tensor.ops.sin||api:paddle.tensor.manipulation.stack||method:__getitem__||method:transpose -import os import unittest import numpy as np -os.environ["FLAGS_prim_all"] = "False" - import paddle