Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND] Added tl.clock and tl.globaltimer #485

Merged
merged 3 commits into from
Mar 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/triton/codegen/selection/generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ class generator: public ir::visitor, public analysis::layout_visitor {
void visit_async_wait_inst(ir::async_wait_inst*);
// void visit_make_range_dyn(ir::make_range_dyn*);
void visit_make_range(ir::make_range*);
void visit_clock_inst(ir::clock_inst*);
void visit_globaltimer_inst(ir::globaltimer_inst*);
// void visit_make_range_sta(ir::make_range_sta*);
void visit_undef_value(ir::undef_value*);
void visit_constant_int(ir::constant_int*);
Expand All @@ -192,6 +194,7 @@ class generator: public ir::visitor, public analysis::layout_visitor {
void visit_argument(ir::argument*);
void visit(ir::module &, llvm::Module &);


// layouts
void visit_layout_mma(analysis::mma_layout*);
void visit_layout_scanline(analysis::scanline_layout*);
Expand Down
4 changes: 3 additions & 1 deletion include/triton/ir/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class builder{
// Constructor
builder(context &ctx);
// Getters
const context& get_context() { return ctx_; }
// const context& get_context() const { return ctx_; }
context& get_context() { return ctx_; }

// Setters
void set_insert_point(iterator instr);
void set_insert_point(instruction* i);
Expand Down
4 changes: 4 additions & 0 deletions include/triton/ir/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ struct dispatch{
static ir::value *sin(ir::value *x, ir::builder *builder);
static ir::value *sqrt(ir::value *x, ir::builder *builder);

// utilities
static ir::value *globaltimer(ir::builder *builder);
static ir::value *clock(ir::builder *builder);

// internal (debug/optimization)
static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder);
static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder);
Expand Down
2 changes: 2 additions & 0 deletions include/triton/ir/enums.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ enum value_id_t: unsigned {
INST_MAKE_RANGE_STA,
INST_MAKE_RANGE,
INST_PREFETCH_S,
INST_GLOBALTIMER,
INST_CLOCK,
};


Expand Down
21 changes: 21 additions & 0 deletions include/triton/ir/instructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,27 @@ class make_range: public instruction{
constant_int* last_;
};

/* timing utilities */
class clock_inst: public instruction{
clock_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "clock"; }
_TRITON_DEFINE_CLONE(clock_inst)
_TRITON_DEFINE_ACCEPT(clock_inst)

public:
static clock_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
};

class globaltimer_inst: public instruction{
globaltimer_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "globaltimer"; }
_TRITON_DEFINE_CLONE(globaltimer_inst)
_TRITON_DEFINE_ACCEPT(globaltimer_inst)

public:
static globaltimer_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
};


}
}
Expand Down
4 changes: 4 additions & 0 deletions include/triton/ir/visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class async_wait_inst;
class make_range_dyn;
class make_range;
class prefetch_s_inst;
class clock_inst;
class globaltimer_inst;

class make_range_sta;
class undef_value;
Expand Down Expand Up @@ -157,6 +159,8 @@ class visitor {
virtual void visit_make_range(make_range*) = 0;
virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0;
virtual void visit_function(function*) = 0;
virtual void visit_clock_inst(clock_inst*) = 0;
virtual void visit_globaltimer_inst(globaltimer_inst*) = 0;

virtual void visit_undef_value(undef_value*) = 0;
virtual void visit_constant_int(constant_int*) = 0;
Expand Down
16 changes: 14 additions & 2 deletions lib/codegen/selection/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1093,10 +1093,10 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
case tt::Xchg: name = "exch", s_ty = "b"; break;
}
std::string s_vec = vec == 2 ? "x2" : "";
std::string mod = nbits == 32 ? "" : ".noftz";
std::string mod = nbits == 16 ? ".noftz" : "";

std::string asm_str = "@$1 atom.global.gpu." + name + mod + "." + s_ty + s_nbits + s_vec + " $0, [$2" + offset + "], $3;";
std::string ty_id = nbits*vec == 32 ? "r" : "h";
std::string ty_id = nbits*vec == 64 ? "l" : (nbits*vec == 32 ? "r" : "h");
std::string constraint = "=" + ty_id + ",b,l," + ty_id;
// create inline asm
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
Expand Down Expand Up @@ -2597,6 +2597,18 @@ void generator::visit_barrier_inst(ir::barrier_inst*) {
add_barrier();
}

void generator::visit_clock_inst(ir::clock_inst* clock){
InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %clock64;", "=l", true);
vals_[clock][{}] = call(iasm);
}

void generator::visit_globaltimer_inst(ir::globaltimer_inst* timer){
InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %globaltimer;", "=l", true);
vals_[timer][{}] = call(iasm);
}



