Skip to content

Commit

Permalink
Improve "JALR" execution with T2C JIT-cache
Browse files Browse the repository at this point in the history
Currently, the "JALR" indirect jump instruction turns the mode of
rv32emu from T2C back to the interpreter. This commit introduces a
"JIT-cache" table lookup to make it redirect to the T2C JIT-ed code
entry and avoid the mode change.

There are several scenarios benefitting from this approach, e.g.
function pointer invocation and far-way function call. The former
like "qsort" can be speeded up by two times, and the latter like
"fibonacci", which compiled from the hand-written assembly for
creating "JALR" instructions, can even reach x4.8 performance
enhencement.
  • Loading branch information
vacantron committed Aug 7, 2024
1 parent 9759ad2 commit 65fdaef
Show file tree
Hide file tree
Showing 8 changed files with 187 additions and 29 deletions.
Binary file added build/fibonacci.elf
Binary file not shown.
1 change: 1 addition & 0 deletions src/jit.c
Original file line number Diff line number Diff line change
Expand Up @@ -1864,6 +1864,7 @@ static void code_cache_flush(struct jit_state *state, riscv_t *rv)
state->offset = state->org_size;
state->n_blocks = 0;
set_reset(&state->set);
jit_cache_clear(rv->jit_cache);
clear_cache_hot(rv->block_cache, (clear_func_t) clear_hot);
return;
}
Expand Down
22 changes: 21 additions & 1 deletion src/jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,26 @@ void jit_translate(riscv_t *rv, block_t *block);
typedef void (*exec_block_func_t)(riscv_t *rv, uintptr_t);

#if RV32_HAS(T2C)
void t2c_compile(block_t *block, uint64_t mem_base);
void t2c_compile(riscv_t *, block_t *);
typedef void (*exec_t2c_func_t)(riscv_t *);

/* The jit-cache records the program counters and the entries of executable
* instructions generated by T2C. Like hardware cache, the old jit-cache will be
* replaced by the new one which uses the same slot.
*/

/* The size of jit-cache table should be the power of 2, thus, we can easily
* access the element by masking the program counter.
*/
#define N_JIT_CACHE_ENTRIES (1 << 12)

struct jit_cache {
uint64_t pc; /* program counter, easy to build LLVM IR with 64-bit width */
void *entry; /* entry of JIT-ed code */
};

struct jit_cache *jit_cache_init();
void jit_cache_exit(struct jit_cache *cache);
void jit_cache_update(struct jit_cache *cache, uint32_t pc, void *entry);
void jit_cache_clear(struct jit_cache *cache);
#endif
5 changes: 3 additions & 2 deletions src/riscv.c
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,7 @@ static void *t2c_runloop(void *arg)
pthread_mutex_lock(&rv->wait_queue_lock);
list_del_init(&entry->list);
pthread_mutex_unlock(&rv->wait_queue_lock);
t2c_compile(entry->block,
(uint64_t) ((memory_t *) PRIV(rv)->mem)->mem_base);
t2c_compile(rv, entry->block);
free(entry);
}
}
Expand Down Expand Up @@ -291,6 +290,7 @@ riscv_t *rv_create(riscv_user_t rv_attr)
mpool_create(sizeof(chain_entry_t) << BLOCK_IR_MAP_CAPACITY_BITS,
sizeof(chain_entry_t));
rv->jit_state = jit_state_init(CODE_CACHE_SIZE);
rv->jit_cache = jit_cache_init();
rv->block_cache = cache_create(BLOCK_MAP_CAPACITY_BITS);
assert(rv->block_cache);
#if RV32_HAS(T2C)
Expand Down Expand Up @@ -392,6 +392,7 @@ void rv_delete(riscv_t *rv)
#endif
mpool_destroy(rv->chain_entry_mp);
jit_state_exit(rv->jit_state);
jit_cache_exit(rv->jit_cache);
cache_free(rv->block_cache);
#endif
free(rv);
Expand Down
1 change: 1 addition & 0 deletions src/riscv_private.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ struct riscv_internal {
struct mpool *block_mp, *block_ir_mp;

void *jit_state;
void *jit_cache;
#if RV32_HAS(GDBSTUB)
/* gdbstub instance */
gdbstub_t gdbstub;
Expand Down
72 changes: 55 additions & 17 deletions src/t2c.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@ FORCE_INLINE LLVMBasicBlockRef t2c_block_map_search(struct LLVM_block_map *map,
return NULL;
}

#define T2C_OP(inst, code) \
static void t2c_##inst( \
LLVMBuilderRef *builder UNUSED, LLVMTypeRef *param_types UNUSED, \
LLVMValueRef start UNUSED, LLVMBasicBlockRef *entry UNUSED, \
LLVMBuilderRef *taken_builder UNUSED, \
LLVMBuilderRef *untaken_builder UNUSED, uint64_t mem_base UNUSED, \
rv_insn_t *ir UNUSED) \
{ \
code; \
#define T2C_OP(inst, code) \
static void t2c_##inst( \
LLVMBuilderRef *builder UNUSED, LLVMTypeRef *param_types UNUSED, \
LLVMValueRef start UNUSED, LLVMBasicBlockRef *entry UNUSED, \
LLVMBuilderRef *taken_builder UNUSED, \
LLVMBuilderRef *untaken_builder UNUSED, riscv_t *rv UNUSED, \
uint64_t mem_base UNUSED, rv_insn_t *ir UNUSED) \
{ \
code; \
}

