Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Print hooked sort elements correctly #914

Merged
merged 5 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bin/llvm-kompile-clang
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ if [[ "$OSTYPE" == "darwin"* ]]; then
flags=(
"-L@BREW_PREFIX@/opt/libffi/lib"
"-L@BREW_PREFIX@/lib"
"-Wl,-u,_sort_table"
"-Wl,-u,_table_getArgumentSortsForTag"
"-I" "@BREW_PREFIX@/include"
)
else
libraries=("-ltinfo")
flags=("-Wl,-u,sort_table")
flags=("-Wl,-u,table_getArgumentSortsForTag")
fi

llc_flags+=("--relocation-model=pic")
Expand Down
2 changes: 1 addition & 1 deletion debug/kgdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def append(self, subject, isVar, sort):
argData = layoutData['args'] + i
arg = subject.cast(self.long_int) + int(argData.dereference()['offset'])
cat = argData.dereference()['cat']
sort = gdb.lookup_global_symbol("sort_table").value()[tag][i].string("iso-8859-1")
sort = gdb.lookup_global_symbol("table_getArgumentSortsForTag").value()[tag][i].string("iso-8859-1")
if cat == @MAP_LAYOUT@:
self.appendMap(arg.cast(self.map_ptr), sort)
elif cat == @RANGEMAP_LAYOUT@:
Expand Down
1 change: 1 addition & 0 deletions include/runtime/header.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ bool hook_STRING_eq(SortString, SortString);

const char *getSymbolNameForTag(uint32_t tag);
const char *getReturnSortForTag(uint32_t tag);
const char **getArgumentSortsForTag(uint32_t tag);
const char *topSort(void);

typedef struct {
Expand Down
78 changes: 28 additions & 50 deletions lib/codegen/EmitConfigParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1248,26 +1248,13 @@ static void emitInjTags(KOREDefinition *def, llvm::Module *mod) {
}
}

static void emitSortTable(KOREDefinition *definition, llvm::Module *module) {
llvm::LLVMContext &Ctx = module->getContext();
auto &syms = definition->getSymbols();
auto ty = llvm::PointerType::getUnqual(llvm::Type::getInt8PtrTy(Ctx));
auto dity = getPointerDebugType(getCharPtrDebugType(), "char *");
auto tableType = llvm::ArrayType::get(ty, syms.size());
auto table = module->getOrInsertGlobal("sort_table", tableType);
llvm::GlobalVariable *globalVar = llvm::dyn_cast<llvm::GlobalVariable>(table);
initDebugGlobal(
"sort_table",
getArrayDebugType(
dity, syms.size(), llvm::DataLayout(module).getABITypeAlign(ty)),
globalVar);
std::vector<llvm::Constant *> values;
for (auto iter = syms.begin(); iter != syms.end(); ++iter) {
auto entry = *iter;
auto symbol = entry.second;
static void emitSortTable(KOREDefinition *def, llvm::Module *mod) {
auto getter = [](KOREDefinition *definition, llvm::Module *module,
KORESymbol *symbol) -> llvm::Constant * {
auto &ctx = module->getContext();

auto subtableType = llvm::ArrayType::get(
llvm::Type::getInt8PtrTy(Ctx), symbol->getArguments().size());
llvm::Type::getInt8PtrTy(ctx), symbol->getArguments().size());
auto subtable = module->getOrInsertGlobal(
fmt::format("sorts_{}", ast_to_string(*symbol)), subtableType);
llvm::GlobalVariable *subtableVar
Expand All @@ -1277,30 +1264,35 @@ static void emitSortTable(KOREDefinition *definition, llvm::Module *module) {
getArrayDebugType(
getCharPtrDebugType(), symbol->getArguments().size(),
llvm::DataLayout(module).getABITypeAlign(
llvm::Type::getInt8PtrTy(Ctx))),
llvm::Type::getInt8PtrTy(ctx))),
subtableVar);
llvm::Constant *zero
= llvm::ConstantInt::get(llvm::Type::getInt64Ty(Ctx), 0);
= llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
auto indices = std::vector<llvm::Constant *>{zero, zero};
values.push_back(llvm::ConstantExpr::getInBoundsGetElementPtr(
subtableType, subtableVar, indices));

std::vector<llvm::Constant *> subvalues;
for (size_t i = 0; i < symbol->getArguments().size(); ++i) {
auto arg_str = ast_to_string(*symbol->getArguments()[i]);
auto strType = llvm::ArrayType::get(
llvm::Type::getInt8Ty(Ctx), arg_str.size() + 1);
llvm::Type::getInt8Ty(ctx), arg_str.size() + 1);
auto sortName = module->getOrInsertGlobal(
fmt::format("sort_name_{}", arg_str), strType);
subvalues.push_back(llvm::ConstantExpr::getInBoundsGetElementPtr(
strType, sortName, indices));
}
subtableVar->setInitializer(
llvm::ConstantArray::get(subtableType, subvalues));
}
if (!globalVar->hasInitializer()) {
globalVar->setInitializer(llvm::ConstantArray::get(tableType, values));
}

