Skip to content

Commit

Permalink
Merge remote-tracking branch 'leofang/cluster' into HEAD
Browse files Browse the repository at this point in the history
  • Loading branch information
ksimpson-work committed Dec 13, 2024
2 parents b8004e9 + 2c3a619 commit 4b95ba4
Show file tree
Hide file tree
Showing 16 changed files with 220 additions and 69 deletions.
7 changes: 7 additions & 0 deletions .github/actions/test/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ runs:
shell: bash --noprofile --norc -xeuo pipefail {0}
run: nvidia-smi

# The cache action needs this
- name: Install zstd
shell: bash --noprofile --norc -xeuo pipefail {0}
run: |
apt update
apt install zstd
- name: Download bindings build artifacts
uses: actions/download-artifact@v4
with:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci-gh.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
upload-enabled:
- false
python-version:
- "3.13"
- "3.12"
- "3.11"
- "3.10"
Expand Down
12 changes: 7 additions & 5 deletions .github/workflows/gh-build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,19 @@ jobs:
test:
# TODO: improve the name once a separate test matrix is defined
name: Test (CUDA ${{ inputs.cuda-version }})
# TODO: enable testing once linux-aarch64 & win-64 GPU runners are up
# TODO: enable testing once win-64 GPU runners are up
if: ${{ (github.repository_owner == 'nvidia') &&
startsWith(inputs.host-platform, 'linux-x64') }}
startsWith(inputs.host-platform, 'linux') }}
permissions:
id-token: write # This is required for configure-aws-credentials
contents: read # This is required for actions/checkout
runs-on: ${{ (inputs.host-platform == 'linux-x64' && 'linux-amd64-gpu-v100-latest-1') }}
# TODO: use a different (nvidia?) container, or just run on bare image
runs-on: ${{ (inputs.host-platform == 'linux-x64' && 'linux-amd64-gpu-v100-latest-1') ||
(inputs.host-platform == 'linux-aarch64' && 'linux-arm64-gpu-a100-latest-1') }}
# Our self-hosted runners require a container
# TODO: use a different (nvidia?) container
container:
options: -u root --security-opt seccomp=unconfined --privileged --shm-size 16g
image: condaforge/miniforge3:latest
image: ubuntu:22.04
env:
NVIDIA_VISIBLE_DEVICES: ${{ env.NVIDIA_VISIBLE_DEVICES }}
needs:
Expand Down
2 changes: 1 addition & 1 deletion cuda_core/cuda/core/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
#
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

__version__ = "0.1.0"
__version__ = "0.1.1"
5 changes: 5 additions & 0 deletions cuda_core/cuda/core/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@
from cuda.core.experimental._linker import Linker, LinkerOptions
from cuda.core.experimental._program import Program
from cuda.core.experimental._stream import Stream, StreamOptions
from cuda.core.experimental._system import System

system = System()
__import__("sys").modules[__spec__.name + ".system"] = system
del System
2 changes: 1 addition & 1 deletion cuda_core/cuda/core/experimental/_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def link(self, target_type) -> ObjectCode:
return ObjectCode(bytes(code), target_type)

def get_error_log(self) -> str:
""" Get the error log generated by the linker.
"""Get the error log generated by the linker.
Returns
-------
Expand Down
40 changes: 27 additions & 13 deletions cuda_core/cuda/core/experimental/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import importlib.metadata

from cuda import cuda
from cuda.core.experimental._utils import handle_return
from cuda.core.experimental._utils import handle_return, precondition