#define T2C_LLVM_GEN_ADDR(reg, rv_member, ir_member) \
Expand Down Expand Up @@ -135,6 +135,9 @@ FORCE_INLINE void t2c_gen_call_io_func(LLVMValueRef start,
&io_param, 1, "");
}

static LLVMTypeRef t2c_jit_cache_func_type;
static LLVMTypeRef t2c_jit_cache_struct_type;

#include "t2c_template.c"
#undef T2C_OP

Expand Down Expand Up @@ -174,14 +177,15 @@ typedef void (*t2c_codegen_block_func_t)(LLVMBuilderRef *builder UNUSED,
LLVMBasicBlockRef *entry UNUSED,
LLVMBuilderRef *taken_builder UNUSED,
LLVMBuilderRef *untaken_builder UNUSED,
riscv_t *rv UNUSED,
uint64_t mem_base UNUSED,
rv_insn_t *ir UNUSED);

static void t2c_trace_ebb(LLVMBuilderRef *builder,
LLVMTypeRef *param_types UNUSED,
LLVMValueRef start,
LLVMBasicBlockRef *entry,
uint64_t mem_base,
riscv_t *rv,
rv_insn_t *ir,
set_t *set,
struct LLVM_block_map *map)
Expand All @@ -194,7 +198,8 @@ static void t2c_trace_ebb(LLVMBuilderRef *builder,

while (1) {
((t2c_codegen_block_func_t) dispatch_table[ir->opcode])(
builder, param_types, start, entry, &tk, &utk, mem_base, ir);
builder, param_types, start, entry, &tk, &utk, rv,
(uint64_t) ((memory_t *) PRIV(rv)->mem)->mem_base, ir);
if (!ir->next)
break;
ir = ir->next;
Expand All @@ -214,8 +219,7 @@ static void t2c_trace_ebb(LLVMBuilderRef *builder,
LLVMPositionBuilderAtEnd(untaken_builder, untaken_entry);
LLVMBuildBr(utk, untaken_entry);
t2c_trace_ebb(&untaken_builder, param_types, start,
&untaken_entry, mem_base, ir->branch_untaken, set,
map);
&untaken_entry, rv, ir->branch_untaken, set, map);
}
}
if (ir->branch_taken) {
Expand All @@ -230,13 +234,13 @@ static void t2c_trace_ebb(LLVMBuilderRef *builder,
LLVMPositionBuilderAtEnd(taken_builder, taken_entry);
LLVMBuildBr(tk, taken_entry);
t2c_trace_ebb(&taken_builder, param_types, start, &taken_entry,
mem_base, ir->branch_taken, set, map);
rv, ir->branch_taken, set, map);
}
}
}
}

