Skip to content

Commit

Permalink
polish prim log (#63788)
Browse files Browse the repository at this point in the history
* polish prim log

* polish code
  • Loading branch information
cyber-pioneer authored Apr 24, 2024
1 parent 5f6e9d4 commit dbe93b5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
21 changes: 15 additions & 6 deletions python/paddle/base/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,15 @@ def _model_return_data():
return False


# This api is used for check whether prim is on
def _prim_return_log():
flag = os.getenv("FLAGS_prim_log")
if flag and flag.lower() in ("1", "true"):
return True
else:
return False


# We have 3 FLAGS to judge whether prim is enabled
# FLAGS_prim_forward: Open or close forward prim strategy
# FLAGS_prim_backward: Open or close backward prim strategy
Expand Down Expand Up @@ -577,25 +586,25 @@ def _set_prim_backward_blacklist(*args):

def _set_prim_backward_enabled(value):
__set_bwd_prim_enabled(bool(value))
if os.getenv("FLAGS_prim_log") == "1":
if _prim_return_log():
print("backward prim enabled: ", bool(_is_bwd_prim_enabled()))


def _set_prim_forward_enabled(value):
__set_fwd_prim_enabled(bool(value))
if os.getenv("FLAGS_prim_log") == "1":
if _prim_return_log():
print("forward prim enabled: ", bool(_is_fwd_prim_enabled()))


def set_prim_eager_enabled(value):
__set_eager_prim_enabled(bool(value))
if os.getenv("FLAGS_prim_log") == "1":
if _prim_return_log():
print("eager prim enabled: ", bool(_is_eager_prim_enabled()))


def _set_prim_all_enabled(value):
__set_all_prim_enabled(bool(value))
if os.getenv("FLAGS_prim_log") == "1":
if _prim_return_log():
print(
"all prim enabled: ",
bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()),
Expand All @@ -605,7 +614,7 @@ def _set_prim_all_enabled(value):
def __sync_prim_backward_status():
flag_value = os.getenv("FLAGS_prim_backward")
if flag_value is None:
if os.getenv("FLAGS_prim_log") == "1":
if _prim_return_log():
print("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
else:
__sync_stat_with_flag("FLAGS_prim_backward")
Expand All @@ -614,7 +623,7 @@ def __sync_prim_backward_status():
def __sync_prim_forward_status():
flag_value = os.getenv("FLAGS_prim_forward")
if flag_value is None:
if os.getenv("FLAGS_prim_log") == "1":
if _prim_return_log():
print("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
else:
__sync_stat_with_flag("FLAGS_prim_forward")
Expand Down
3 changes: 0 additions & 3 deletions test/ir/pir/cinn/symbolic/test_sub_graph_chatglm2_4_st.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,10 @@
# repo: llm_sub_graphs
# model: chatglm2
# api:paddle.nn.functional.input.embedding||method:transpose||api:paddle.tensor.creation.ones||api:paddle.tensor.creation.tril||method:astype||api:paddle.tensor.creation.ones||method:astype||method:__and__||api:paddle.tensor.creation.arange||method:__truediv__||method:__rpow__||method:__rtruediv__||api:paddle.tensor.creation.arange||api:paddle.tensor.math.outer||method:astype||api:paddle.tensor.ops.cos||api:paddle.tensor.ops.sin||api:paddle.tensor.manipulation.stack||method:__getitem__||method:transpose
import os
import unittest

import numpy as np

os.environ["FLAGS_prim_all"] = "False"

import paddle


Expand Down

0 comments on commit dbe93b5

Please sign in to comment.