Skip to content

Commit

Permalink
[FRONTEND] Added tl.clock and tl.globaltimer (#485)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet authored Mar 28, 2022
1 parent 76a9ee5 commit e0cc488
Show file tree
Hide file tree
Showing 13 changed files with 99 additions and 25 deletions.
3 changes: 3 additions & 0 deletions include/triton/codegen/selection/generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ class generator: public ir::visitor, public analysis::layout_visitor {
void visit_async_wait_inst(ir::async_wait_inst*);
// void visit_make_range_dyn(ir::make_range_dyn*);
void visit_make_range(ir::make_range*);
void visit_clock_inst(ir::clock_inst*);
void visit_globaltimer_inst(ir::globaltimer_inst*);
// void visit_make_range_sta(ir::make_range_sta*);
void visit_undef_value(ir::undef_value*);
void visit_constant_int(ir::constant_int*);
Expand All @@ -192,6 +194,7 @@ class generator: public ir::visitor, public analysis::layout_visitor {
void visit_argument(ir::argument*);
void visit(ir::module &, llvm::Module &);


// layouts
void visit_layout_mma(analysis::mma_layout*);
void visit_layout_scanline(analysis::scanline_layout*);
Expand Down
4 changes: 3 additions & 1 deletion include/triton/ir/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class builder{
// Constructor
builder(context &ctx);
// Getters
const context& get_context() { return ctx_; }
// const context& get_context() const { return ctx_; }
context& get_context() { return ctx_; }

// Setters
void set_insert_point(iterator instr);
void set_insert_point(instruction* i);
Expand Down
4 changes: 4 additions & 0 deletions include/triton/ir/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ struct dispatch{
static ir::value *sin(ir::value *x, ir::builder *builder);
static ir::value *sqrt(ir::value *x, ir::builder *builder);

// utilities
static ir::value *globaltimer(ir::builder *builder);
static ir::value *clock(ir::builder *builder);

// internal (debug/optimization)
static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder);
static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder);
Expand Down
2 changes: 2 additions & 0 deletions include/triton/ir/enums.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ enum value_id_t: unsigned {
INST_MAKE_RANGE_STA,
INST_MAKE_RANGE,
INST_PREFETCH_S,
INST_GLOBALTIMER,
INST_CLOCK,
};


Expand Down
21 changes: 21 additions & 0 deletions include/triton/ir/instructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,27 @@ class make_range: public instruction{
constant_int* last_;
};

/* timing utilities */
class clock_inst: public instruction{
clock_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "clock"; }
_TRITON_DEFINE_CLONE(clock_inst)
_TRITON_DEFINE_ACCEPT(clock_inst)

public:
static clock_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
};

class globaltimer_inst: public instruction{
globaltimer_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "globaltimer"; }
_TRITON_DEFINE_CLONE(globaltimer_inst)
_TRITON_DEFINE_ACCEPT(globaltimer_inst)

public:
static globaltimer_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
};


}
}
Expand Down
4 changes: 4 additions & 0 deletions include/triton/ir/visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class async_wait_inst;
class make_range_dyn;
class make_range;
class prefetch_s_inst;
class clock_inst;
class globaltimer_inst;

class make_range_sta;
class undef_value;
Expand Down Expand Up @@ -157,6 +159,8 @@ class visitor {
virtual void visit_make_range(make_range*) = 0;
virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0;
virtual void visit_function(function*) = 0;
virtual void visit_clock_inst(clock_inst*) = 0;
virtual void visit_globaltimer_inst(globaltimer_inst*) = 0;

virtual void visit_undef_value(undef_value*) = 0;
virtual void visit_constant_int(constant_int*) = 0;
Expand Down
16 changes: 14 additions & 2 deletions lib/codegen/selection/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1093,10 +1093,10 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
case tt::Xchg: name = "exch", s_ty = "b"; break;
}
std::string s_vec = vec == 2 ? "x2" : "";
std::string mod = nbits == 32 ? "" : ".noftz";
std::string mod = nbits == 16 ? ".noftz" : "";

std::string asm_str = "@$1 atom.global.gpu." + name + mod + "." + s_ty + s_nbits + s_vec + " $0, [$2" + offset + "], $3;";
std::string ty_id = nbits*vec == 32 ? "r" : "h";
std::string ty_id = nbits*vec == 64 ? "l" : (nbits*vec == 32 ? "r" : "h");
std::string constraint = "=" + ty_id + ",b,l," + ty_id;
// create inline asm
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
Expand Down Expand Up @@ -2597,6 +2597,18 @@ void generator::visit_barrier_inst(ir::barrier_inst*) {
add_barrier();
}

void generator::visit_clock_inst(ir::clock_inst* clock){
InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %clock64;", "=l", true);
vals_[clock][{}] = call(iasm);
}

void generator::visit_globaltimer_inst(ir::globaltimer_inst* timer){
InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %globaltimer;", "=l", true);
vals_[timer][{}] = call(iasm);
}



