Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/mp/proxy-io.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <capnp/rpc-twoparty.h>

#include <functional>
#include <memory>
#include <string>

namespace mp {
Expand Down Expand Up @@ -199,7 +200,7 @@ class EventLoop
LoggingErrorHandler m_error_handler{*this};

//! Capnp list of pending promises.
boost::optional<kj::TaskSet> m_task_set;
std::unique_ptr<kj::TaskSet> m_task_set;

//! List of connections.
std::list<Connection> m_incoming_connections;
Expand Down
111 changes: 64 additions & 47 deletions src/mp/gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
#include <mp/config.h>
#include <mp/util.h>

#include <boost/optional.hpp>
#include <algorithm>
#include <capnp/schema-parser.h>
#include <fstream>
#include <map>
#include <set>
#include <sstream>
#include <unistd.h>
#include <vector>

#define PROXY_BIN "mpgen"
Expand All @@ -25,12 +26,38 @@ constexpr uint64_t NAME_ANNOTATION_ID = 0xb594888f63f4dbb9ull; // From prox
constexpr uint64_t SKIP_ANNOTATION_ID = 0x824c08b82695d8ddull; // From proxy.capnp

template <typename Reader>
boost::optional<capnp::schema::Value::Reader> GetAnnotation(const Reader& reader, uint64_t id)
static bool AnnotationExists(const Reader& reader, uint64_t id)
{
for (const auto annotation : reader.getAnnotations()) {
if (annotation.getId() == id) return annotation.getValue();
if (annotation.getId() == id) {
return true;
}
}
return false;
}

template <typename Reader>
static bool GetAnnotationText(const Reader& reader, uint64_t id, kj::StringPtr* result)
{
for (const auto annotation : reader.getAnnotations()) {
if (annotation.getId() == id) {
*result = annotation.getValue().getText();
return true;
}
}
return false;
}

template <typename Reader>
static bool GetAnnotationInt32(const Reader& reader, uint64_t id, int32_t* result)
{
for (const auto annotation : reader.getAnnotations()) {
if (annotation.getId() == id) {
*result = annotation.getValue().getInt32();
return true;
}
}
return {};
return false;
}

using CharSlice = kj::ArrayPtr<const char>;
Expand Down Expand Up @@ -162,9 +189,7 @@ void Generate(kj::StringPtr src_prefix,
h << "namespace mp {\n";

kj::StringPtr message_namespace;
if (auto value = GetAnnotation(file_schema.getProto(), NAMESPACE_ANNOTATION_ID)) {
message_namespace = value->getText();
}
GetAnnotationText(file_schema.getProto(), NAMESPACE_ANNOTATION_ID, &message_namespace);

std::string base_name = include_base;
size_t output_slash = base_name.rfind("/");
Expand Down Expand Up @@ -202,9 +227,7 @@ void Generate(kj::StringPtr src_prefix,
kj::StringPtr node_name = node_nested.getName();
const auto& node = file_schema.getNested(node_name);
kj::StringPtr proxied_class_type;
if (auto proxy = GetAnnotation(node.getProto(), WRAP_ANNOTATION_ID)) {
proxied_class_type = proxy->getText();
}
GetAnnotationText(node.getProto(), WRAP_ANNOTATION_ID, &proxied_class_type);

if (node.getProto().isStruct()) {
const auto& struc = node.asStruct();
Expand Down Expand Up @@ -239,7 +262,7 @@ void Generate(kj::StringPtr src_prefix,
dec << " using Accessors = std::tuple<";
size_t i = 0;
for (const auto field : struc.getFields()) {
if (GetAnnotation(field.getProto(), SKIP_ANNOTATION_ID)) {
if (AnnotationExists(field.getProto(), SKIP_ANNOTATION_ID)) {
continue;
}
if (i) dec << ", ";
Expand All @@ -258,14 +281,12 @@ void Generate(kj::StringPtr src_prefix,
inl << " using Struct = " << message_namespace << "::" << node_name << ";\n";
size_t i = 0;
for (const auto field : struc.getFields()) {
if (GetAnnotation(field.getProto(), SKIP_ANNOTATION_ID)) {
if (AnnotationExists(field.getProto(), SKIP_ANNOTATION_ID)) {
continue;
}
auto field_name = field.getProto().getName();
auto member_name = field_name;
if (auto name = GetAnnotation(field.getProto(), NAME_ANNOTATION_ID)) {
member_name = name->getText();
}
GetAnnotationText(field.getProto(), NAME_ANNOTATION_ID, &member_name);
inl << " static auto get(std::integral_constant<size_t, " << i << ">) -> AUTO_RETURN("
<< "&" << proxied_class_type << "::" << member_name << ")\n";
++i;
Expand Down Expand Up @@ -300,9 +321,7 @@ void Generate(kj::StringPtr src_prefix,
for (const auto method : interface.getMethods()) {
kj::StringPtr method_name = method.getProto().getName();
kj::StringPtr proxied_method_name = method_name;
if (auto name = GetAnnotation(method.getProto(), NAME_ANNOTATION_ID)) {
proxied_method_name = name->getText();
}
GetAnnotationText(method.getProto(), NAME_ANNOTATION_ID, &proxied_method_name);

const std::string method_prefix = Format() << message_namespace << "::" << node_name
<< "::" << Cap(method_name);
Expand All @@ -311,8 +330,10 @@ void Generate(kj::StringPtr src_prefix,

struct Field
{
boost::optional<::capnp::StructSchema::Field> param;
boost::optional<::capnp::StructSchema::Field> result;
::capnp::StructSchema::Field param;
bool param_is_set = false;
::capnp::StructSchema::Field result;
bool result_is_set = false;
int args = 0;
bool retval = false;
bool optional = false;
Expand All @@ -326,7 +347,7 @@ void Generate(kj::StringPtr src_prefix,
bool has_result = false;

auto add_field = [&](const ::capnp::StructSchema::Field& schema_field, bool param) {
if (GetAnnotation(schema_field.getProto(), SKIP_ANNOTATION_ID)) {
if (AnnotationExists(schema_field.getProto(), SKIP_ANNOTATION_ID)) {
return;
}

Expand All @@ -336,39 +357,35 @@ void Generate(kj::StringPtr src_prefix,
fields.emplace_back();
}
auto& field = fields[inserted.first->second];
(param ? field.param : field.result) = schema_field;
if (param) {
field.param = schema_field;
field.param_is_set = true;
} else {
field.result = schema_field;
field.result_is_set = true;
}

if (!param && field_name == "result") {
field.retval = true;
has_result = true;
}

if (auto value = GetAnnotation(schema_field.getProto(), EXCEPTION_ANNOTATION_ID)) {
field.exception = value->getText();
}
GetAnnotationText(schema_field.getProto(), EXCEPTION_ANNOTATION_ID, &field.exception);

boost::optional<int> count;
if (auto value = GetAnnotation(schema_field.getProto(), COUNT_ANNOTATION_ID)) {
count = value->getInt32();
} else if (schema_field.getType().isStruct()) {
if (auto value =
GetAnnotation(schema_field.getType().asStruct().getProto(), COUNT_ANNOTATION_ID)) {
count = value->getInt32();
}
} else if (schema_field.getType().isInterface()) {
if (auto value =
GetAnnotation(schema_field.getType().asInterface().getProto(), COUNT_ANNOTATION_ID)) {
count = value->getInt32();
int32_t count = 1;
if (!GetAnnotationInt32(schema_field.getProto(), COUNT_ANNOTATION_ID, &count)) {
if (schema_field.getType().isStruct()) {
GetAnnotationInt32(schema_field.getType().asStruct().getProto(),
COUNT_ANNOTATION_ID, &count);
} else if (schema_field.getType().isInterface()) {
GetAnnotationInt32(schema_field.getType().asInterface().getProto(),
COUNT_ANNOTATION_ID, &count);
}
}


if (inserted.second && !field.retval && !field.exception.size()) {
if (count) {
field.args = *count;
} else {
field.args = 1;
}
field.args = count;
}
};

Expand All @@ -385,7 +402,7 @@ void Generate(kj::StringPtr src_prefix,
fields[field.second].optional = true;
}
auto want_field = field_idx.find("want" + Cap(field.first));
if (want_field != field_idx.end() && fields[want_field->second].param) {
if (want_field != field_idx.end() && fields[want_field->second].param_is_set) {
fields[want_field->second].skip = true;
fields[field.second].requested = true;
}
Expand All @@ -408,12 +425,12 @@ void Generate(kj::StringPtr src_prefix,
for (const auto& field : fields) {
if (field.skip) continue;

auto field_name = field.param ? field.param->getProto().getName() :
field.result ? field.result->getProto().getName() : "";
auto field_type = field.param ? field.param->getType() : field.result->getType();
const auto& f = field.param_is_set ? field.param : field.result;
auto field_name = f.getProto().getName();
auto field_type = f.getType();

std::ostringstream field_flags;
field_flags << (!field.param ? "FIELD_OUT" : field.result ? "FIELD_IN | FIELD_OUT" : "FIELD_IN");
field_flags << (!field.param_is_set ? "FIELD_OUT" : field.result_is_set ? "FIELD_IN | FIELD_OUT" : "FIELD_IN");
if (field.optional) field_flags << " | FIELD_OPTIONAL";
if (field.requested) field_flags << " | FIELD_REQUESTED";
if (BoxedType(field_type)) field_flags << " | FIELD_BOXED";
Expand Down
8 changes: 5 additions & 3 deletions src/mp/proxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

#include <assert.h>
#include <atomic>
#include <boost/optional.hpp>
#include <capnp/blob.h>
#include <capnp/capability.h>
#include <condition_variable>
Expand Down Expand Up @@ -124,13 +123,16 @@ void Connection::addAsyncCleanup(std::function<void()> fn)
}

EventLoop::EventLoop(const char* exe_name, LogFn log_fn, void* context)
: m_exe_name(exe_name), m_io_context(kj::setupAsyncIo()), m_log_fn(std::move(log_fn)), m_context(context)
: m_exe_name(exe_name),
m_io_context(kj::setupAsyncIo()),
m_log_fn(std::move(log_fn)),
m_context(context),
m_task_set(new kj::TaskSet(m_error_handler))
{
int fds[2];
KJ_SYSCALL(socketpair(AF_UNIX, SOCK_STREAM, 0, fds));
m_wait_fd = fds[0];
m_post_fd = fds[1];
m_task_set.emplace(m_error_handler);
}

EventLoop::~EventLoop()
Expand Down