Skip to content

Commit

Permalink
[part 1] Implementation of primitive ir emitter (PaddlePaddle#26)
Browse files Browse the repository at this point in the history
* Implementation of primitive ir emitter

* fix test

* fix code style
  • Loading branch information
Zhang Ting authored Aug 12, 2021
1 parent 1394fcd commit d60e0b2
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 18 deletions.
35 changes: 35 additions & 0 deletions paddle/fluid/compiler/piano/backends/llvm_ir/llvm_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Value.h"

namespace paddle {
namespace piano {
namespace backends {

llvm::Value* CallToLLVMIntrinsic(llvm::IRBuilder<>* ir_builder,
llvm::Intrinsic::ID llvm_Intrinsic) {
llvm::Module* llvm_module = ir_builder->GetInsertBlock()->getModule();
llvm::Function* func =
llvm::Intrinsic::getDeclaration(llvm_module, llvm_Intrinsic);
return ir_builder->CreateCall(func);
}

} // namespace backends
} // namespace piano
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
// limitations under the License.

#include "paddle/fluid/compiler/piano/backends/llvm_ir/nvptx_primitive_ir_emitter.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "paddle/fluid/compiler/piano/backends/llvm_ir/llvm_utils.h"

namespace paddle {
namespace piano {
Expand All @@ -35,39 +37,57 @@ NvptxPrimitiveIrEmitter::GetBinaryOp(const note::Instruction* instr) {
}

llvm::Value* NvptxPrimitiveIrEmitter::ThreadIdx(llvm::IRBuilder<>* ir_builder) {
return nullptr;
llvm::Intrinsic::ID llvm_Intrinsic =
llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x;
return CallToLLVMIntrinsic(ir_builder, llvm_Intrinsic);
}

llvm::Value* NvptxPrimitiveIrEmitter::ThreadIdy(llvm::IRBuilder<>* ir_builder) {
return nullptr;
llvm::Intrinsic::ID llvm_Intrinsic =
llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y;
return CallToLLVMIntrinsic(ir_builder, llvm_Intrinsic);
}

llvm::Value* NvptxPrimitiveIrEmitter::ThreadIdz(llvm::IRBuilder<>* ir_builder) {
return nullptr;
llvm::Intrinsic::ID llvm_Intrinsic =
llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z;
return CallToLLVMIntrinsic(ir_builder, llvm_Intrinsic);
}

llvm::Value* NvptxPrimitiveIrEmitter::BlockDimx(llvm::IRBuilder<>* ir_builder) {
return nullptr;
llvm::Intrinsic::ID llvm_Intrinsic =
llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x;
return CallToLLVMIntrinsic(ir_builder, llvm_Intrinsic);
}

llvm::Value* NvptxPrimitiveIrEmitter::BlockDimy(llvm::IRBuilder<>* ir_builder) {
return nullptr;
llvm::Intrinsic::ID llvm_Intrinsic =
llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y;
return CallToLLVMIntrinsic(ir_builder, llvm_Intrinsic);
}

llvm::Value* NvptxPrimitiveIrEmitter::BlockDimz(llvm::IRBuilder<>* ir_builder) {
return nullptr;
llvm::Intrinsic::ID llvm_Intrinsic =
llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z;
return CallToLLVMIntrinsic(ir_builder, llvm_Intrinsic);
}

llvm::Value* NvptxPrimitiveIrEmitter::BlockIdx(llvm::IRBuilder<>* ir_builder) {
return nullptr;
llvm::Intrinsic::ID llvm_Intrinsic =
llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x;
return CallToLLVMIntrinsic(ir_builder, llvm_Intrinsic);
}

llvm::Value* NvptxPrimitiveIrEmitter::BlockIdy(llvm::IRBuilder<>* ir_builder) {
return nullptr;
llvm::Intrinsic::ID llvm_Intrinsic =
llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_y;
return CallToLLVMIntrinsic(ir_builder, llvm_Intrinsic);
}

llvm::Value* NvptxPrimitiveIrEmitter::BlockIdz(llvm::IRBuilder<>* ir_builder) {
return nullptr;
llvm::Intrinsic::ID llvm_Intrinsic =
llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_z;
return CallToLLVMIntrinsic(ir_builder, llvm_Intrinsic);
}

void NvptxPrimitiveIrEmitter::ThreadSync(llvm::IRBuilder<>* ir_builder) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,25 @@ TEST(NvptxPrimitiveIrEmitter, GetUnaryOp) {

TEST(NvptxPrimitiveIrEmitter, DeviceBaseOp) {
NvptxPrimitiveIrEmitter nvptx_primitive_ir_emitter;
ASSERT_EQ(nvptx_primitive_ir_emitter.ThreadIdx(nullptr), nullptr);
ASSERT_EQ(nvptx_primitive_ir_emitter.ThreadIdy(nullptr), nullptr);
ASSERT_EQ(nvptx_primitive_ir_emitter.ThreadIdz(nullptr), nullptr);
ASSERT_EQ(nvptx_primitive_ir_emitter.BlockDimx(nullptr), nullptr);
ASSERT_EQ(nvptx_primitive_ir_emitter.BlockDimy(nullptr), nullptr);
ASSERT_EQ(nvptx_primitive_ir_emitter.BlockDimz(nullptr), nullptr);
ASSERT_EQ(nvptx_primitive_ir_emitter.BlockIdx(nullptr), nullptr);
ASSERT_EQ(nvptx_primitive_ir_emitter.BlockIdy(nullptr), nullptr);
ASSERT_EQ(nvptx_primitive_ir_emitter.BlockIdz(nullptr), nullptr);
llvm::LLVMContext context;
llvm::Module module("DeviceBaseOp", context);
llvm::IRBuilder<> builder(context);
llvm::FunctionType *func_type =
llvm::FunctionType::get(llvm::Type::getVoidTy(context), false);
llvm::Function *init_fn = llvm::Function::Create(
func_type, llvm::Function::ExternalLinkage, "init", module);
llvm::BasicBlock *entry = llvm::BasicBlock::Create(context, "entry", init_fn);
builder.SetInsertPoint(entry);

ASSERT_NE(nvptx_primitive_ir_emitter.ThreadIdx(&builder), nullptr);
ASSERT_NE(nvptx_primitive_ir_emitter.ThreadIdy(&builder), nullptr);
ASSERT_NE(nvptx_primitive_ir_emitter.ThreadIdz(&builder), nullptr);
ASSERT_NE(nvptx_primitive_ir_emitter.BlockDimx(&builder), nullptr);
ASSERT_NE(nvptx_primitive_ir_emitter.BlockDimy(&builder), nullptr);
ASSERT_NE(nvptx_primitive_ir_emitter.BlockDimz(&builder), nullptr);
ASSERT_NE(nvptx_primitive_ir_emitter.BlockIdx(&builder), nullptr);
ASSERT_NE(nvptx_primitive_ir_emitter.BlockIdy(&builder), nullptr);
ASSERT_NE(nvptx_primitive_ir_emitter.BlockIdz(&builder), nullptr);
ASSERT_EQ(nvptx_primitive_ir_emitter.Alloca(nullptr, 8), nullptr);
}

Expand Down

0 comments on commit d60e0b2

Please sign in to comment.