Skip to content

Commit

Permalink
Revert unreviewed changes. (triton-lang#5)
Browse files Browse the repository at this point in the history
Co-authored-by: Shane Nay <snay@meta.com>
  • Loading branch information
2 people authored and minjang committed Oct 23, 2024
1 parent 5670f69 commit 3f6e029
Show file tree
Hide file tree
Showing 40 changed files with 77 additions and 2,744 deletions.
8 changes: 0 additions & 8 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"

#include "cpu/include/TritonCPUToLLVM/Passes.h"
#include "cpu/include/TritonToTritonCPU/Passes.h"
#include "nvidia/include/NVGPUToLLVM/Passes.h"
#include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
Expand Down Expand Up @@ -66,12 +64,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::registerTritonAMDGPUStreamPipelineV2();
mlir::registerTritonAMDGPUCanonicalizePointers();

// CPU passes
mlir::triton::cpu::registerTritonToTritonCPUPasses();
mlir::triton::cpu::registerTritonToTritonCPUPipeline();
mlir::triton::cpu::registerTritonCPUToLLVMPasses();
mlir::triton::cpu::registerTritonCPUToLLVMPipeline();

// TODO: register Triton & TritonGPU passes
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
mlir::triton::cpu::TritonCPUDialect,
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Conversion/TritonCPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUToLLVM)
add_public_tablegen_target(TritonCPUToLLVMConversionPassIncGen)
add_public_tablegen_target(TritonCPUConversionPassIncGen)
5 changes: 3 additions & 2 deletions include/triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ class TritonCPU_Attr<string name, string attrMnemonic, list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute">
: AttrDef<dialect, name, !listconcat([TritonCPU_AttrTrait], traits), baseCppClass> {

let description = [{TritonCPU attr.}];
let attrName = "triton.cpu." # attrMnemonic;
let description = [{
WIP...
}];
}

#endif
3 changes: 0 additions & 3 deletions include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@ def TritonCPU_Dialect : Dialect {
let dependentDialects = [
"triton::TritonDialect",
"tensor::TensorDialect",
"mlir::memref::MemRefDialect",
];

let extraClassDeclaration = [{
void registerTypes();
}];

let useDefaultTypePrinterParser = 1;
let useDefaultAttributePrinterParser = 1;
let usePropertiesForAttributes = 1;
}

#endif
51 changes: 0 additions & 51 deletions include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,57 +7,6 @@ include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td"
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/ViewLikeInterface.td"

class TTC_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonCPU_Dialect, mnemonic,
!listconcat(traits, [])> {
}

def TTC_ExtractMemRefOp : TTC_Op<"extract_memref", [NoMemoryEffect]> {
let summary = "Extract base memref from a block pointer";

let description = [{
Extract base memref from a block pointer. It covers whole base tensor memory,
not only the block referenced. Base pointer, shape, and strides are used
in the resulting memref. Offsets and block shape are ignored.

}];

let arguments = (ins TT_TensorPtr:$src);

let results = (outs AnyRankedOrUnrankedMemRef:$result);

let hasCanonicalizer = 1;

let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
}

def TTC_ExtractIndicesOp : TTC_Op<"extract_indices", [NoMemoryEffect]> {
let summary = "Extract indices from a block pointer.";

let description = [{
Extract indices that can be used to access the block using its base memref.
Indices are supposed to be used for vector loads/stores with the base
memref extracted from the same block pointer.
}];

let arguments = (ins TT_TensorPtr:$src);

let results = (outs Variadic<Index>:$result);

let builders = [
OpBuilder<(ins "Value":$src)>
];

let hasCanonicalizer = 1;

let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
}

#endif
4 changes: 2 additions & 2 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#add_subdirectory(TritonToTritonCPU)
add_subdirectory(TritonToTritonCPU)
add_subdirectory(TritonToTritonGPU)
#add_subdirectory(TritonCPUToLLVM)
add_subdirectory(TritonCPUToLLVM)
add_subdirectory(TritonGPUToLLVM)
37 changes: 1 addition & 36 deletions lib/Dialect/TritonCPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,16 @@

#include <numeric>

#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonCPU/IR/Dialect.cpp.inc"
#include "triton/Dialect/TritonCPU/IR/Dialect.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/ADT/TypeSwitch.h"

#include "triton/Dialect/TritonCPU/IR/Dialect.cpp.inc"

using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::cpu;

//===----------------------------------------------------------------------===//
Expand All @@ -23,35 +20,6 @@ using namespace mlir::triton::cpu;
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.cpp.inc"

void ExtractMemRefOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {}

void ExtractIndicesOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {}

/// Parse an attribute registered to this dialect.
::mlir::Attribute
TritonCPUDialect::parseAttribute(::mlir::DialectAsmParser &parser,
::mlir::Type type) const {
llvm_unreachable("parse stub called");
}

/// Print an attribute registered to this dialect.
void TritonCPUDialect::printAttribute(::mlir::Attribute attr,
::mlir::DialectAsmPrinter &os) const {
llvm_unreachable("print stub called");
}