void t2c_compile(block_t *block, uint64_t mem_base)
void t2c_compile(riscv_t *rv, block_t *block)
{
LLVMModuleRef module = LLVMModuleCreateWithName("my_module");
LLVMTypeRef io_members[] = {
Expand All @@ -254,6 +258,16 @@ void t2c_compile(block_t *block, uint64_t mem_base)
LLVMTypeRef param_types[] = {LLVMPointerType(struct_rv, 0)};
LLVMValueRef start = LLVMAddFunction(
module, "start", LLVMFunctionType(LLVMVoidType(), param_types, 1, 0));

LLVMTypeRef t2c_args[1] = {LLVMInt64Type()};
t2c_jit_cache_func_type =
LLVMFunctionType(LLVMVoidType(), t2c_args, 1, false);

/* Notice to the alignment */
LLVMTypeRef jit_cache_memb[2] = {LLVMInt64Type(),
LLVMPointerType(LLVMVoidType(), 0)};
t2c_jit_cache_struct_type = LLVMStructType(jit_cache_memb, 2, false);

LLVMBasicBlockRef first_block = LLVMAppendBasicBlock(start, "first_block");
LLVMBuilderRef first_builder = LLVMCreateBuilder();
LLVMPositionBuilderAtEnd(first_builder, first_block);
Expand All @@ -266,8 +280,8 @@ void t2c_compile(block_t *block, uint64_t mem_base)
struct LLVM_block_map map;
map.count = 0;
/* Translate custon IR into LLVM IR */
t2c_trace_ebb(&builder, param_types, start, &entry, mem_base,
block->ir_head, &set, &map);
t2c_trace_ebb(&builder, param_types, start, &entry, rv, block->ir_head,
&set, &map);
/* Offload LLVM IR to LLVM backend */
char *error = NULL, *triple = LLVMGetDefaultTargetTriple();
LLVMExecutionEngineRef engine;
Expand Down Expand Up @@ -298,5 +312,29 @@ void t2c_compile(block_t *block, uint64_t mem_base)

/* Return the function pointer of T2C generated machine code */
block->func = (exec_t2c_func_t) LLVMGetPointerToGlobal(engine, start);
jit_cache_update(rv->jit_cache, block->pc_start, block->func);
block->hot2 = true;
}

struct jit_cache *jit_cache_init()
{
return calloc(N_JIT_CACHE_ENTRIES, sizeof(struct jit_cache));
}

void jit_cache_exit(struct jit_cache *cache)
{
free(cache);
}

void jit_cache_update(struct jit_cache *cache, uint32_t pc, void *entry)
{
uint32_t pos = pc & (N_JIT_CACHE_ENTRIES - 1);

cache[pos].pc = pc;
cache[pos].entry = entry;
}

void jit_cache_clear(struct jit_cache *cache)
{
memset(cache, 0, N_JIT_CACHE_ENTRIES * sizeof(struct jit_cache));
}
72 changes: 63 additions & 9 deletions src/t2c_template.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,63 @@ T2C_OP(jal, {
}
})

FORCE_INLINE void t2c_jit_cache_helper(LLVMBuilderRef *builder,
LLVMValueRef start,
LLVMValueRef addr,
riscv_t *rv,
rv_insn_t *ir)
{
LLVMBasicBlockRef true_path = LLVMAppendBasicBlock(start, "");
LLVMBuilderRef true_builder = LLVMCreateBuilder();
LLVMPositionBuilderAtEnd(true_builder, true_path);

LLVMBasicBlockRef false_path = LLVMAppendBasicBlock(start, "");
LLVMBuilderRef false_builder = LLVMCreateBuilder();
LLVMPositionBuilderAtEnd(false_builder, false_path);

/* get jit-cache base address */
LLVMValueRef base = LLVMConstIntToPtr(
LLVMConstInt(LLVMInt64Type(), (long) rv->jit_cache, false),
LLVMPointerType(t2c_jit_cache_struct_type, 0));

/* get index */
LLVMValueRef hash = LLVMBuildAnd(
*builder, addr,
LLVMConstInt(LLVMInt32Type(), N_JIT_CACHE_ENTRIES - 1, false), "");

/* get jit_cache_t::pc */
LLVMValueRef cast =
LLVMBuildIntCast2(*builder, hash, LLVMInt64Type(), false, "");
LLVMValueRef element_ptr = LLVMBuildInBoundsGEP2(
*builder, t2c_jit_cache_struct_type, base, &cast, 1, "");
LLVMValueRef pc_ptr = LLVMBuildStructGEP2(
*builder, t2c_jit_cache_struct_type, element_ptr, 0, "");
LLVMValueRef pc = LLVMBuildLoad2(*builder, LLVMInt32Type(), pc_ptr, "");

/* compare with calculated destination */
LLVMValueRef cmp = LLVMBuildICmp(*builder, LLVMIntEQ, pc, addr, "");

LLVMBuildCondBr(*builder, cmp, true_path, false_path);

/* get jit_cache_t::entry */
LLVMValueRef entry_ptr = LLVMBuildStructGEP2(
true_builder, t2c_jit_cache_struct_type, element_ptr, 1, "");

/* invoke T2C JIT-ed code */
LLVMValueRef t2c_args[1] = {
LLVMConstInt(LLVMInt64Type(), (long) rv, false)};

LLVMBuildCall2(true_builder, t2c_jit_cache_func_type,
LLVMBuildLoad2(true_builder, LLVMInt64Type(), entry_ptr, ""),
t2c_args, 1, "");
LLVMBuildRetVoid(true_builder);

/* return to interpreter if cache-miss */
LLVMBuildStore(false_builder, addr,
t2c_gen_PC_addr(start, &false_builder, ir));
LLVMBuildRetVoid(false_builder);
}