void generator::visit_prefetch_s_inst(ir::prefetch_s_inst *i) {
ir::value *v = i->get_operand(0);
int inc = i->get_inc();
Expand Down
3 changes: 2 additions & 1 deletion lib/driver/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,14 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
// create machine
module->setTargetTriple(triple);
std::string error;
llvm::TargetMachine* machine;
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if(layout.empty())
Expand Down
10 changes: 10 additions & 0 deletions lib/ir/dispatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,16 @@ ir::value *dispatch::sqrt(ir::value *x, ir::builder *builder) {
return builder->create_sqrt(x);
}

//

ir::value *dispatch::globaltimer(ir::builder *builder) {
return builder->insert(globaltimer_inst::create(builder->get_context()));
}

ir::value *dispatch::clock(ir::builder *builder) {
return builder->insert(clock_inst::create(builder->get_context()));

}

//

Expand Down
33 changes: 13 additions & 20 deletions lib/ir/instructions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -861,8 +861,7 @@ copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::stri
}

// barrier
barrier_inst::barrier_inst(context &ctx, const std::string &name,
instruction *next)
barrier_inst::barrier_inst(context &ctx, const std::string &name, instruction *next)
: instruction(type::get_void_ty(ctx), INST_BARRIER, 0, name, next) { }

barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) {
Expand All @@ -881,27 +880,21 @@ prefetch_s_inst *prefetch_s_inst::create(context &ctx, value *arg, int inc, cons
return new prefetch_s_inst(ctx, arg, inc, name, next);
}

//// nv_dynamic_program_idx
//make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
// : instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { }
// global timer
globaltimer_inst::globaltimer_inst(context &ctx, const std::string &name, instruction *next)
: instruction(type::get_int64_ty(ctx), INST_GLOBALTIMER, 0, name, next) { }

//make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) {
// return new make_range_dyn(ty, name, next);
//}

//// nv_static_program_idx
//make_range_sta::make_range_sta(make_range *range)
// : constant(range->get_type(), 0), range_(range) { }
globaltimer_inst* globaltimer_inst::create(context &ctx, const std::string &name, instruction *next) {
return new globaltimer_inst(ctx, name, next);
}

//make_range* make_range_sta::get_range() const
//{ return range_; }
// clock
clock_inst::clock_inst(context &ctx, const std::string &name, instruction *next)
: instruction(type::get_int64_ty(ctx), INST_CLOCK, 0, name, next) { }

//make_range_sta* make_range_sta::get(make_range* range) {
// static std::map<make_range*, make_range_sta*> cache;
// if(cache.find(range) == cache.end())
// cache.insert({range, new make_range_sta(range)});
// return cache.at(range);
//}
clock_inst* clock_inst::create(context &ctx, const std::string &name, instruction *next) {
return new clock_inst(ctx, name, next);
}


// make_range
Expand Down
3 changes: 3 additions & 0 deletions python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,9 @@ void init_triton_frontend(py::module &&m) {
m.def("cos", &ir::dispatch::cos, ret::reference);
m.def("sin", &ir::dispatch::sin, ret::reference);
m.def("sqrt", &ir::dispatch::sqrt, ret::reference);
// utilities
m.def("clock", &ir::dispatch::clock, ret::reference);
m.def("globaltimer", &ir::dispatch::globaltimer, ret::reference);
// internal (debugging only)
m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference);
m.def("max_contiguous", &ir::dispatch::max_contiguous, ret::reference);
Expand Down
13 changes: 13 additions & 0 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,19 @@ def sum(input, axis, _builder=None):
def xor_sum(input, axis, _builder=None):
return frontend.xor_sum(input, axis, _builder)

# -----------------------
# Utilities
# -----------------------


@builtin
def globaltimer(_builder=None):
return frontend.globaltimer(_builder)


@builtin
def clock(_builder=None):
return frontend.clock(_builder)

# -----------------------
# Internal for debugging
Expand Down
8 changes: 7 additions & 1 deletion python/tutorials/01-vector-add.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ def add_kernel(
y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector
time_start_ptr, time_end_ptr,
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
# NOTE: `constexpr` so it can be used as a shape value
):
tl.atomic_min(time_start_ptr, tl.clock())
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
Expand All @@ -45,6 +47,7 @@ def add_kernel(
output = x + y
# Write x + y back to DRAM
tl.store(output_ptr + offsets, output, mask=mask)
tl.atomic_max(time_end_ptr, tl.clock())


# %%
Expand All @@ -53,6 +56,8 @@ def add_kernel(


def add(x: torch.Tensor, y: torch.Tensor):
time_start = torch.zeros(1, dtype=torch.int64, device='cuda')
time_end = torch.zeros(1, dtype=torch.int64, device='cuda')
# We need to preallocate the output
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
Expand All @@ -65,9 +70,10 @@ def add(x: torch.Tensor, y: torch.Tensor):
# - each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
# - don't forget to pass meta-parameters as keywords arguments
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
add_kernel[grid](x, y, output, n_elements, time_start, time_end, BLOCK_SIZE=1024)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously at this point.
print((time_end, time_start))
return output


Expand Down