_backend = {
"old": {
Expand Down Expand Up @@ -106,30 +106,43 @@ class ObjectCode:
"""

__slots__ = ("_handle", "_code_type", "_module", "_loader", "_sym_map")
__slots__ = ("_handle", "_backend_version", "_jit_options", "_code_type", "_module", "_loader", "_sym_map")
_supported_code_type = ("cubin", "ptx", "ltoir", "fatbin")

def __init__(self, module, code_type, jit_options=None, *, symbol_mapping=None):
if code_type not in self._supported_code_type:
raise ValueError
_lazy_init()

# handle is assigned during _lazy_load
self._handle = None
self._jit_options = jit_options

self._backend_version = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000) else "old"
self._loader = _backend[self._backend_version]

self._code_type = code_type
self._module = module
self._sym_map = {} if symbol_mapping is None else symbol_mapping

backend = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000) else "old"
self._loader = _backend[backend]
# TODO: do we want to unload in a finalizer? Probably not..

def _lazy_load_module(self, *args, **kwargs):
if self._handle is not None:
return
jit_options = self._jit_options
module = self._module
if isinstance(module, str):
# TODO: this option is only taken by the new library APIs, but we have
# a bug that we can't easily support it just yet (NVIDIA/cuda-python#73).
if jit_options is not None:
raise ValueError
module = module.encode()
self._handle = handle_return(self._loader["file"](module))
else:
assert isinstance(module, bytes)
if jit_options is None:
jit_options = {}
if backend == "new":
if self._backend_version == "new":
args = (
module,
list(jit_options.keys()),
Expand All @@ -141,15 +154,15 @@ def __init__(self, module, code_type, jit_options=None, *, symbol_mapping=None):
0,
)
else: # "old" backend
args = (module, len(jit_options), list(jit_options.keys()), list(jit_options.values()))
args = (
module,
len(jit_options),
list(jit_options.keys()),
list(jit_options.values()),
)
self._handle = handle_return(self._loader["data"](*args))

self._code_type = code_type
self._module = module
self._sym_map = {} if symbol_mapping is None else symbol_mapping

# TODO: do we want to unload in a finalizer? Probably not..

@precondition(_lazy_load_module)
def get_kernel(self, name):
"""Return the :obj:`Kernel` of a specified name from this object code.
Expand All @@ -168,6 +181,7 @@ def get_kernel(self, name):
name = self._sym_map[name]
except KeyError:
name = name.encode()

data = handle_return(self._loader["kernel"](self._handle, name))
return Kernel._from_obj(data, self)

Expand Down
67 changes: 67 additions & 0 deletions cuda_core/cuda/core/experimental/_system.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

from typing import Tuple

from cuda import cuda, cudart
from cuda.core.experimental._device import Device
from cuda.core.experimental._utils import handle_return


class System:
"""Provide information about the cuda system.
This class is a singleton and should not be instantiated directly.
"""

_instance = None

def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

def __init__(self):
if hasattr(self, "_initialized") and self._initialized:
return
self._initialized = True

@property
def driver_version(self) -> Tuple[int, int]:
"""
Query the CUDA driver version.
Returns
-------
tuple of int
A 2-tuple of (major, minor) version numbers.
"""
version = handle_return(cuda.cuDriverGetVersion())
major = version // 1000
minor = (version % 1000) // 10
return (major, minor)

@property
def num_devices(self) -> int:
"""
Query the number of available GPUs.
Returns
-------
int
The number of available GPU devices.
"""
return handle_return(cudart.cudaGetDeviceCount())

@property
def devices(self) -> tuple:
"""
Query the available device instances.
Returns
-------
tuple of Device
A tuple containing instances of available devices.
"""
total = self.num_devices
return tuple(Device(device_id) for device_id in range(total))
11 changes: 11 additions & 0 deletions cuda_core/docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ CUDA compilation toolchain
LinkerOptions


CUDA system information
-----------------------

.. autodata:: cuda.core.experimental.system.driver_version
:no-value:
.. autodata:: cuda.core.experimental.system.num_devices
:no-value:
.. autodata:: cuda.core.experimental.system.devices
:no-value:


.. module:: cuda.core.experimental.utils

Utility functions
Expand Down
28 changes: 28 additions & 0 deletions cuda_core/docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,31 @@

napoleon_google_docstring = False
napoleon_numpy_docstring = True


section_titles = ["Returns"]
def autodoc_process_docstring(app, what, name, obj, options, lines):
if name.startswith("cuda.core.experimental.system"):
# patch the docstring (in lines) *in-place*. Should docstrings include section titles other than "Returns",
# this will need to be modified to handle them.
attr = name.split(".")[-1]
from cuda.core.experimental._system import System

lines_new = getattr(System, attr).__doc__.split("\n")
formatted_lines = []
for line in lines_new:
title = line.strip()
if title in section_titles:
formatted_lines.append(line.replace(title, f".. rubric:: {title}"))
elif line.strip() == "-" * len(title):
formatted_lines.append(" " * len(title))
else:
formatted_lines.append(line)
n_pops = len(lines)
lines.extend(formatted_lines)
for _ in range(n_pops):
lines.pop(0)


def setup(app):
app.connect("autodoc-process-docstring", autodoc_process_docstring)
1 change: 1 addition & 0 deletions cuda_core/docs/source/release.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ maxdepth: 3
0.1.1 <release/0.1.1-notes>
0.1.0 <release/0.1.0-notes>
```
1 change: 1 addition & 0 deletions cuda_core/docs/source/release/0.1.1-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Released on Dec XX, 2024
- Add `Linker` that can link one or multiple `ObjectCode` instances generated by `Program`s. Under
the hood, it uses either the nvJitLink or cuLink APIs depending on the CUDA version detected
in the current environment.
- Add a `cuda.core.experimental.system` module for querying system- or process- wide information.
- Support TCC devices with a default synchronous memory resource to avoid the use of memory pools

## New features
Expand Down
10 changes: 8 additions & 2 deletions cuda_core/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
import sys

try:
from cuda.bindings import driver
from cuda.bindings import driver, nvrtc
except ImportError:
from cuda import cuda as driver

from cuda import nvrtc
import pytest

from cuda.core.experimental import Device, _device
Expand Down Expand Up @@ -65,3 +65,9 @@ def clean_up_cffi_files():
os.remove(f)
except FileNotFoundError:
pass # noqa: SIM105


def can_load_generated_ptx():
_, driver_ver = driver.cuDriverGetVersion()
_, nvrtc_major, nvrtc_minor = nvrtc.nvrtcVersion()
return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver
55 changes: 17 additions & 38 deletions cuda_core/tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,22 @@
# this software and related documentation outside the terms of the EULA
# is strictly prohibited.

import importlib

import pytest

from cuda.core.experimental._module import ObjectCode


@pytest.mark.skipif(
int(importlib.metadata.version("cuda-python").split(".")[0]) < 12,
reason="Module loading for older drivers validate require valid module code.",
)
def test_object_code_initialization():
# Test with supported code types
for code_type in ["cubin", "ptx", "fatbin"]:
module_data = b"dummy_data"
obj_code = ObjectCode(module_data, code_type)
assert obj_code._code_type == code_type
assert obj_code._module == module_data
assert obj_code._handle is not None

# Test with unsupported code type
with pytest.raises(ValueError):
ObjectCode(b"dummy_data", "unsupported_code_type")


# TODO add ObjectCode tests which provide the appropriate data for cuLibraryLoadFromFile
def test_object_code_initialization_with_str():
assert True


def test_object_code_initialization_with_jit_options():
assert True


def test_object_code_get_kernel():
assert True


def test_kernel_from_obj():
assert True
from conftest import can_load_generated_ptx

from cuda.core.experimental import Program


@pytest.mark.xfail(not can_load_generated_ptx(), reason="PTX version too new")
def test_get_kernel():
kernel = """
extern __device__ int B();
extern __device__ int C(int a, int b);
__global__ void A() { int result = C(B(), 1);}
"""
object_code = Program(kernel, "c++").compile("ptx", options=("-rdc=true",))
assert object_code._handle is None
kernel = object_code.get_kernel("A")
assert object_code._handle is not None
assert kernel._handle is not None
10 changes: 1 addition & 9 deletions cuda_core/tests/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,12 @@
# is strictly prohibited.

import pytest
from conftest import can_load_generated_ptx

from cuda import cuda, nvrtc
from cuda.core.experimental import Device, Program
from cuda.core.experimental._module import Kernel, ObjectCode


def can_load_generated_ptx():
_, driver_ver = cuda.cuDriverGetVersion()
_, nvrtc_major, nvrtc_minor = nvrtc.nvrtcVersion()
if nvrtc_major * 1000 + nvrtc_minor * 10 > driver_ver:
return False
return True


def test_program_init_valid_code_type():
code = 'extern "C" __global__ void my_kernel() {}'
program = Program(code, "c++")
Expand Down
Loading

0 comments on commit 4b95ba4

Please sign in to comment.