T2C_OP(jalr, {
if (ir->rd)
T2C_LLVM_GEN_STORE_IMM32(*builder, ir->pc + 4,
Expand All @@ -40,8 +97,7 @@ T2C_OP(jalr, {
T2C_LLVM_GEN_LOAD_VMREG(rs1, 32, t2c_gen_rs1_addr(start, builder, ir));
val_rs1 = T2C_LLVM_GEN_ALU32_IMM(Add, val_rs1, ir->imm);
val_rs1 = T2C_LLVM_GEN_ALU32_IMM(And, val_rs1, ~1U);
LLVMBuildStore(*builder, val_rs1, t2c_gen_PC_addr(start, builder, ir));
LLVMBuildRetVoid(*builder);
t2c_jit_cache_helper(builder, start, val_rs1, rv, ir);
})

#define BRANCH_FUNC(type, cond) \
Expand Down Expand Up @@ -672,8 +728,7 @@ T2C_OP(clwsp, {

T2C_OP(cjr, {
T2C_LLVM_GEN_LOAD_VMREG(rs1, 32, t2c_gen_rs1_addr(start, builder, ir));
LLVMBuildStore(*builder, val_rs1, t2c_gen_PC_addr(start, builder, ir));
LLVMBuildRetVoid(*builder);
t2c_jit_cache_helper(builder, start, val_rs1, rv, ir);
})

T2C_OP(cmv, {
Expand All @@ -692,8 +747,7 @@ T2C_OP(cjalr, {
T2C_LLVM_GEN_STORE_IMM32(*builder, ir->pc + 2,
t2c_gen_ra_addr(start, builder, ir));
T2C_LLVM_GEN_LOAD_VMREG(rs1, 32, t2c_gen_rs1_addr(start, builder, ir));
LLVMBuildStore(*builder, val_rs1, t2c_gen_PC_addr(start, builder, ir));
LLVMBuildRetVoid(*builder);
t2c_jit_cache_helper(builder, start, val_rs1, rv, ir);
})

T2C_OP(cadd, {
Expand Down Expand Up @@ -785,15 +839,15 @@ T2C_OP(fuse5, {
switch (fuse[i].opcode) {
case rv_insn_slli:
t2c_slli(builder, param_types, start, entry, taken_builder,
untaken_builder, mem_base, (rv_insn_t *) (&fuse[i]));
untaken_builder, rv, mem_base, (rv_insn_t *) (&fuse[i]));
break;
case rv_insn_srli:
t2c_srli(builder, param_types, start, entry, taken_builder,
untaken_builder, mem_base, (rv_insn_t *) (&fuse[i]));
untaken_builder, rv, mem_base, (rv_insn_t *) (&fuse[i]));
break;
case rv_insn_srai:
t2c_srai(builder, param_types, start, entry, taken_builder,
untaken_builder, mem_base, (rv_insn_t *) (&fuse[i]));
untaken_builder, rv, mem_base, (rv_insn_t *) (&fuse[i]));
break;
default:
__UNREACHABLE;
Expand Down
43 changes: 43 additions & 0 deletions tests/fibonacci.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
fib:
li a5, 1
bleu a0, a5, .L3
addi sp, sp, -16
sw ra, 12(sp)
sw s0, 8(sp)
sw s1, 4(sp)
mv s0, a0
addi a0, a0, -1
la t0, fib
jalr ra, 0(t0)
mv s1, a0
addi a0, s0, -2
la t0, fib
jalr ra, 0(t0)
add a0, s1, a0
lw ra, 12(sp)
lw s0, 8(sp)
lw s1, 4(sp)
addi sp, sp, 16
jr ra
.L3:
li a0, 1
ret
.LC0:
.string "%d\n"
.text
.align 1
.globl main
.type main, @function
main:
addi sp, sp, -16
sw ra, 12(sp)
li a0, 42
call fib
mv a1, a0
lui a0, %hi(.LC0)
addi a0, a0, %lo(.LC0)
call printf
li a0, 0
lw ra, 12(sp)
addi sp, sp, 16
jr ra

0 comments on commit 65fdaef

Please sign in to comment.