void generator::visit_prefetch_s_inst(ir::prefetch_s_inst *i) {
ir::value *v = i->get_operand(0);
int inc = i->get_inc();
Expand Down
3 changes: 2 additions & 1 deletion lib/driver/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,14 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
// create machine
module->setTargetTriple(triple);
std::string error;
llvm::TargetMachine* machine;
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if(layout.empty())
Expand Down
10 changes: 10 additions & 0 deletions lib/ir/dispatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,16 @@ ir::value *dispatch::sqrt(ir::value *x, ir::builder *builder) {
return builder->create_sqrt(x);
}

//

ir::value *dispatch::globaltimer(ir::builder *builder) {
return builder->insert(globaltimer_inst::create(builder->get_context()));
}

ir::value *dispatch::clock(ir::builder *builder) {
return builder->insert(clock_inst::create(builder->get_context()));

}

//

Expand Down
33 changes: 13 additions & 20 deletions lib/ir/instructions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -861,8 +861,7 @@ copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::stri
}

// barrier
barrier_inst::barrier_inst(context &ctx, const std::string &name,
instruction *next)
barrier_inst::barrier_inst(context &ctx, const std::string &name, instruction *next)
: instruction(type::get_void_ty(ctx), INST_BARRIER, 0, name, next) { }

barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) {
Expand All @@ -881,27 +880,21 @@ prefetch_s_inst *prefetch_s_inst::create(context &ctx, value *arg, int inc, cons
return new prefetch_s_inst(ctx, arg, inc, name, next);
}

//// nv_dynamic_program_idx
//make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
// : instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { }
// global timer
globaltimer_inst::globaltimer_inst(context &ctx, const std::string &name, instruction *next)
: instruction(type::get_int64_ty(ctx), INST_GLOBALTIMER, 0, name, next) { }

//make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) {
// return new make_range_dyn(ty, name, next);
//}

//// nv_static_program_idx
//make_range_sta::make_range_sta(make_range *range)
// : constant(range->get_type(), 0), range_(range) { }
globaltimer_inst* globaltimer_inst::create(context &ctx, const std::string &name, instruction *next) {
return new globaltimer_inst(ctx, name, next);
}

//make_range* make_range_sta::get_range() const
//{ return range_; }
// clock
clock_inst::clock_inst(context &ctx, const std::string &name, instruction *next)
: instruction(type::get_int64_ty(ctx), INST_CLOCK, 0, name, next) { }

//make_range_sta* make_range_sta::get(make_range* range) {
// static std::map<make_range*, make_range_sta*> cache;
// if(cache.find(range) == cache.end())
// cache.insert({range, new make_range_sta(range)});
// return cache.at(range);
//}
clock_inst* clock_inst::create(context &ctx, const std::string &name, instruction *next) {
return new clock_inst(ctx, name, next);
}


// make_range
Expand Down
3 changes: 3 additions & 0 deletions python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,9 @@ void init_triton_frontend(py::module &&m) {
m.def("cos", &ir::dispatch::cos, ret::reference);
m.def("sin", &ir::dispatch::sin, ret::reference);
m.def("sqrt", &ir::dispatch::sqrt, ret::reference);
// utilities
m.def("clock", &ir::dispatch::clock, ret::reference);
m.def("globaltimer", &ir::dispatch::globaltimer, ret::reference);
// internal (debugging only)
m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference);
m.def("max_contiguous", &ir::dispatch::max_contiguous, ret::reference);
Expand Down
13 changes: 13 additions & 0 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,19 @@ def sum(input, axis, _builder=None):
def xor_sum(input, axis, _builder=None):
return frontend.xor_sum(input, axis, _builder)

# -----------------------
# Utilities
# -----------------------


@builtin
def globaltimer(_builder=None):
return frontend.globaltimer(_builder)


@builtin
def clock(_builder=None):
return frontend.clock(_builder)

# -----------------------
# Internal for debugging
Expand Down
8 changes: 7 additions & 1 deletion python/tutorials/01-vector-add.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ def add_kernel(
y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector
time_start_ptr, time_end_ptr,
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
# NOTE: `constexpr` so it can be used as a shape value
):
tl.atomic_min(time_start_ptr, tl.clock())
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
Expand All @@ -45,6 +47,7 @@ def add_kernel(
output = x + y
# Write x + y back to DRAM
tl.store(output_ptr + offsets, output, mask=mask)
tl.atomic_max(time_end_ptr, tl.clock())


# %%
Expand All @@ -53,6 +56,8 @@ def add_kernel(


def add(x: torch.Tensor, y: torch.Tensor):
time_start = torch.zeros(1, dtype=torch.int64, device='cuda')
time_end = torch.zeros(1, dtype=torch.int64, device='cuda')
# We need to preallocate the output
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
Expand All @@ -65,9 +70,10 @@ def add(x: torch.Tensor, y: torch.Tensor):
# - each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
# - don't forget to pass meta-parameters as keywords arguments
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
add_kernel[grid](x, y, output, n_elements, time_start, time_end, BLOCK_SIZE=1024)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously at this point.
print((time_end, time_start))
return output


Expand Down

0 comments on commit e0cc488

Please sign in to comment.