Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial support for WebAssembly in user-defined functions (UDF) #9108

Merged
merged 10 commits into from
Sep 14, 2021
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ set(scylla_sources
locator/simple_strategy.cc
locator/snitch_base.cc
locator/token_metadata.cc
lua.cc
lang/lua.cc
main.cc
memtable.cc
message/messaging_service.cc
Expand Down
2 changes: 2 additions & 0 deletions NOTICE.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ These files are located in utils/arch/powerpc/crc32-vpmsum. Their license may be
It includes modified code from https://gitbox.apache.org/repos/asf?p=cassandra-dtest.git (owned by The Apache Software Foundation)

It includes modified tests from https://github.com/etcd-io/etcd.git (owned by The etcd Authors)

It includes files from https://github.com/bytecodealliance/wasmtime-cpp (owned by Bytecode Alliance), licensed with Apache License 2.0.
4 changes: 3 additions & 1 deletion configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,9 @@ def find_headers(repodir, excluded_dirs):
'mutation_writer/shard_based_splitting_writer.cc',
'mutation_writer/partition_based_splitting_writer.cc',
'mutation_writer/feed_writers.cc',
'lua.cc',
'lang/lua.cc',
'lang/wasm_engine.cc',
'lang/wasm.cc',
'service/raft/schema_raft_state_machine.cc',
'service/raft/raft_sys_table_storage.cc',
'serializer.cc',
Expand Down
40 changes: 24 additions & 16 deletions cql3/functions/user_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
*/

#include "user_function.hh"
#include "lua.hh"
#include "log.hh"
#include "cql_serialization_format.hh"

Expand All @@ -32,12 +31,10 @@ namespace functions {
extern logging::logger log;

user_function::user_function(function_name name, std::vector<data_type> arg_types, std::vector<sstring> arg_names,
sstring body, sstring language, data_type return_type, bool called_on_null_input, sstring bitcode,
lua::runtime_config cfg)
sstring body, sstring language, data_type return_type, bool called_on_null_input, context ctx)
: abstract_function(std::move(name), std::move(arg_types), std::move(return_type)),
_arg_names(std::move(arg_names)), _body(std::move(body)), _language(std::move(language)),
_called_on_null_input(called_on_null_input), _bitcode(std::move(bitcode)),
_cfg(std::move(cfg)) {}
_called_on_null_input(called_on_null_input), _ctx(std::move(ctx)) {}

bool user_function::is_pure() const { return true; }

Expand All @@ -53,20 +50,31 @@ bytes_opt user_function::execute(cql_serialization_format sf, const std::vector<
throw std::logic_error("Wrong number of parameters");
}

std::vector<data_value> values;
values.reserve(parameters.size());
for (int i = 0, n = types.size(); i != n; ++i) {
const data_type& type = types[i];
const bytes_opt& bytes = parameters[i];
if (!bytes && !_called_on_null_input) {
return std::nullopt;
}
values.push_back(bytes ? type->deserialize(*bytes) : data_value::make_null(type));
}
if (!seastar::thread::running_in_thread()) {
on_internal_error(log, "User function cannot be executed in this context");
}
return lua::run_script(lua::bitcode_view{_bitcode}, values, return_type(), _cfg).get0();
return seastar::visit(_ctx,
[&] (lua_context& ctx) -> bytes_opt {
std::vector<data_value> values;
values.reserve(parameters.size());
for (int i = 0, n = types.size(); i != n; ++i) {
const data_type& type = types[i];
const bytes_opt& bytes = parameters[i];
if (!bytes && !_called_on_null_input) {
return std::nullopt;
}
values.push_back(bytes ? type->deserialize(*bytes) : data_value::make_null(type));
}
return lua::run_script(lua::bitcode_view{ctx.bitcode}, values, return_type(), ctx.cfg).get0();
},
[&] (wasm::context& ctx) {
try {
return wasm::run_script(ctx, arg_types(), parameters, return_type(), _called_on_null_input).get0();
} catch (const wasm::exception& e) {
throw exceptions::invalid_request_exception(format("UDF error: {}", e.what()));
}
});
}

}
}
34 changes: 21 additions & 13 deletions cql3/functions/user_function.hh
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,39 @@