void ExtractIndicesOp::build(::mlir::OpBuilder &builder,
::mlir::OperationState &state, Value src) {
assert(triton::isTensorPointerType(src.getType()) &&
"Unexecpeted source type");
auto tensorTy = dyn_cast<RankedTensorType>(
dyn_cast<PointerType>(src.getType()).getPointeeType());
SmallVector<Type> resTypes(tensorTy.getRank(), builder.getIndexType());
build(builder, state, resTypes, src);
}

void TritonCPUDialect::initialize() {
registerTypes();

Expand All @@ -66,9 +34,6 @@ void TritonCPUDialect::initialize() {
>();
}

#define GET_OP_CLASSES
#include "triton/Dialect/TritonCPU/IR/Ops.cpp.inc"

// verify TritonCPU ops
LogicalResult TritonCPUDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
Expand Down
67 changes: 0 additions & 67 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/CodeGen/CommandFlags.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
Expand All @@ -23,7 +21,6 @@
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/TargetParser/Host.h"
#include "llvm/Transforms/IPO/AlwaysInliner.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include <csignal>
Expand Down Expand Up @@ -393,70 +390,6 @@ void init_triton_llvm(py::module &&m) {
py::arg("flags") = std::vector<std::string>{},
py::arg("enable_fp_fusion") = false);

m.def("set_host_target", [](llvm::Module *mod) {
mod->setTargetTriple(llvm::sys::getDefaultTargetTriple());
std::string error;
auto target =
llvm::TargetRegistry::lookupTarget(mod->getTargetTriple(), error);
std::unique_ptr<llvm::TargetMachine> machine{target->createTargetMachine(
mod->getTargetTriple(), llvm::sys::getHostCPUName(), "", {},
llvm::Reloc::PIC_)};
mod->setDataLayout(machine->createDataLayout());
});

m.def(
"translate_to_host_asm",
[](std::string llvmIR) -> py::object {
std::string res;
{
// when allow_threads goes out of scope, gil will be released
py::gil_scoped_release allow_threads;
// create LLVM module from C++
llvm::LLVMContext context;
std::unique_ptr<llvm::MemoryBuffer> buffer =
llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str());
llvm::SMDiagnostic error;
std::unique_ptr<llvm::Module> module =
llvm::parseIR(buffer->getMemBufferRef(), error, context);
if (!module) {
llvm::report_fatal_error(
"failed to parse IR: " + error.getMessage() +
"lineno: " + std::to_string(error.getLineNo()));
}
res = translateLLVMIRToASM(
*module, llvm::sys::getDefaultTargetTriple(),
llvm::sys::getHostCPUName().str(), "", {}, false, false);
}
return py::str(res);
},
ret::take_ownership);

m.def(
"translate_to_bc",
[](const std::string llvmIR) -> py::object {
py::gil_scoped_release allow_threads;
// create LLVM module
llvm::LLVMContext context;
std::unique_ptr<llvm::MemoryBuffer> buffer =
llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str());
llvm::SMDiagnostic error;
std::unique_ptr<llvm::Module> module =
llvm::parseIR(buffer->getMemBufferRef(), error, context);
if (!module) {
llvm::report_fatal_error(
"failed to parse IR: " + error.getMessage() +
"lineno: " + std::to_string(error.getLineNo()));
}
// Write bitcode to a buffer.
llvm::SmallVector<char, 0> buf;
llvm::BitcodeWriter writer(buf);
writer.writeModule(*module);
writer.writeStrtab();
std::string bitcode(buf.begin(), buf.end());
return py::bytes(bitcode);
},
ret::take_ownership);

m.def(
"translate_to_asm",
[](std::string llvmIR, std::string triple, std::string proc,
Expand Down
5 changes: 2 additions & 3 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ void init_triton_passes_ttir(py::module &&m) {
ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir",
createConvertTritonToTritonGPUPass, const std::string &,
int, int, int);
// ADD_PASS_WRAPPER_0("add_convert_to_ttcpuir",
// createConvertTritonToTritonCPUPass);
ADD_PASS_WRAPPER_0("add_convert_to_ttcpuir",
createConvertTritonToTritonCPUPass);
}

void init_triton_passes_ttgpuir(py::module &&m) {
Expand Down Expand Up @@ -81,7 +81,6 @@ void init_triton_passes_convert(py::module &&m) {
ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass);
ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass);
ADD_PASS_WRAPPER_0("add_arith_to_llvmir", createArithToLLVMConversionPass);
ADD_PASS_WRAPPER_0("add_math_to_llvmir", createConvertMathToLLVMPass);
}

void init_triton_passes_llvmir(py::module &&m) {
Expand Down
1 change: 0 additions & 1 deletion python/test/unit/language/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@

def pytest_configure(config):
config.addinivalue_line("markers", "interpreter: indicate whether interpreter supports the test")
config.addinivalue_line("markers", "cpu: indicate whether test is supported on cpu")
Loading

0 comments on commit 3f6e029

Please sign in to comment.