diff --git a/python/paddle/base/core.py b/python/paddle/base/core.py index d0f95343fe26fc..2cf6858e569fc2 100644 --- a/python/paddle/base/core.py +++ b/python/paddle/base/core.py @@ -566,6 +566,11 @@ def _enable_dist_prim_all(): def _enable_auto_recompute(): + # NOTE(chenxi67): open recompute when cinn is enabled + from paddle.base.framework import in_cinn_mode + + if in_cinn_mode(): + return True flag = os.getenv("FLAGS_enable_auto_recompute") if flag and flag.lower() in ("1", "true"): return True