diff --git a/CMakeLists.txt b/CMakeLists.txt index 54ff25193faf..6c6bce723728 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -508,7 +508,7 @@ set(scylla_sources locator/simple_strategy.cc locator/snitch_base.cc locator/token_metadata.cc - lua.cc + lang/lua.cc main.cc memtable.cc message/messaging_service.cc diff --git a/NOTICE.txt b/NOTICE.txt index d1dfbacc89cf..b0c99fb437fd 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -7,3 +7,5 @@ These files are located in utils/arch/powerpc/crc32-vpmsum. Their license may be It includes modified code from https://gitbox.apache.org/repos/asf?p=cassandra-dtest.git (owned by The Apache Software Foundation) It includes modified tests from https://github.com/etcd-io/etcd.git (owned by The etcd Authors) + +It includes files from https://github.com/bytecodealliance/wasmtime-cpp (owned by Bytecode Alliance), licensed with Apache License 2.0. diff --git a/configure.py b/configure.py index d3959e128867..573b614ae510 100755 --- a/configure.py +++ b/configure.py @@ -1006,7 +1006,9 @@ def find_headers(repodir, excluded_dirs): 'mutation_writer/shard_based_splitting_writer.cc', 'mutation_writer/partition_based_splitting_writer.cc', 'mutation_writer/feed_writers.cc', - 'lua.cc', + 'lang/lua.cc', + 'lang/wasm_engine.cc', + 'lang/wasm.cc', 'service/raft/schema_raft_state_machine.cc', 'service/raft/raft_sys_table_storage.cc', 'serializer.cc', diff --git a/cql3/functions/user_function.cc b/cql3/functions/user_function.cc index ca9d7058a759..5b81c32c7308 100644 --- a/cql3/functions/user_function.cc +++ b/cql3/functions/user_function.cc @@ -20,7 +20,6 @@ */ #include "user_function.hh" -#include "lua.hh" #include "log.hh" #include "cql_serialization_format.hh" @@ -32,12 +31,10 @@ namespace functions { extern logging::logger log; user_function::user_function(function_name name, std::vector arg_types, std::vector arg_names, - sstring body, sstring language, data_type return_type, bool called_on_null_input, sstring bitcode, - lua::runtime_config cfg) + sstring body, sstring language, data_type return_type, bool called_on_null_input, context ctx) : abstract_function(std::move(name), std::move(arg_types), std::move(return_type)), _arg_names(std::move(arg_names)), _body(std::move(body)), _language(std::move(language)), - _called_on_null_input(called_on_null_input), _bitcode(std::move(bitcode)), - _cfg(std::move(cfg)) {} + _called_on_null_input(called_on_null_input), _ctx(std::move(ctx)) {} bool user_function::is_pure() const { return true; } @@ -53,20 +50,31 @@ bytes_opt user_function::execute(cql_serialization_format sf, const std::vector< throw std::logic_error("Wrong number of parameters"); } - std::vector values; - values.reserve(parameters.size()); - for (int i = 0, n = types.size(); i != n; ++i) { - const data_type& type = types[i]; - const bytes_opt& bytes = parameters[i]; - if (!bytes && !_called_on_null_input) { - return std::nullopt; - } - values.push_back(bytes ? type->deserialize(*bytes) : data_value::make_null(type)); - } if (!seastar::thread::running_in_thread()) { on_internal_error(log, "User function cannot be executed in this context"); } - return lua::run_script(lua::bitcode_view{_bitcode}, values, return_type(), _cfg).get0(); + return seastar::visit(_ctx, + [&] (lua_context& ctx) -> bytes_opt { + std::vector values; + values.reserve(parameters.size()); + for (int i = 0, n = types.size(); i != n; ++i) { + const data_type& type = types[i]; + const bytes_opt& bytes = parameters[i]; + if (!bytes && !_called_on_null_input) { + return std::nullopt; + } + values.push_back(bytes ? type->deserialize(*bytes) : data_value::make_null(type)); + } + return lua::run_script(lua::bitcode_view{ctx.bitcode}, values, return_type(), ctx.cfg).get0(); + }, + [&] (wasm::context& ctx) { + try { + return wasm::run_script(ctx, arg_types(), parameters, return_type(), _called_on_null_input).get0(); + } catch (const wasm::exception& e) { + throw exceptions::invalid_request_exception(format("UDF error: {}", e.what())); + } + }); } + } } diff --git a/cql3/functions/user_function.hh b/cql3/functions/user_function.hh index 7d126b5e55ee..c5318728292e 100644 --- a/cql3/functions/user_function.hh +++ b/cql3/functions/user_function.hh @@ -25,31 +25,39 @@ #include "abstract_function.hh" #include "scalar_function.hh" -#include "lua.hh" +#include "lang/lua.hh" +#include "lang/wasm.hh" namespace cql3 { namespace functions { + class user_function final : public abstract_function, public scalar_function { +public: + struct lua_context { + sstring bitcode; + // FIXME: We should not need a copy in each function. It is here + // because user_function::execute is only passed the + // cql_serialization_format and the runtime arguments. We could + // avoid it by having a runtime->execute(user_function) instead, + // but that is a large refactoring. We could also store a + // lua_runtime in a thread_local variable, but that is one extra + // global. + lua::runtime_config cfg; + }; + + using context = std::variant; + +private: std::vector _arg_names; sstring _body; sstring _language; bool _called_on_null_input; - sstring _bitcode; - - // FIXME: We should not need a copy in each function. It is here - // because user_function::execute is only passed the - // cql_serialization_format and the runtime arguments. We could - // avoid it by having a runtime->execute(user_function) instead, - // but that is a large refactoring. We could also store a - // lua_runtime in a thread_local variable, but that is one extra - // global. - lua::runtime_config _cfg; + context _ctx; public: user_function(function_name name, std::vector arg_types, std::vector arg_names, sstring body, - sstring language, data_type return_type, bool called_on_null_input, sstring bitcode, - lua::runtime_config cfg); + sstring language, data_type return_type, bool called_on_null_input, context ctx); const std::vector& arg_names() const { return _arg_names; } diff --git a/cql3/statements/create_function_statement.cc b/cql3/statements/create_function_statement.cc index d53341f10f4f..120d774fcda4 100644 --- a/cql3/statements/create_function_statement.cc +++ b/cql3/statements/create_function_statement.cc @@ -25,7 +25,7 @@ #include "prepared_statement.hh" #include "service/migration_manager.hh" #include "service/storage_proxy.hh" -#include "lua.hh" +#include "lang/lua.hh" #include "database.hh" #include "cql3/query_processor.hh" @@ -37,7 +37,7 @@ void create_function_statement::create(service::storage_proxy& proxy, functions: if (old && !dynamic_cast(old)) { throw exceptions::invalid_request_exception(format("Cannot replace '{}' which is not a user defined function", *old)); } - if (_language != "lua") { + if (_language != "lua" && _language != "xwasm") { throw exceptions::invalid_request_exception(format("Language '{}' is not supported", _language)); } data_type return_type = prepare_type(proxy, *_return_type); @@ -47,13 +47,25 @@ void create_function_statement::create(service::storage_proxy& proxy, functions: } auto&& db = proxy.get_db().local(); - lua::runtime_config cfg = lua::make_runtime_config(db.get_config()); + if (_language == "lua") { + auto cfg = lua::make_runtime_config(db.get_config()); + functions::user_function::context ctx = functions::user_function::lua_context { + .bitcode = lua::compile(cfg, arg_names, _body), + .cfg = cfg, + }; - // Checking that the function compiles also produces bitcode - auto bitcode = lua::compile(cfg, arg_names, _body); - - _func = ::make_shared(_name, _arg_types, std::move(arg_names), _body, _language, - std::move(return_type), _called_on_null_input, std::move(bitcode), std::move(cfg)); + _func = ::make_shared(_name, _arg_types, std::move(arg_names), _body, _language, + std::move(return_type), _called_on_null_input, std::move(ctx)); + } else if (_language == "xwasm") { + wasm::context ctx{db.wasm_engine(), _name.name}; + try { + wasm::compile(ctx, arg_names, _body); + _func = ::make_shared(_name, _arg_types, std::move(arg_names), _body, _language, + std::move(return_type), _called_on_null_input, std::move(ctx)); + } catch (const wasm::exception& we) { + throw exceptions::invalid_request_exception(we.what()); + } + } return; } diff --git a/database.hh b/database.hh index 7541c7261d92..11899202209e 100644 --- a/database.hh +++ b/database.hh @@ -134,6 +134,10 @@ class abstract_replication_strategy; } // namespace locator +namespace wasm { +class engine; +} + class mutation_reordered_with_truncate_exception : public std::exception {}; using shared_memtable = lw_shared_ptr; @@ -1343,6 +1347,8 @@ private: bool _supports_infinite_bound_range_deletions = false; gms::feature::listener_registration _infinite_bound_range_deletions_reg; + wasm::engine* _wasm_engine; + future<> init_commitlog(); public: const gms::feature_service& features() const { return _feat; } @@ -1351,6 +1357,14 @@ public: void set_local_id(utils::UUID uuid) noexcept { _local_host_id = std::move(uuid); } + wasm::engine* wasm_engine() { + return _wasm_engine; + } + + void set_wasm_engine(wasm::engine* engine) { + _wasm_engine = engine; + } + private: using system_keyspace = bool_class; void create_in_memory_keyspace(const lw_shared_ptr& ksm, system_keyspace system); diff --git a/db/schema_tables.cc b/db/schema_tables.cc index b2bc24a1ab9c..bccd8cd56d81 100644 --- a/db/schema_tables.cc +++ b/db/schema_tables.cc @@ -90,7 +90,7 @@ #include "user_types_metadata.hh" #include "index/target_parser.hh" -#include "lua.hh" +#include "lang/lua.hh" #include "db/query_context.hh" #include "serializer.hh" @@ -1548,12 +1548,26 @@ static shared_ptr create_func(database& db, cons auto arg_names = get_list(row, "argument_names"); auto body = row.get_nonnull("body"); - lua::runtime_config cfg = lua::make_runtime_config(db.get_config()); - auto bitcode = lua::compile(cfg, arg_names, body); + auto language = row.get_nonnull("language"); + if (language == "lua") { + lua::runtime_config cfg = lua::make_runtime_config(db.get_config()); + cql3::functions::user_function::context ctx = cql3::functions::user_function::lua_context { + .bitcode = lua::compile(cfg, arg_names, body), + .cfg = cfg, + }; - return ::make_shared(std::move(name), std::move(arg_types), std::move(arg_names), - std::move(body), row.get_nonnull("language"), std::move(return_type), - row.get_nonnull("called_on_null_input"), std::move(bitcode), std::move(cfg)); + return ::make_shared(std::move(name), std::move(arg_types), std::move(arg_names), + std::move(body), language, std::move(return_type), + row.get_nonnull("called_on_null_input"), std::move(ctx)); + } else if (language == "xwasm") { + wasm::context ctx{db.wasm_engine(), name.name}; + wasm::compile(ctx, arg_names, body); + return ::make_shared(std::move(name), std::move(arg_types), std::move(arg_names), + std::move(body), language, std::move(return_type), + row.get_nonnull("called_on_null_input"), std::move(ctx)); + } else { + throw std::runtime_error(format("Unsupported language for UDF: {}", language)); + } } static shared_ptr create_aggregate(database& db, const query::result_set_row& row) { diff --git a/docs/design-notes/wasm.md b/docs/design-notes/wasm.md new file mode 100644 index 000000000000..4a24c4b10ba6 --- /dev/null +++ b/docs/design-notes/wasm.md @@ -0,0 +1,209 @@ +# WASM support for user-defined functions + +This document describes the details of WASM language support in user-defined functions (UDF). `wasm` is one of the possible languages to implement these functions in, aside `lua`. + +## Experimental status + +Before the design of WebAssembly integration and ABI is finalized, it's only available in experimental mode. +User-defined functions are already experimental at the time of this writing, but in order to be ready +for backward incompatible changes, the language accepted by CQL is currently named "xwasm". +Once the ABI is set in stone, it should be changed to "wasm". + +## Supported types + +Due to the limitations imposed by WebAssembly specification, the following types can be natively supported with CQL: + - int + - bigint + - smallint + - tinyint + - bool + - float + - double + +The rest of CQL types (text, date, timestamp, etc.) are implemented by putting their serialized representation into wasm module memory +and passing each parameter as a pointer to a struct of the form: +```c +{ + int32_t size; + char buf[0]; +} +``` + +## Support for NULL values + +Native WebAssembly types can only be represented directly if the function does not operate on NULL values. Fortunately, user-defined functions +explicitly specify whether they accept NULL or not. + +If the function is specified not to accept NULL, all parameters are represented +as in the description above. + +If the function is specified to accept NULL, each parameter should be represented in WebAssembly as a struct, +which starts with its size, followed by a serialized form explained in the paragraph above, i.e. +```c +{ + int32_t size; + char buf[0]; +} +``` + +the important distinction is that size equal to -1 (minus one)indicates that the value is NULL and should not be parsed. + +## Return values + +NOTE: ABI for return values is experimental and subject to change. It can (and should) be redesigned +after implementing helper libraries for a few popular languages (including C++, C, Rust). + +Natively supported types are returned as is. All the other types are returned via memory, similarly to the way +they are passed as parameters: the wasm function should return the serialized form of the returned value, +preceded by its size: +```c +{ + int32_t size; + char buf[0]; +} +``` + +Currently, returning NULL values is possible only for functions declared to be `CALLED ON NULL INPUT`. +For such functions, the return value is always expected to be presented in the serialized form (which +allows representing nulls), even for types natively supported by WebAssembly. +The decision is experimental and it was done in order to allow returning some values as native WebAssembly types +without having to allocate memory for them and serialize them first. +Alternative ways of expressing whether a function can **return** null should be considered - perhaps +as CQL syntax extension. + +## How to generate a correct wasm UDF source code + +Scylla accepts UDF's source code in WebAssembly text format - also known as `wat`. The source can use and define whatever's needed for execution, including multiple helper functions and symbols. The only requirement for it to be accepted as correct UDF source is that the WebAssembly module exports a symbol with the same name as the function, and this symbol's type is indeed a function with correct signature. + +UDF's source code can be, naturally, simply coded by hand in wat. It is not often very convenient to program directly in assembly, so here are a few tips. + +### Compiling from Rust to wasm + +#### C + +Clang is capable of compiling C source code to wasm and it also supports useful built-ins +for using wasm-specific interfaces, like `__builtin_wasm_memory_size` and `__builtin_wasm_memory_grow` +for memory management. + +Example source code: + +```c +struct __attribute__((packed)) nullable_bigint { + int size; + long long v; +}; + +static long long swap_int64(long long val) { + val = ((val << 8) & 0xFF00FF00FF00FF00ULL ) | ((val >> 8) & 0x00FF00FF00FF00FFULL ); + val = ((val << 16) & 0xFFFF0000FFFF0000ULL ) | ((val >> 16) & 0x0000FFFF0000FFFFULL ); + return (val << 32) | ((val >> 32) & 0xFFFFFFFFULL); +} + +long long fib_aux(long long n) { + if (n < 2) { + return n; + } + return fib_aux(n-1) + fib_aux(n-2); +} + +struct nullable_bigint* fib(struct nullable_bigint* p) { + // Initialize memory for the return struct + struct nullable_bigint* ret = (struct nullable_bigint*)__builtin_wasm_memory_size(0); + __builtin_wasm_memory_grow(0, sizeof(struct nullable_bigint)); + + ret->size = sizeof(long long); + if (p->size == -1) { + ret->v = swap_int64(42); + } else { + ret->v = swap_int64(fib_aux(swap_int64(p->v))); + } + return ret; +} +``` + +Compilation instructions: +```bash + clang -O2 --target=wasm32 --no-standard-libraries -Wl,--export-all -Wl,--no-entry fibnull.c -o fibnull.wasm + wasm2wat fibnull.wasm > fibnull.wat +``` + +#### Rust + +Rust ecosystem exposes a rather convenient way of generating WebAssembly, with the help of `cargo wasi` and `wasm_bindgen`. + +As a short example, here's a sample Rust code which can be compiled to WebAssembly: +```rust +use wasm_bindgen::prelude::*; + +#[wasm_bindgen] +pub fn fib(n: i32) -> i32 { + if n < 2 { + n + } else { + fib(n - 1) + fib(n - 2) + } +} +``` + +A more detailed guide and examples can be found here: +https://bytecodealliance.github.io/cargo-wasi/hello-world.html +https://rustwasm.github.io/wasm-bindgen/ + +### Generating wat from wasm + +For those who want to use precompiled WASM modules, it's enough to translate WASM bytecode to `wat` representation. On Linux, it can be achieved by a `wasm2wat` tool, available in most distributions in the `wabt` package. + +## Example + +Here's how a `wasm` function can be declared: + +```cql +CREATE FUNCTION ks.fib (input bigint) RETURNS NULL ON NULL INPUT RETURNS bigint LANGUAGE xwasm +AS '(module + (func $fib (param $n i64) (result i64) + (if + (i64.lt_s (local.get $n) (i64.const 2)) + (return (local.get $n)) + ) + (i64.add + (call $fib (i64.sub (local.get $n) (i64.const 1))) + (call $fib (i64.sub (local.get $n) (i64.const 2))) + ) + ) + (export "fib" (func $fib)) +)' +``` + +and it can be invoked just like a regular UDF: +```cql +scylla@cqlsh:ks> CREATE TABLE t(id int, n bigint, PRIMARY KEY(id,n)); +scylla@cqlsh:ks> INSERT INTO t(id, n) VALUES (0, 0); +scylla@cqlsh:ks> INSERT INTO t(id, n) VALUES (0, 1); +scylla@cqlsh:ks> INSERT INTO t(id, n) VALUES (0, 2); +scylla@cqlsh:ks> INSERT INTO t(id, n) VALUES (0, 3); +scylla@cqlsh:ks> INSERT INTO t(id, n) VALUES (0, 4); +scylla@cqlsh:ks> INSERT INTO t(id, n) VALUES (0, 5); +scylla@cqlsh:ks> INSERT INTO t(id, n) VALUES (0, 6); +scylla@cqlsh:ks> INSERT INTO t(id, n) VALUES (0, 7); +scylla@cqlsh:ks> INSERT INTO t(id, n) VALUES (0, 8); +scylla@cqlsh:ks> INSERT INTO t(id, n) VALUES (0, 9); +scylla@cqlsh:ks> INSERT INTO t(id, n) VALUES (0, 10); +scylla@cqlsh:ks> SELECT n, ks.fib(n) FROM t; + + n | ks.fib(n) +----+----------- + 0 | 0 + 1 | 1 + 2 | 1 + 3 | 2 + 4 | 3 + 5 | 5 + 6 | 8 + 7 | 13 + 8 | 21 + 9 | 34 + 10 | 55 + +(11 rows) +``` + diff --git a/lua.cc b/lang/lua.cc similarity index 100% rename from lua.cc rename to lang/lua.cc diff --git a/lua.hh b/lang/lua.hh similarity index 100% rename from lua.hh rename to lang/lua.hh diff --git a/lang/wasm.cc b/lang/wasm.cc new file mode 100644 index 000000000000..8caaa133205c --- /dev/null +++ b/lang/wasm.cc @@ -0,0 +1,315 @@ +/* + * Copyright (C) 2021-present ScyllaDB + */ + +/* + * This file is part of Scylla. + * + * Scylla is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Scylla is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with Scylla. If not, see . + */ + +#ifdef SCYLLA_ENABLE_WASMTIME + +#include "wasm.hh" +#include "concrete_types.hh" +#include "utils/utf8.hh" +#include "utils/ascii.hh" +#include "utils/date.h" +#include "db/config.hh" +#include +#include +#include "seastarx.hh" + +static logging::logger wasm_logger("wasm"); + +namespace wasm { + +context::context(wasm::engine* engine_ptr, std::string name) : engine_ptr(engine_ptr), function_name(name) { +} + +static std::pair create_instance_and_func(context& ctx, wasmtime::Store& store) { + auto instance_res = wasmtime::Instance::create(store, *ctx.module, {}); + if (!instance_res) { + throw wasm::exception(format("Creating a wasm runtime instance failed: {}", instance_res.err().message())); + } + auto instance = instance_res.unwrap(); + auto function_obj = instance.get(store, ctx.function_name); + if (!function_obj) { + throw wasm::exception(format("Function {} was not found in given wasm source code", ctx.function_name)); + } + wasmtime::Func* func = std::get_if(&*function_obj); + if (!func) { + throw wasm::exception(format("Exported object {} is not a function", ctx.function_name)); + } + return std::make_pair(std::move(instance), std::move(*func)); +} + +void compile(context& ctx, const std::vector& arg_names, std::string script) { + wasm_logger.debug("Compiling script {}", script); + auto module = wasmtime::Module::compile(ctx.engine_ptr->get(), script); + if (!module) { + throw wasm::exception(format("Compilation failed: {}", module.err().message())); + } + ctx.module = module.unwrap(); + // Create the instance and extract function definition for validation purposes only + wasmtime::Store store(ctx.engine_ptr->get()); + create_instance_and_func(ctx, store); +} + +struct init_arg_visitor { + const bytes_opt& param; + std::vector& argv; + wasmtime::Store& store; + wasmtime::Instance& instance; + + void operator()(const boolean_type_impl&) { + auto dv = boolean_type->deserialize(*param); + auto val = wasmtime::Val(int32_t(value_cast(dv))); + argv.push_back(std::move(val)); + } + void operator()(const byte_type_impl&) { + auto dv = byte_type->deserialize(*param); + auto val = wasmtime::Val(int32_t(value_cast(dv))); + argv.push_back(std::move(val)); + } + void operator()(const short_type_impl&) { + auto dv = short_type->deserialize(*param); + auto val = wasmtime::Val(int32_t(value_cast(dv))); + argv.push_back(std::move(val)); + } + void operator()(const double_type_impl&) { + auto dv = double_type->deserialize(*param); + auto val = wasmtime::Val(value_cast(dv)); + argv.push_back(std::move(val)); + } + void operator()(const float_type_impl&) { + auto dv = float_type->deserialize(*param); + auto val = wasmtime::Val(value_cast(dv)); + argv.push_back(std::move(val)); + } + void operator()(const int32_type_impl&) { + auto dv = int32_type->deserialize(*param); + auto val = wasmtime::Val(value_cast(dv)); + argv.push_back(std::move(val)); + } + void operator()(const long_type_impl&) { + auto dv = long_type->deserialize(*param); + auto val = wasmtime::Val(value_cast(dv)); + argv.push_back(std::move(val)); + } + + void operator()(const abstract_type& t) { + // set up exported memory's underlying buffer, + // `memory` is required to be exported in the WebAssembly module + auto memory_export = instance.get(store, "memory"); + if (!memory_export) { + throw wasm::exception("memory export not found - please export `memory` in the wasm module"); + } + auto memory = std::get(*memory_export); + uint8_t* data = memory.data(store).data(); + size_t mem_size = memory.size(store); + if (!param) { + on_internal_error(wasm_logger, "init_arg_visitor does not accept null values"); + } + int32_t serialized_size = param->size(); + if (serialized_size > std::numeric_limits::max()) { + throw wasm::exception(format("Serialized parameter is too large: {} > {}", serialized_size, std::numeric_limits::max())); + } + auto grown = memory.grow(store, sizeof(int32_t) + serialized_size); // for fitting serialized size + the buffer itself + if (!grown) { + throw wasm::exception(format("Failed to grow wasm memory to {}: {}", serialized_size, grown.err().message())); + } + // put the size in wasm module's memory + std::memcpy(data + mem_size, reinterpret_cast(&serialized_size), sizeof(int32_t)); + // put the argument in wasm module's memory + std::memcpy(data + mem_size + sizeof(int32_t), param->data(), serialized_size); + + // the place inside wasm memory where the struct is placed + argv.push_back(int32_t(mem_size)); + } +}; + +struct init_nullable_arg_visitor { + const bytes_opt& param; + std::vector& argv; + wasmtime::Store& store; + wasmtime::Instance& instance; + + void operator()(const abstract_type& t) { + // set up exported memory's underlying buffer, + // `memory` is required to be exported in the WebAssembly module + auto memory_export = instance.get(store, "memory"); + if (!memory_export) { + throw wasm::exception("memory export not found - please export `memory` in the wasm module"); + } + auto memory = std::get(*memory_export); + uint8_t* data = memory.data(store).data(); + size_t mem_size = memory.size(store); + const int32_t serialized_size = param ? param->size() : 0; + if (serialized_size > std::numeric_limits::max()) { + throw wasm::exception(format("Serialized parameter is too large: {} > {}", param->size(), std::numeric_limits::max())); + } + auto grown = memory.grow(store, sizeof(int32_t) + serialized_size); // for fitting the serialized size + the buffer itself + if (!grown) { + throw wasm::exception(format("Failed to grow wasm memory to {}: {}", serialized_size, grown.err().message())); + } + if (param) { + // put the size in wasm module's memory + std::memcpy(data + mem_size, reinterpret_cast(&serialized_size), sizeof(int32_t)); + // put the argument in wasm module's memory + std::memcpy(data + mem_size + sizeof(int32_t), param->data(), serialized_size); + } else { + // size of -1 means that the value is null + const int32_t is_null = -1; + std::memcpy(data + mem_size, reinterpret_cast(&is_null), sizeof(int32_t)); + } + + // the place inside wasm memory where the struct is placed + argv.push_back(int32_t(mem_size)); + } +}; + +struct from_val_visitor { + const wasmtime::Val& val; + wasmtime::Store& store; + wasmtime::Instance& instance; + + bytes_opt operator()(const boolean_type_impl&) { + expect_kind(wasmtime::ValKind::I32); + return boolean_type->decompose(bool(val.i32())); + } + bytes_opt operator()(const byte_type_impl&) { + expect_kind(wasmtime::ValKind::I32); + return byte_type->decompose(int8_t(val.i32())); + } + bytes_opt operator()(const short_type_impl&) { + expect_kind(wasmtime::ValKind::I32); + return short_type->decompose(int16_t(val.i32())); + } + bytes_opt operator()(const double_type_impl&) { + expect_kind(wasmtime::ValKind::F64); + return double_type->decompose(val.f64()); + } + bytes_opt operator()(const float_type_impl&) { + expect_kind(wasmtime::ValKind::F32); + return float_type->decompose(val.f32()); + } + bytes_opt operator()(const int32_type_impl&) { + expect_kind(wasmtime::ValKind::I32); + return int32_type->decompose(val.i32()); + } + bytes_opt operator()(const long_type_impl&) { + expect_kind(wasmtime::ValKind::I64); + return long_type->decompose(val.i64()); + } + + bytes_opt operator()(const abstract_type& t) { + expect_kind(wasmtime::ValKind::I32); + auto memory_export = instance.get(store, "memory"); + if (!memory_export) { + throw wasm::exception("memory export not found - please export `memory` in the wasm module"); + } + auto memory = std::get(*memory_export); + uint8_t* mem_base = memory.data(store).data(); + uint8_t* data = mem_base + val.i32(); + int32_t ret_size; + std::memcpy(reinterpret_cast(&ret_size), data, 4); + if (ret_size == -1) { + return bytes_opt{}; + } + data += sizeof(int32_t); // size of the return type was consumed + return t.decompose(t.deserialize(bytes_view(reinterpret_cast(data), ret_size))); + } + + void expect_kind(wasmtime::ValKind expected) { + // Created to match wasmtime::ValKind order + static constexpr std::string_view kind_str[] = { + "i32", + "i64", + "f32", + "f64", + "v128", + "externref", + "funcref", + }; + if (val.kind() != expected) { + throw wasm::exception(format("Incorrect wasm value kind returned. Expected {}, got {}", kind_str[size_t(expected)], kind_str[size_t(val.kind())])); + } + } +}; + +seastar::future run_script(context& ctx, const std::vector& arg_types, const std::vector& params, data_type return_type, bool allow_null_input) { + wasm_logger.debug("Running function {}", ctx.function_name); + + auto store = wasmtime::Store(ctx.engine_ptr->get()); + // Replenish the store with initial amount of fuel + auto added = store.context().add_fuel(ctx.engine_ptr->initial_fuel_amount()); + if (!added) { + co_return coroutine::make_exception(wasm::exception(added.err().message())); + } + auto [instance, func] = create_instance_and_func(ctx, store); + std::vector argv; + for (size_t i = 0; i < arg_types.size(); ++i) { + const abstract_type& type = *arg_types[i]; + const bytes_opt& param = params[i]; + // If nulls are allowed, each type will be passed indirectly + // as a struct {bool is_null; int32_t serialized_size, char[] serialized_buf} + if (allow_null_input) { + visit(type, init_nullable_arg_visitor{param, argv, store, instance}); + } else if (param) { + visit(type, init_arg_visitor{param, argv, store, instance}); + } else { + co_return coroutine::make_exception(wasm::exception(format("Function {} cannot be called on null values", ctx.function_name))); + } + } + uint64_t fuel_before = *store.context().fuel_consumed(); + + auto result = func.call(store, argv); + + uint64_t consumed = *store.context().fuel_consumed() - fuel_before; + wasm_logger.debug("Consumed {} fuel units", consumed); + + if (!result) { + co_return coroutine::make_exception(wasm::exception("Calling wasm function failed: " + result.err().message())); + } + std::vector result_vec = std::move(result).unwrap(); + if (result_vec.size() != 1) { + co_return coroutine::make_exception(wasm::exception(format("Unexpected number of returned values: {} (expected: 1)", result_vec.size()))); + } + + // TODO: ABI for return values is experimental and subject to change in the future. + // Currently, if a function is marked with `CALLED ON NULL INPUT` it is also capable + // of returning nulls - which implies that all types are returned in its serialized form. + // Otherwise, it is expected to return non-null values, which makes it possible to return + // values of types natively supported by wasm via registers, without prior serialization + // and avoiding allocations. This is however not ideal, especially that theoretically + // it's perfectly fine for a function which `RETURNS NULL ON NULL INPUT` to also want to + // return null on non-null input. The workaround for UDF programmers now is to always use + // CALLED ON NULL INPUT if they want to be able to return nulls. + // In order to properly decide on the ABI, an attempt should be made to provide library + // wrappers for a few languages (C++, C, Rust), and see whether the ABI makes it easy + // to interact with - we want to avoid poor user experience, and it's hard to judge it + // before we actually have helper libraries. + if (allow_null_input) { + // Force calling the default method for abstract_type, which checks for nulls + // and expects a serialized input + co_return from_val_visitor{result_vec[0], store, instance}(static_cast(*return_type)); + } else { + co_return visit(*return_type, from_val_visitor{result_vec[0], store, instance}); + } +} + +} + +#endif diff --git a/lang/wasm.hh b/lang/wasm.hh new file mode 100644 index 000000000000..80a179e2c4c6 --- /dev/null +++ b/lang/wasm.hh @@ -0,0 +1,71 @@ +/* + * Copyright (C) 2021-present ScyllaDB + */ + +/* + * This file is part of Scylla. + * + * Scylla is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Scylla is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with Scylla. If not, see . + */ + +#pragma once + +#include "types.hh" +#include +#include "lang/wasm_engine.hh" + +namespace wasm { + +struct exception : public std::exception { + std::string _msg; +public: + explicit exception(std::string_view msg) : _msg(msg) {} + const char* what() const noexcept { + return _msg.c_str(); + } +}; + +#ifdef SCYLLA_ENABLE_WASMTIME + +struct context { + wasm::engine* engine_ptr; + std::optional module; + std::string function_name; + + context(wasm::engine* engine_ptr, std::string name); +}; + +void compile(context& ctx, const std::vector& arg_names, std::string script); + +seastar::future run_script(context& ctx, const std::vector& arg_types, const std::vector& params, data_type return_type, bool allow_null_input); + +#else + +struct context { + context(wasm::engine*, std::string) { + throw wasm::exception("WASM support was not enabled during compilation!"); + } +}; + +inline void compile(context&, const std::vector&, std::string) { + throw wasm::exception("WASM support was not enabled during compilation!"); +} + +inline seastar::future run_script(context&, const std::vector&, const std::vector&, data_type, bool) { + throw wasm::exception("WASM support was not enabled during compilation!"); +} + +#endif + +} diff --git a/lang/wasm_engine.cc b/lang/wasm_engine.cc new file mode 100644 index 000000000000..53e2804f0938 --- /dev/null +++ b/lang/wasm_engine.cc @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2021-present ScyllaDB + */ + +/* + * This file is part of Scylla. + * + * Scylla is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Scylla is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with Scylla. If not, see . + */ + +#include "wasm_engine.hh" + +namespace wasm { + +#ifdef SCYLLA_ENABLE_WASMTIME + +seastar::future<> init_sharded_engine(seastar::sharded& e) { + // Fuel defines more or less how many bytecode instructions + // can be performed at once. Empirically, 20k units + // allow for considerably less than 0.5ms of preemption-free execution time. + // TODO: investigate other configuration variables. + // We're particularly interested in limiting resource usage + // and yielding in the middle of execution - which is possible + // in the original wasmtime implementation for Rust and tightly + // bound with its native async support, but not yet possible + // in wasmtime.hh binding at the time of this writing. + // It's highly probable that a more generic support for yielding + // can be contributed to wasmtime. + const uint64_t initial_fuel_amount = 20*1024; + return e.start(initial_fuel_amount); +} + +#else + +seastar::future<> init_sharded_engine(seastar::sharded& e) { + return e.start(); +} + + +#endif + +} // namespace wasm diff --git a/lang/wasm_engine.hh b/lang/wasm_engine.hh new file mode 100644 index 000000000000..8917962118b1 --- /dev/null +++ b/lang/wasm_engine.hh @@ -0,0 +1,63 @@ +/* + * Copyright (C) 2021-present ScyllaDB + */ + +/* + * This file is part of Scylla. + * + * Scylla is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Scylla is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with Scylla. If not, see . + */ + +#pragma once + +#include +#include + +#ifdef SCYLLA_ENABLE_WASMTIME + +#include "wasmtime.hh" + +namespace wasm { + +class engine { + wasmtime::Engine _engine; + uint64_t _initial_fuel_amount; +public: + engine(uint64_t initial_fuel_amount) + : _engine(make_config()) + , _initial_fuel_amount(initial_fuel_amount) + {} + wasmtime::Engine& get() { return _engine; } + uint64_t initial_fuel_amount() { return _initial_fuel_amount; }; +private: + wasmtime::Config make_config() { + wasmtime::Config cfg; + cfg.consume_fuel(true); + return cfg; + } +}; + +} + +#else + +namespace wasm { +class engine {}; +} + +#endif + +namespace wasm { +seastar::future<> init_sharded_engine(seastar::sharded& e); +} diff --git a/lang/wasmtime.hh b/lang/wasmtime.hh new file mode 100644 index 000000000000..a909eeae9341 --- /dev/null +++ b/lang/wasmtime.hh @@ -0,0 +1,2665 @@ +/** + * Downloaded from https://github.com/bytecodealliance/wasmtime-cpp/blob/main/include/wasmtime.hh + * License: Apache License 2.0 + */ + +/** + * \mainpage + * + * This project is a C++ API for + * [Wasmtime](https://github.com/bytecodealliance/wasmtime). Support for the + * C++ API is exclusively built on the [C API of + * Wasmtime](https://docs.wasmtime.dev/c-api/), so the C++ support for this is + * simply a single header file. To use this header file, though, it must be + * combined with the header and binary of Wasmtime's C API. Note, though, that + * while this header is built on top of the `wasmtime.h` header file you should + * only need to use the contents of this header file to interact with Wasmtime. + * + * Examples can be [found + * online](https://github.com/bytecodealliance/wasmtime-cpp/tree/main/examples) + * and otherwise be sure to check out the + * [README](https://github.com/bytecodealliance/wasmtime-cpp/blob/main/README.md) + * for simple usage instructions. Otherwise you can dive right in to the + * reference documentation of \ref wasmtime.hh + * + * \example hello.cc + * \example gcd.cc + * \example linking.cc + * \example memory.cc + * \example interrupt.cc + * \example externref.cc + */ + +/** + * \file wasmtime.hh + */ + +#ifndef WASMTIME_HH +#define WASMTIME_HH + +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef __cpp_lib_span +#include +#endif + +#include "wasmtime.h" + +namespace wasmtime { + +#ifdef __cpp_lib_span + +/// \brief Alias to C++20 std::span when it is available +template +using Span = std::span; + +#else + +/// \brief Means number of elements determined at runtime +inline constexpr size_t dynamic_extent = + std::numeric_limits::max(); + +/** + * \brief Span class used when c++20 is not available + * @tparam T Type of data + * @tparam Extent Static size of data refered by Span class + */ +template class Span { +public: + /// \brief Type used to iterate over this span (a raw pointer) + using iterator = T *; + + /// \brief Constructor of Span class + Span(T *t, std::size_t n) : ptr_{t}, size_{n} {} + + /// \brief Constructor of Span class for containers + template