Skip to content

Commit

Permalink
[BACKEND][CPU] Initial plumbing for cpu backend (triton-lang#2)
Browse files Browse the repository at this point in the history
* [BACKEND][CPU] Implement the empty cpu backend

* Run clang-format

* Fix yadf error

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
minjang authored and ienkovich committed Dec 6, 2024
1 parent 83078a6 commit b935aec
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def get_platform_dependent_src_path(subdir):
f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/{system}-{arch}/cuda-cupti-{version}-0.tar.bz2")
(*version.split('.'))))

backends = [*BackendInstaller.copy(["nvidia", "amd"]), *BackendInstaller.copy_externals()]
backends = [*BackendInstaller.copy(["nvidia", "amd", "cpu"]), *BackendInstaller.copy_externals()]


def add_link_to_backends():
Expand Down
11 changes: 11 additions & 0 deletions python/triton/backends/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,14 @@ def __init__(self):
# TODO: remove once TMA is cleaned up
def assemble_tensormap_to_arg(self, tensormaps_info, args):
return args


class CPUDriverBase(DriverBase):

def __init__(self):
# Right now, we just provide dummy functions.
# TODO: Consider better engineering the code only intended for GPU in jit.py.
self.get_device_capability = lambda idx: (0, 0)
self.get_current_stream = lambda idx: 0
self.get_current_device = lambda: 0
self.set_current_device = lambda idx: None
10 changes: 10 additions & 0 deletions python/triton/runtime/driver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import os

from ..backends import backends
from ..backends import DriverBase


def _create_driver():
if os.getenv("TRITON_CPU_BACKEND", "0") == "1":
if "cpu" not in backends:
raise RuntimeError("TRITON_CPU_BACKEND is set, but CPU backend is unavailable.")
return backends["cpu"].driver()

actives = [x.driver for x in backends.values() if x.driver.is_active()]
if len(actives) >= 2 and backends["cpu"].driver.is_active():
print("Both CPU and GPU backends are available. Using the GPU backend.")
actives.remove(backends["cpu"].driver)
if len(actives) != 1:
raise RuntimeError(f"{len(actives)} active drivers ({actives}). There should only be one.")
return actives[0]()
Expand Down
7 changes: 7 additions & 0 deletions third_party/cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
add_subdirectory(include)
add_subdirectory(lib)
if(TRITON_BUILD_PYTHON_MODULE)
add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc)
endif()
106 changes: 106 additions & 0 deletions third_party/cpu/backend/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import functools
import hashlib
import re

from dataclasses import dataclass
from typing import Any

from triton._C.libtriton import cpu, ir, passes
from triton.backends.compiler import BaseBackend


@dataclass(frozen=True)
class CPUOptions:
# GPU-specific options are used in several places.
# For now, we just provide dummy values.
num_warps: int = 0
num_stages: int = 0
num_ctas: int = 0
cluster_dims: tuple = (1, 1, 1)
debug: bool = False

# TODO: We may introduce CPU-specific options like # of cores.

def __post_init__(self):
pass

def hash(self):
hash_dict = dict(self.__dict__)
key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())])
return hashlib.sha256(key.encode("utf-8")).hexdigest()


class CPUBackend(BaseBackend):

@staticmethod
def supports_target(target: tuple):
return target[0] == "cpu"

def __init__(self, target: tuple) -> None:
super().__init__(target)
self.binary_ext = "exe"

def parse_options(self, opts) -> Any:
args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts}
return CPUOptions(**args)

def pack_metadata(self, metadata):
return metadata

def get_codegen_implementation(self):
codegen_fns = dict()
return codegen_fns

def load_dialects(self, ctx):
cpu.load_dialects(ctx)

@staticmethod
def make_ttir(mod, metadata, opt):
# This is the same as the Nvidia backend.
pm = ir.pass_manager(mod.context)
pm.enable_debug()
passes.common.add_inliner(pm)
passes.ttir.add_rewrite_tensor_pointer(pm)
passes.ttir.add_combine(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_reorder_broadcast(pm)
passes.common.add_cse(pm)
passes.common.add_licm(pm)
passes.common.add_symbol_dce(pm)
pm.run(mod)
return mod

@staticmethod
def make_ttcir(mod, metadata, opt):
# TODO:
return mod

@staticmethod
def make_llir(src, metadata, options):
# TODO:
metadata["shared"] = 0
return src

@staticmethod
def make_exe(src, metadata, options):
# Right now, src is just TTIR. Extract kernel name from tt.func.
names = re.findall(r"\s+tt.func public @([a-zA-Z_][a-zA-Z0-9_]*)\(", str(src))
assert len(names) == 1
metadata["name"] = names[0]

# TODO: Call llc to create an executable.
return src

def add_stages(self, stages, options):
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
stages["ttcir"] = lambda src, metadata: self.make_ttcir(src, metadata, options)
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
stages["exe"] = lambda src, metadata: self.make_exe(src, metadata, options)

@functools.lru_cache()
def hash(self):
# TODO: Get more detailed CPU info like raw brand name with supported ISAs.
# Right now it would only return a simple string like "x86_64" or "aarch64".
import platform

return f"{platform.machine()}"
68 changes: 68 additions & 0 deletions third_party/cpu/backend/driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from triton.backends.driver import CPUDriverBase

# ------------------------
# Utils
# ------------------------


class CPUUtils(object):

def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(CPUUtils, cls).__new__(cls)
return cls.instance

def __init__(self):
pass

@staticmethod
def get_device_properties(device):
# This is just dummy for now. We will need to implement driver.c.
return {
"max_shared_mem": 0,
"multiprocessor_count": 0,
"sm_clock_rate": 0,
"mem_clock_rate": 0,
"mem_bus_width": 0,
}

@staticmethod
def load_binary(name, kernel_asm, shared, device):
# This is just dummy for now. We will need to implement driver.c.
return (None, kernel_asm, 0, 0)


# ------------------------
# Launcher
# ------------------------


def make_launcher(constants, signature, ids):
pass


class CPULauncher(object):

def __init__(self, src, metadata):
# TODO:
self.launch = lambda *args, **kwargs: None

def __call__(self, *args, **kwargs):
print("CPULauncher.__call__")
self.launch(*args, **kwargs)


class CPUDriver(CPUDriverBase):

def __init__(self):
self.utils = CPUUtils()
self.launcher_cls = CPULauncher
super().__init__()

def get_current_target(self):
# Capability and warp size are zeros for CPU.
return ("cpu", 0, 0)

@staticmethod
def is_active():
return True
Empty file.
Empty file.
24 changes: 24 additions & 0 deletions third_party/cpu/triton_cpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "llvm/IR/Constants.h"
#include "llvm/Support/TargetSelect.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>

#include <iostream>

namespace py = pybind11;

void init_triton_passes_ttcpuir(py::module &&m) {
// TODO:
}

void init_triton_cpu(py::module &&m) {
auto passes = m.def_submodule("passes");
init_triton_passes_ttcpuir(passes.def_submodule("ttcpuir"));

m.def("load_dialects", [](mlir::MLIRContext &context) {
// TODO:
});
}

0 comments on commit b935aec

Please sign in to comment.