Skip to content

Commit

Permalink
Set runtime_include_dir in Paddle.__init__.py (PaddlePaddle#37886)
Browse files Browse the repository at this point in the history
Paddle don't have to set runtime_include_dir during run CINN.
  • Loading branch information
zhhsplendid authored and Zjq9409 committed Dec 10, 2021
1 parent 96a99bd commit f44cacb
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 0 deletions.
9 changes: 9 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,14 @@ bool IsCompiledWithMKLDNN() {
#endif
}

bool IsCompiledWithCINN() {
#ifndef PADDLE_WITH_CINN
return false;
#else
return true;
#endif
}

bool IsCompiledWithHETERPS() {
#ifndef PADDLE_WITH_HETERPS
return false;
Expand Down Expand Up @@ -2191,6 +2199,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("is_compiled_with_npu", IsCompiledWithNPU);
m.def("is_compiled_with_xpu", IsCompiledWithXPU);
m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN);
m.def("is_compiled_with_cinn", IsCompiledWithCINN);
m.def("_is_compiled_with_heterps", IsCompiledWithHETERPS);
m.def("supports_bfloat16", SupportsBfloat16);
m.def("supports_bfloat16_fast_performance", SupportsBfloat16FastPerformance);
Expand Down
11 changes: 11 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@
from .device import get_cudnn_version # noqa: F401
from .device import set_device # noqa: F401
from .device import get_device # noqa: F401
from .fluid.framework import is_compiled_with_cinn # noqa: F401
from .fluid.framework import is_compiled_with_cuda # noqa: F401
from .fluid.framework import is_compiled_with_rocm # noqa: F401
from .fluid.framework import disable_signal_handler # noqa: F401
Expand Down Expand Up @@ -310,6 +311,16 @@
import paddle.vision # noqa: F401

from .tensor.random import check_shape # noqa: F401

# CINN has to set a flag to include a lib
if is_compiled_with_cinn():
import os
package_dir = os.path.dirname(os.path.abspath(__file__))
runtime_include_dir = os.path.join(package_dir, "libs")
cuh_file = os.path.join(runtime_include_dir, "cinn_cuda_runtime_source.cuh")
if os.path.exists(cuh_file):
os.environ['runtime_include_dir'] = runtime_include_dir

disable_static()

__all__ = [ # noqa
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/device/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.framework import is_compiled_with_cinn # noqa: F401
from paddle.fluid.framework import is_compiled_with_cuda # noqa: F401
from paddle.fluid.framework import is_compiled_with_rocm # noqa: F401
from . import cuda
Expand All @@ -28,6 +29,7 @@
'get_device',
'XPUPlace',
'is_compiled_with_xpu',
'is_compiled_with_cinn',
'is_compiled_with_cuda',
'is_compiled_with_rocm',
'is_compiled_with_npu'
Expand Down
16 changes: 16 additions & 0 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
'xpu_places',
'cuda_pinned_places',
'in_dygraph_mode',
'is_compiled_with_cinn',
'is_compiled_with_cuda',
'is_compiled_with_rocm',
'is_compiled_with_xpu',
Expand Down Expand Up @@ -477,6 +478,21 @@ def disable_signal_handler():
core.disable_signal_handler()


def is_compiled_with_cinn():
"""
Whether this whl package can be used to run the model on CINN.
Returns (bool): `True` if CINN is currently available, otherwise `False`.
Examples:
.. code-block:: python
import paddle
support_cinn = paddle.device.is_compiled_with_cinn()
"""
return core.is_compiled_with_cinn()


def is_compiled_with_cuda():
"""
Whether this whl package can be used to run the model on GPU.
Expand Down
2 changes: 2 additions & 0 deletions python/setup.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,9 @@ if '${WITH_LITE}' == 'ON':

if '${WITH_CINN}' == 'ON':
shutil.copy('${CINN_LIB_LOCATION}/${CINN_LIB_NAME}', libs_path)
shutil.copy('${CINN_INCLUDE_DIR}/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh', libs_path)
package_data['paddle.libs']+=['libcinnapi.so']
package_data['paddle.libs']+=['cinn_cuda_runtime_source.cuh']

if '${WITH_PSLIB}' == 'ON':
shutil.copy('${PSLIB_LIB}', libs_path)
Expand Down

0 comments on commit f44cacb

Please sign in to comment.