#include "abstract_function.hh"
#include "scalar_function.hh"
#include "lua.hh"
#include "lang/lua.hh"
#include "lang/wasm.hh"

namespace cql3 {
namespace functions {


class user_function final : public abstract_function, public scalar_function {
public:
struct lua_context {
sstring bitcode;
// FIXME: We should not need a copy in each function. It is here
// because user_function::execute is only passed the
// cql_serialization_format and the runtime arguments. We could
// avoid it by having a runtime->execute(user_function) instead,
// but that is a large refactoring. We could also store a
// lua_runtime in a thread_local variable, but that is one extra
// global.
lua::runtime_config cfg;
};

using context = std::variant<lua_context, wasm::context>;

private:
std::vector<sstring> _arg_names;
sstring _body;
sstring _language;
bool _called_on_null_input;
sstring _bitcode;

// FIXME: We should not need a copy in each function. It is here
// because user_function::execute is only passed the
// cql_serialization_format and the runtime arguments. We could
// avoid it by having a runtime->execute(user_function) instead,
// but that is a large refactoring. We could also store a
// lua_runtime in a thread_local variable, but that is one extra
// global.
lua::runtime_config _cfg;
context _ctx;

public:
user_function(function_name name, std::vector<data_type> arg_types, std::vector<sstring> arg_names, sstring body,
sstring language, data_type return_type, bool called_on_null_input, sstring bitcode,
lua::runtime_config cfg);
sstring language, data_type return_type, bool called_on_null_input, context ctx);

const std::vector<sstring>& arg_names() const { return _arg_names; }

Expand Down
28 changes: 20 additions & 8 deletions cql3/statements/create_function_statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#include "prepared_statement.hh"
#include "service/migration_manager.hh"
#include "service/storage_proxy.hh"
#include "lua.hh"
#include "lang/lua.hh"
#include "database.hh"
#include "cql3/query_processor.hh"

Expand All @@ -37,7 +37,7 @@ void create_function_statement::create(service::storage_proxy& proxy, functions:
if (old && !dynamic_cast<functions::user_function*>(old)) {
throw exceptions::invalid_request_exception(format("Cannot replace '{}' which is not a user defined function", *old));
}
if (_language != "lua") {
if (_language != "lua" && _language != "xwasm") {
throw exceptions::invalid_request_exception(format("Language '{}' is not supported", _language));
}
data_type return_type = prepare_type(proxy, *_return_type);
Expand All @@ -47,13 +47,25 @@ void create_function_statement::create(service::storage_proxy& proxy, functions:
}

auto&& db = proxy.get_db().local();
lua::runtime_config cfg = lua::make_runtime_config(db.get_config());
if (_language == "lua") {
auto cfg = lua::make_runtime_config(db.get_config());
functions::user_function::context ctx = functions::user_function::lua_context {
.bitcode = lua::compile(cfg, arg_names, _body),
.cfg = cfg,
};

// Checking that the function compiles also produces bitcode
auto bitcode = lua::compile(cfg, arg_names, _body);

_func = ::make_shared<functions::user_function>(_name, _arg_types, std::move(arg_names), _body, _language,
std::move(return_type), _called_on_null_input, std::move(bitcode), std::move(cfg));
_func = ::make_shared<functions::user_function>(_name, _arg_types, std::move(arg_names), _body, _language,
std::move(return_type), _called_on_null_input, std::move(ctx));
} else if (_language == "xwasm") {
wasm::context ctx{db.wasm_engine(), _name.name};
try {
wasm::compile(ctx, arg_names, _body);
_func = ::make_shared<functions::user_function>(_name, _arg_types, std::move(arg_names), _body, _language,
std::move(return_type), _called_on_null_input, std::move(ctx));
} catch (const wasm::exception& we) {
throw exceptions::invalid_request_exception(we.what());
}
}
return;
}

Expand Down
14 changes: 14 additions & 0 deletions database.hh
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ class abstract_replication_strategy;

} // namespace locator

namespace wasm {
class engine;
}

class mutation_reordered_with_truncate_exception : public std::exception {};

using shared_memtable = lw_shared_ptr<memtable>;
Expand Down Expand Up @@ -1343,6 +1347,8 @@ private:
bool _supports_infinite_bound_range_deletions = false;
gms::feature::listener_registration _infinite_bound_range_deletions_reg;

wasm::engine* _wasm_engine;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: why an old-style C pointer instead of some C++ type like unique_ptr or shared_ptr?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a non-owning pointer. I don't want the database to be the owner of this wasm engine in order to not make it a hard dependency. Instead, if the programmer is interested in providing the wasm engine, he can do so with a simple setter method.


future<> init_commitlog();
public:
const gms::feature_service& features() const { return _feat; }
Expand All @@ -1351,6 +1357,14 @@ public:

void set_local_id(utils::UUID uuid) noexcept { _local_host_id = std::move(uuid); }

wasm::engine* wasm_engine() {
return _wasm_engine;
}

void set_wasm_engine(wasm::engine* engine) {
_wasm_engine = engine;
}

private:
using system_keyspace = bool_class<struct system_keyspace_tag>;
void create_in_memory_keyspace(const lw_shared_ptr<keyspace_metadata>& ksm, system_keyspace system);
Expand Down
26 changes: 20 additions & 6 deletions db/schema_tables.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
#include "user_types_metadata.hh"

#include "index/target_parser.hh"
#include "lua.hh"
#include "lang/lua.hh"

#include "db/query_context.hh"
#include "serializer.hh"
Expand Down Expand Up @@ -1548,12 +1548,26 @@ static shared_ptr<cql3::functions::user_function> create_func(database& db, cons

auto arg_names = get_list<sstring>(row, "argument_names");
auto body = row.get_nonnull<sstring>("body");
lua::runtime_config cfg = lua::make_runtime_config(db.get_config());
auto bitcode = lua::compile(cfg, arg_names, body);
auto language = row.get_nonnull<sstring>("language");
if (language == "lua") {
lua::runtime_config cfg = lua::make_runtime_config(db.get_config());
cql3::functions::user_function::context ctx = cql3::functions::user_function::lua_context {
.bitcode = lua::compile(cfg, arg_names, body),
.cfg = cfg,
};

return ::make_shared<cql3::functions::user_function>(std::move(name), std::move(arg_types), std::move(arg_names),
std::move(body), row.get_nonnull<sstring>("language"), std::move(return_type),
row.get_nonnull<bool>("called_on_null_input"), std::move(bitcode), std::move(cfg));
return ::make_shared<cql3::functions::user_function>(std::move(name), std::move(arg_types), std::move(arg_names),
std::move(body), language, std::move(return_type),
row.get_nonnull<bool>("called_on_null_input"), std::move(ctx));
} else if (language == "xwasm") {
wasm::context ctx{db.wasm_engine(), name.name};
wasm::compile(ctx, arg_names, body);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it seems we have two functions checking if language=="lua" or "wasm" and compiling things, and a third place checking if it is neither and printing an error... Could we create just one function that does this, and all three locations will call it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have a scripting_engine class and derive lua_scripting_engine and wasm_scripting_engine. But I'm fine with waiting a bit to understand what that class wants to look like first.

return ::make_shared<cql3::functions::user_function>(std::move(name), std::move(arg_types), std::move(arg_names),
std::move(body), language, std::move(return_type),
row.get_nonnull<bool>("called_on_null_input"), std::move(ctx));
} else {
throw std::runtime_error(format("Unsupported language for UDF: {}", language));
}
}

static shared_ptr<cql3::functions::user_aggregate> create_aggregate(database& db, const query::result_set_row& row) {
Expand Down
Loading