Skip to content

Commit

Permalink
finish remaining options
Browse files Browse the repository at this point in the history
  • Loading branch information
ksimpson-work committed Feb 4, 2025
2 parents cd655f2 + 8ed9c03 commit ea25e67
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 131 deletions.
88 changes: 37 additions & 51 deletions cuda_core/cuda/core/experimental/_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

import ctypes
import warnings
import weakref
from contextlib import contextmanager
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from warnings import warn

from cuda.core.experimental._device import Device
from cuda.core.experimental._module import ObjectCode
Expand All @@ -23,11 +23,11 @@


# Note: this function is reused in the tests
def _decide_nvjitlink_or_driver():
def _decide_nvjitlink_or_driver() -> bool:
"""Returns True if falling back to the cuLink* driver APIs."""
global _driver_ver, _driver, _nvjitlink
if _driver or _nvjitlink:
return
return _driver is not None

_driver_ver = handle_return(driver.cuDriverGetVersion())
_driver_ver = (_driver_ver // 1000, (_driver_ver % 1000) // 10)
Expand All @@ -43,7 +43,7 @@ def _decide_nvjitlink_or_driver():
_nvjitlink = None

if _nvjitlink is None:
warnings.warn(
warn(
"nvJitLink is not installed or too old (<12.3). Therefore it is not usable "
"and the culink APIs will be used instead.",
stacklevel=3,
Expand Down Expand Up @@ -98,78 +98,58 @@ class LinkerOptions:
will be used.
max_register_count : int, optional
Maximum register count.
Maps to: ``-maxrregcount=<N>``.
time : bool, optional
Print timing information to the info log.
Maps to ``-time``.
Default: False.
verbose : bool, optional
Print verbose messages to the info log.
Maps to ``-verbose``.
Default: False.
link_time_optimization : bool, optional
Perform link time optimization.
Maps to: ``-lto``.
Default: False.
ptx : bool, optional
Emit PTX after linking instead of CUBIN; only supported with ``-lto``.
Maps to ``-ptx``.
Emit PTX after linking instead of CUBIN; only supported with ``link_time_optimization=True``.
Default: False.
optimization_level : int, optional
Set optimization level. Only 0 and 3 are accepted.
Maps to ``-O<N>``.
debug : bool, optional
Generate debug information.
Maps to ``-g``
Default: False.
lineinfo : bool, optional
Generate line information.
Maps to ``-lineinfo``.
Default: False.
ftz : bool, optional
Flush denormal values to zero.
Maps to ``-ftz=<n>``.
Default: False.
prec_div : bool, optional
Use precise division.
Maps to ``-prec-div=<n>``.
Default: True.
prec_sqrt : bool, optional
Use precise square root.
Maps to ``-prec-sqrt=<n>``.
Default: True.
fma : bool, optional
Use fast multiply-add.
Maps to ``-fma=<n>``.
Default: True.
kernels_used : List[str], optional
Pass list of kernels that are used; any not in the list can be removed. This option can be specified multiple
times.
Maps to ``-kernels-used=<name>``.
variables_used : List[str], optional
Pass a list of variables that are used; any not in the list can be removed.
Maps to ``-variables-used=<name>``
kernels_used : [Union[str, Tuple[str], List[str]]], optional
Pass a kernel or sequence of kernels that are used; any not in the list can be removed.
variables_used : [Union[str, Tuple[str], List[str]]], optional
Pass a variable or sequence of variables that are used; any not in the list can be removed.
optimize_unused_variables : bool, optional
Assume that if a variable is not referenced in device code, it can be removed.
Maps to: ``-optimize-unused-variables``
Default: False.
xptxas : [Union[str, Tuple[str], List[str]]], optional
ptxas_options : [Union[str, Tuple[str], List[str]]], optional
Pass options to PTXAS.
Maps to: ``-Xptxas=<opt>``.
split_compile : int, optional
Split compilation maximum thread count. Use 0 to use all available processors. Value of 1 disables split
compilation (default).
Maps to ``-split-compile=<N>``.
Default: 1.
split_compile_extended : int, optional
A more aggressive form of split compilation available in LTO mode only. Accepts a maximum thread count value.
Use 0 to use all available processors. Value of 1 disables extended split compilation (default). Note: This
option can potentially impact performance of the compiled binary.
Maps to ``-split-compile-extended=<N>``.
Default: 1.
no_cache : bool, optional
Do not cache the intermediate steps of nvJitLink.
Maps to ``-no-cache``.
Default: False.
"""

Expand All @@ -186,10 +166,10 @@ class LinkerOptions:
prec_div: Optional[bool] = None
prec_sqrt: Optional[bool] = None
fma: Optional[bool] = None
kernels_used: Optional[List[str]] = None
variables_used: Optional[List[str]] = None
kernels_used: Optional[Union[str, Tuple[str], List[str]]] = None
variables_used: Optional[Union[str, Tuple[str], List[str]]] = None
optimize_unused_variables: Optional[bool] = None
xptxas: Optional[Union[str, Tuple[str], List[str]]] = None
ptxas_options: Optional[Union[str, Tuple[str], List[str]]] = None
split_compile: Optional[int] = None
split_compile_extended: Optional[int] = None
no_cache: Optional[bool] = None
Expand Down Expand Up @@ -232,18 +212,24 @@ def _init_nvjitlink(self):
if self.fma is not None:
self.formatted_options.append(f"-fma={'true' if self.fma else 'false'}")
if self.kernels_used is not None:
for kernel in self.kernels_used:
self.formatted_options.append(f"-kernels-used={kernel}")
if isinstance(self.kernels_used, str):
self.formatted_options.append(f"-kernels-used={self.kernels_used}")
elif isinstance(self.kernels_used, list):
for kernel in self.kernels_used:
self.formatted_options.append(f"-kernels-used={kernel}")
if self.variables_used is not None:
for variable in self.variables_used:
self.formatted_options.append(f"-variables-used={variable}")
if isinstance(self.variables_used, str):
self.formatted_options.append(f"-variables-used={self.variables_used}")
elif isinstance(self.variables_used, list):
for variable in self.variables_used:
self.formatted_options.append(f"-variables-used={variable}")
if self.optimize_unused_variables is not None:
self.formatted_options.append("-optimize-unused-variables")
if self.xptxas is not None:
if isinstance(self.xptxas, str):
self.formatted_options.append(f"-Xptxas={self.xptxas}")
elif is_sequence(self.xptxas):
for opt in self.xptxas:
if self.ptxas_options is not None:
if isinstance(self.ptxas_options, str):
self.formatted_options.append(f"-Xptxas={self.ptxas_options}")
elif is_sequence(self.ptxas_options):
for opt in self.ptxas_options:
self.formatted_options.append(f"-Xptxas={opt}")
if self.split_compile is not None:
self.formatted_options.append(f"-split-compile={self.split_compile}")
Expand Down Expand Up @@ -293,21 +279,21 @@ def _init_driver(self):
self.formatted_options.append(1)
self.option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO)
if self.ftz is not None:
raise ValueError("ftz option is deprecated in the driver API")
warn("ftz option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
if self.prec_div is not None:
raise ValueError("prec_div option is deprecated in the driver API")
warn("prec_div option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
if self.prec_sqrt is not None:
raise ValueError("prec_sqrt option is deprecated in the driver API")
warn("prec_sqrt option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
if self.fma is not None:
raise ValueError("fma options is deprecated in the driver API")
warn("fma options is deprecated in the driver API", DeprecationWarning, stacklevel=3)
if self.kernels_used is not None:
raise ValueError("kernels_used is deprecated in the driver API")
warn("kernels_used is deprecated in the driver API", DeprecationWarning, stacklevel=3)
if self.variables_used is not None:
raise ValueError("variables_used is deprecated in the driver API")
warn("variables_used is deprecated in the driver API", DeprecationWarning, stacklevel=3)
if self.optimize_unused_variables is not None:
raise ValueError("optimize_unused_variables is deprecated in the driver API")
if self.xptxas is not None:
raise ValueError("xptxas option is not supported by the driver API")
warn("optimize_unused_variables is deprecated in the driver API", DeprecationWarning, stacklevel=3)
if self.ptxas_options is not None:
raise ValueError("ptxas_options option is not supported by the driver API")
if self.split_compile is not None:
raise ValueError("split_compile option is not supported by the driver API")
if self.split_compile_extended is not None:
Expand Down
Loading

0 comments on commit ea25e67

Please sign in to comment.