return llvm::ConstantExpr::getInBoundsGetElementPtr(
subtableType, subtableVar, indices);
};

auto i8_ptr_ty = llvm::Type::getInt8PtrTy(mod->getContext());
auto entry_ty = llvm::PointerType::getUnqual(i8_ptr_ty);
auto debug_ty = getPointerDebugType(getCharPtrDebugType(), "char **");

emitDataTableForSymbol(
"getArgumentSortsForTag", entry_ty, debug_ty, def, mod, getter);
}

/*
Expand All @@ -1311,23 +1303,12 @@ static void emitSortTable(KOREDefinition *definition, llvm::Module *module) {
*
* Each value in the table is a pointer to a global variable containing the
* relevant sort name as a null-terminated string.
*
* The function `getReturnSortForTag` abstracts accesses to the data in this
* table.
*/
static void
emitReturnSortTable(KOREDefinition *definition, llvm::Module *module) {
auto &ctx = module->getContext();
static void emitReturnSortTable(KOREDefinition *def, llvm::Module *mod) {
auto getter = [](KOREDefinition *definition, llvm::Module *module,
KORESymbol *symbol) -> llvm::Constant * {
auto &ctx = module->getContext();

auto const &syms = definition->getSymbols();

auto element_type = llvm::Type::getInt8PtrTy(ctx);
auto table_type = llvm::ArrayType::get(element_type, syms.size());

auto table = module->getOrInsertGlobal("return_sort_table", table_type);
auto values = std::vector<llvm::Constant *>{};

for (auto [tag, symbol] : syms) {
auto sort = symbol->getSort();
auto sort_str = ast_to_string(*sort);

Expand All @@ -1340,16 +1321,13 @@ emitReturnSortTable(KOREDefinition *definition, llvm::Module *module) {
auto i64_type = llvm::Type::getInt64Ty(ctx);
auto zero = llvm::ConstantInt::get(i64_type, 0);

auto pointer = llvm::ConstantExpr::getInBoundsGetElementPtr(
return llvm::ConstantExpr::getInBoundsGetElementPtr(
str_type, sort_name, std::vector<llvm::Constant *>{zero});
};

values.push_back(pointer);
}

auto global = llvm::dyn_cast<llvm::GlobalVariable>(table);
if (!global->hasInitializer()) {
global->setInitializer(llvm::ConstantArray::get(table_type, values));
}
emitDataTableForSymbol(
"getReturnSortForTag", llvm::Type::getInt8PtrTy(mod->getContext()),
getCharPtrDebugType(), def, mod, getter);
}

void emitConfigParserFunctions(
Expand Down
5 changes: 4 additions & 1 deletion runtime/collections/lists.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ void printList(
return;
}

auto tag = getTagForSymbolName(element);
auto arg_sorts = getArgumentSortsForTag(tag);

sfprintf(file, "\\left-assoc{}(%s(", concat);

bool once = true;
Expand All @@ -256,7 +259,7 @@ void printList(
sfprintf(file, ",");
}
sfprintf(file, "%s(", element);
printConfigurationInternal(file, *iter, "SortKItem{}", false, state);
printConfigurationInternal(file, *iter, arg_sorts[0], false, state);
sfprintf(file, ")");
}
sfprintf(file, "))");
Expand Down
7 changes: 5 additions & 2 deletions runtime/collections/maps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ void printMap(
return;
}

auto tag = getTagForSymbolName(element);
auto arg_sorts = getArgumentSortsForTag(tag);

sfprintf(file, "\\left-assoc{}(%s(", concat);

bool once = true;
Expand All @@ -212,9 +215,9 @@ void printMap(

sfprintf(file, "%s(", element);
auto entry = *iter;
printConfigurationInternal(file, entry.first, "SortKItem{}", false, state);
printConfigurationInternal(file, entry.first, arg_sorts[0], false, state);
sfprintf(file, ",");
printConfigurationInternal(file, entry.second, "SortKItem{}", false, state);
printConfigurationInternal(file, entry.second, arg_sorts[1], false, state);
sfprintf(file, ")");
}
sfprintf(file, "))");
Expand Down
5 changes: 4 additions & 1 deletion runtime/collections/rangemaps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@ void printRangeMap(
return;
}

auto tag = getTagForSymbolName(element);
auto arg_sorts = getArgumentSortsForTag(tag);

sfprintf(file, "\\left-assoc{}(%s(", concat);

bool once = true;
Expand All @@ -285,7 +288,7 @@ void printRangeMap(
printConfigurationInternal(
file, entry.first.end(), "SortKItem{}", false, state);
sfprintf(file, "),");
printConfigurationInternal(file, entry.second, "SortKItem{}", false, state);
printConfigurationInternal(file, entry.second, arg_sorts[1], false, state);
sfprintf(file, ")");
}
sfprintf(file, "))");
Expand Down
5 changes: 4 additions & 1 deletion runtime/collections/sets.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ void printSet(
return;
}

auto tag = getTagForSymbolName(element);
auto arg_sorts = getArgumentSortsForTag(tag);

sfprintf(file, "\\left-assoc{}(%s(", concat);

bool once = true;
Expand All @@ -157,7 +160,7 @@ void printSet(
}

sfprintf(file, "%s(", element);
printConfigurationInternal(file, *iter, "SortKItem{}", false, state);
printConfigurationInternal(file, *iter, arg_sorts[0], false, state);
sfprintf(file, ")");
}
sfprintf(file, "))");
Expand Down
22 changes: 17 additions & 5 deletions runtime/util/ConfigurationSerializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,14 @@ void serializeMap(
return;
}

auto tag = getTagForSymbolName(element);
auto arg_sorts = getArgumentSortsForTag(tag);

for (auto iter = map->begin(); iter != map->end(); ++iter) {
serializeConfigurationInternal(
file, iter->first, "SortKItem{}", false, state);
file, iter->first, arg_sorts[0], false, state);
serializeConfigurationInternal(
file, iter->second, "SortKItem{}", false, state);
file, iter->second, arg_sorts[1], false, state);
emitSymbol(instance, element, 2);

if (iter != map->begin()) {
Expand All @@ -138,6 +141,9 @@ void serializeRangeMap(
return;
}

auto tag = getTagForSymbolName(element);
auto arg_sorts = getArgumentSortsForTag(tag);

bool once = true;
for (auto iter = rng_map::ConstRangeMapIterator<KElem, KElem>(*map);
iter.has_next(); ++iter) {
Expand All @@ -147,7 +153,7 @@ void serializeRangeMap(
file, iter->first.end(), "SortKItem{}", false, state);
emitSymbol(instance, "LblRangemap'Coln'Range{}", 2);
serializeConfigurationInternal(
file, iter->second, "SortKItem{}", false, state);
file, iter->second, arg_sorts[1], false, state);
emitSymbol(instance, element, 2);

if (once) {
Expand All @@ -169,8 +175,11 @@ void serializeList(
return;
}

auto tag = getTagForSymbolName(element);
auto arg_sorts = getArgumentSortsForTag(tag);

for (auto iter = list->begin(); iter != list->end(); ++iter) {
serializeConfigurationInternal(file, *iter, "SortKItem{}", false, state);
serializeConfigurationInternal(file, *iter, arg_sorts[0], false, state);
emitSymbol(instance, element, 1);

if (iter != list->begin()) {
Expand All @@ -190,8 +199,11 @@ void serializeSet(
return;
}

auto tag = getTagForSymbolName(element);
auto arg_sorts = getArgumentSortsForTag(tag);

for (auto iter = set->begin(); iter != set->end(); ++iter) {
serializeConfigurationInternal(file, *iter, "SortKItem{}", false, state);
serializeConfigurationInternal(file, *iter, arg_sorts[0], false, state);
emitSymbol(instance, element, 1);

if (iter != set->begin()) {
Expand Down
6 changes: 0 additions & 6 deletions runtime/util/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,6 @@

extern "C" {

extern char *return_sort_table;

const char *getReturnSortForTag(uint32_t tag) {
return (&return_sort_table)[tag];
}

block *dot_k() {
return leaf_block(getTagForSymbolName("dotk{}"));
}
Expand Down
85 changes: 85 additions & 0 deletions test/defn/k-files/wasm-maps.k
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
module WASM-MAPS
imports MAP-INT-TO-VAL
// imports MAP-INT-TO-INT

configuration
<wasm>
<instrs> $PGM </instrs>
<locals> .MapIntToVal </locals>
// <locals> .MapIntToInt </locals>
</wasm>

syntax Instr ::= "init_local" Int Int
// -----------------------------------
rule <instrs> init_local INDEX VALUE => . ... </instrs>
<locals> LOCALS => LOCALS {{ INDEX <- wrap(VALUE) }} </locals>

endmodule

module MAP-INT-TO-INT
imports WRAPPED-INT

syntax MapIntToInt [hook(MAP.Map)]
syntax MapIntToInt ::= MapIntToInt MapIntToInt
[ left, function, hook(MAP.concat), klabel(_MapIntToInt_),
symbol, assoc, comm, unit(.MapIntToInt), element(_Int2Int|->_)
]
syntax MapIntToInt ::= ".MapIntToInt"
[function, total, hook(MAP.unit),klabel(.MapIntToInt), symbol]
syntax MapIntToInt ::= WrappedInt "Int2Int|->" WrappedInt
[function, total, hook(MAP.element), klabel(_Int2Int|->_), symbol]

syntax MapIntToInt ::= MapIntToInt "[" key: WrappedInt "<-" value: WrappedInt "]" [function, total, klabel(MapInt2Int:update), symbol, hook(MAP.update), prefer]

syntax priorities _Int2Int|->_ > _MapIntToInt_ .MapIntToInt
syntax non-assoc _Int2Int|->_

syntax MapIntToInt ::= MapIntToInt "{{" key: Int "<-" value: Int "}}"
[ function, total, klabel(MapIntToInt:primitiveUpdate), symbol,
prefer
]
rule M:MapIntToInt {{ Key:Int <- Value:Int }}
=> M[wrap(Key) <- wrap(Value)]

endmodule

module MAP-INT-TO-VAL
imports WRAPPED-INT

syntax Val ::= WrappedInt
// -----------------------

syntax MapIntToVal [hook(MAP.Map)]
syntax MapIntToVal ::= MapIntToVal MapIntToVal
[ left, function, hook(MAP.concat), klabel(_MapIntToVal_),
symbol, assoc, comm, unit(.MapIntToVal), element(_Int2Val|->_)
]
syntax MapIntToVal ::= ".MapIntToVal"
[function, total, hook(MAP.unit),klabel(.MapIntToVal), symbol]
syntax MapIntToVal ::= WrappedInt "Int2Val|->" Val
[function, total, hook(MAP.element), klabel(_Int2Val|->_), symbol]

syntax MapIntToVal ::= MapIntToVal "[" key: WrappedInt "<-" value: Val "]" [function, total, klabel(MapInt2Val:update), symbol, hook(MAP.update), prefer]

syntax priorities _Int2Val|->_ > _MapIntToVal_ .MapIntToVal
syntax non-assoc _Int2Val|->_

syntax MapIntToVal ::= MapIntToVal "{{" key: Int "<-" value: Val "}}"
[ function, total, klabel(MapIntToVal:primitiveUpdate), symbol,
prefer
]
rule M:MapIntToVal {{ Key:Int <- Value:Val }}
=> M[wrap(Key) <- Value]

endmodule

module WRAPPED-INT
imports INT

syntax WrappedInt ::= wrap(Int) [symbol, klabel(wrapInt)]
// -------------------------------------------------------

syntax Int ::= unwrap(WrappedInt) [function, total, injective, symbol, klabel(unwrapInt)]
// ---------------------------------------------------------------------------------------
rule unwrap(wrap(A:Int)) => A
endmodule
Loading
Loading