forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[BACKEND][CPU] Initial plumbing for cpu backend (triton-lang#2)
* [BACKEND][CPU] Implement the empty cpu backend * Run clang-format * Fix yadf error Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
- Loading branch information
Showing
9 changed files
with
227 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) | ||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) | ||
add_subdirectory(include) | ||
add_subdirectory(lib) | ||
if(TRITON_BUILD_PYTHON_MODULE) | ||
add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc) | ||
endif() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import functools | ||
import hashlib | ||
import re | ||
|
||
from dataclasses import dataclass | ||
from typing import Any | ||
|
||
from triton._C.libtriton import cpu, ir, passes | ||
from triton.backends.compiler import BaseBackend | ||
|
||
|
||
@dataclass(frozen=True) | ||
class CPUOptions: | ||
# GPU-specific options are used in several places. | ||
# For now, we just provide dummy values. | ||
num_warps: int = 0 | ||
num_stages: int = 0 | ||
num_ctas: int = 0 | ||
cluster_dims: tuple = (1, 1, 1) | ||
debug: bool = False | ||
|
||
# TODO: We may introduce CPU-specific options like # of cores. | ||
|
||
def __post_init__(self): | ||
pass | ||
|
||
def hash(self): | ||
hash_dict = dict(self.__dict__) | ||
key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())]) | ||
return hashlib.sha256(key.encode("utf-8")).hexdigest() | ||
|
||
|
||
class CPUBackend(BaseBackend): | ||
|
||
@staticmethod | ||
def supports_target(target: tuple): | ||
return target[0] == "cpu" | ||
|
||
def __init__(self, target: tuple) -> None: | ||
super().__init__(target) | ||
self.binary_ext = "exe" | ||
|
||
def parse_options(self, opts) -> Any: | ||
args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} | ||
return CPUOptions(**args) | ||
|
||
def pack_metadata(self, metadata): | ||
return metadata | ||
|
||
def get_codegen_implementation(self): | ||
codegen_fns = dict() | ||
return codegen_fns | ||
|
||
def load_dialects(self, ctx): | ||
cpu.load_dialects(ctx) | ||
|
||
@staticmethod | ||
def make_ttir(mod, metadata, opt): | ||
# This is the same as the Nvidia backend. | ||
pm = ir.pass_manager(mod.context) | ||
pm.enable_debug() | ||
passes.common.add_inliner(pm) | ||
passes.ttir.add_rewrite_tensor_pointer(pm) | ||
passes.ttir.add_combine(pm) | ||
passes.common.add_canonicalizer(pm) | ||
passes.ttir.add_reorder_broadcast(pm) | ||
passes.common.add_cse(pm) | ||
passes.common.add_licm(pm) | ||
passes.common.add_symbol_dce(pm) | ||
pm.run(mod) | ||
return mod | ||
|
||
@staticmethod | ||
def make_ttcir(mod, metadata, opt): | ||
# TODO: | ||
return mod | ||
|
||
@staticmethod | ||
def make_llir(src, metadata, options): | ||
# TODO: | ||
metadata["shared"] = 0 | ||
return src | ||
|
||
@staticmethod | ||
def make_exe(src, metadata, options): | ||
# Right now, src is just TTIR. Extract kernel name from tt.func. | ||
names = re.findall(r"\s+tt.func public @([a-zA-Z_][a-zA-Z0-9_]*)\(", str(src)) | ||
assert len(names) == 1 | ||
metadata["name"] = names[0] | ||
|
||
# TODO: Call llc to create an executable. | ||
return src | ||
|
||
def add_stages(self, stages, options): | ||
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) | ||
stages["ttcir"] = lambda src, metadata: self.make_ttcir(src, metadata, options) | ||
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options) | ||
stages["exe"] = lambda src, metadata: self.make_exe(src, metadata, options) | ||
|
||
@functools.lru_cache() | ||
def hash(self): | ||
# TODO: Get more detailed CPU info like raw brand name with supported ISAs. | ||
# Right now it would only return a simple string like "x86_64" or "aarch64". | ||
import platform | ||
|
||
return f"{platform.machine()}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from triton.backends.driver import CPUDriverBase | ||
|
||
# ------------------------ | ||
# Utils | ||
# ------------------------ | ||
|
||
|
||
class CPUUtils(object): | ||
|
||
def __new__(cls): | ||
if not hasattr(cls, "instance"): | ||
cls.instance = super(CPUUtils, cls).__new__(cls) | ||
return cls.instance | ||
|
||
def __init__(self): | ||
pass | ||
|
||
@staticmethod | ||
def get_device_properties(device): | ||
# This is just dummy for now. We will need to implement driver.c. | ||
return { | ||
"max_shared_mem": 0, | ||
"multiprocessor_count": 0, | ||
"sm_clock_rate": 0, | ||
"mem_clock_rate": 0, | ||
"mem_bus_width": 0, | ||
} | ||
|
||
@staticmethod | ||
def load_binary(name, kernel_asm, shared, device): | ||
# This is just dummy for now. We will need to implement driver.c. | ||
return (None, kernel_asm, 0, 0) | ||
|
||
|
||
# ------------------------ | ||
# Launcher | ||
# ------------------------ | ||
|
||
|
||
def make_launcher(constants, signature, ids): | ||
pass | ||
|
||
|
||
class CPULauncher(object): | ||
|
||
def __init__(self, src, metadata): | ||
# TODO: | ||
self.launch = lambda *args, **kwargs: None | ||
|
||
def __call__(self, *args, **kwargs): | ||
print("CPULauncher.__call__") | ||
self.launch(*args, **kwargs) | ||
|
||
|
||
class CPUDriver(CPUDriverBase): | ||
|
||
def __init__(self): | ||
self.utils = CPUUtils() | ||
self.launcher_cls = CPULauncher | ||
super().__init__() | ||
|
||
def get_current_target(self): | ||
# Capability and warp size are zeros for CPU. | ||
return ("cpu", 0, 0) | ||
|
||
@staticmethod | ||
def is_active(): | ||
return True |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Pass/PassManager.h" | ||
#include "llvm/IR/Constants.h" | ||
#include "llvm/Support/TargetSelect.h" | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
#include <pybind11/stl_bind.h> | ||
|
||
#include <iostream> | ||
|
||
namespace py = pybind11; | ||
|
||
void init_triton_passes_ttcpuir(py::module &&m) { | ||
// TODO: | ||
} | ||
|
||
void init_triton_cpu(py::module &&m) { | ||
auto passes = m.def_submodule("passes"); | ||
init_triton_passes_ttcpuir(passes.def_submodule("ttcpuir")); | ||
|
||
m.def("load_dialects", [](mlir::MLIRContext &context) { | ||
// TODO: | ||
}); | ||
} |