Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for auth with custom schemes #62

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions include/mgclient.h
Original file line number Diff line number Diff line change
Expand Up @@ -1244,6 +1244,8 @@ MGCLIENT_EXPORT void mg_session_params_set_host(mg_session_params *,
const char *host);
MGCLIENT_EXPORT void mg_session_params_set_port(mg_session_params *,
uint16_t port);
MGCLIENT_EXPORT void mg_session_params_set_scheme(mg_session_params *,
const char *scheme);
MGCLIENT_EXPORT void mg_session_params_set_username(mg_session_params *,
const char *username);
MGCLIENT_EXPORT void mg_session_params_set_password(mg_session_params *,
Expand Down
18 changes: 15 additions & 3 deletions mgclient_cpp/include/mgclient.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class Client {
struct Params {
std::string host = "127.0.0.1";
uint16_t port = 7687;
std::string scheme = "none";
std::string username = "";
std::string password = "";
bool use_ssl = false;
Expand Down Expand Up @@ -148,13 +149,24 @@ inline std::unique_ptr<Client> Client::Connect(const Client::Params &params) {
if (!mg_params) {
return nullptr;
}
mg_session_params_set_host(mg_params, params.host.c_str());
mg_session_params_set_port(mg_params, params.port);
if (!params.host.empty()) {
mg_session_params_set_host(mg_params, params.host.c_str());
}
if (params.port != 0) {
mg_session_params_set_port(mg_params, params.port);
}
if (!params.scheme.empty()) {
mg_session_params_set_scheme(mg_params, params.scheme.c_str());
}
if (!params.username.empty()) {
mg_session_params_set_username(mg_params, params.username.c_str());
}
if (!params.password.empty()) {
mg_session_params_set_password(mg_params, params.password.c_str());
}
mg_session_params_set_user_agent(mg_params, params.user_agent.c_str());
if (!params.user_agent.empty()) {
mg_session_params_set_user_agent(mg_params, params.user_agent.c_str());
}
mg_session_params_set_sslmode(
mg_params, params.use_ssl ? MG_SSLMODE_REQUIRE : MG_SSLMODE_DISABLE);

Expand Down
47 changes: 33 additions & 14 deletions src/mgclient.c
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ typedef struct mg_session_params {
const char *address;
const char *host;
uint16_t port;
const char *scheme;
const char *username;
const char *password;
const char *user_agent;
Expand Down Expand Up @@ -118,6 +119,11 @@ void mg_session_params_set_port(mg_session_params *params, uint16_t port) {
params->port = port;
}

void mg_session_params_set_scheme(mg_session_params *params,
const char *scheme) {
params->scheme = scheme;
}

void mg_session_params_set_username(mg_session_params *params,
const char *username) {
params->username = username;
Expand Down Expand Up @@ -364,8 +370,8 @@ int mg_bolt_init_v1(mg_session *session, const mg_session_params *params) {
return status;
}

static mg_map *build_hello_extra(const char *user_agent, const char *username,
const char *password) {
static mg_map *build_hello_extra(const char *user_agent, const char *scheme,
const char *username, const char *password) {
mg_map *extra = mg_map_make_empty(4);
if (!extra) {
return NULL;
Expand All @@ -379,40 +385,53 @@ static mg_map *build_hello_extra(const char *user_agent, const char *username,
}
}

assert((username && password) || (!username && !password));
if (username) {
mg_value *scheme = mg_value_make_string("basic");
if (!scheme || mg_map_insert_unsafe(extra, "scheme", scheme) != 0) {
// The "basic" scheme requires a username and a password/credential within the
// HELLO message. Other schemes (save for "kerberos", which is not supported
// by Memgraph) do not have such requirements:
// https://neo4j.com/docs/bolt/current/bolt/message/#messages-hello
// https://neo4j.com/docs/bolt/current/bolt/message/#messages-logon
// NOTE: HELLO message does NOT contain schema after Bolt 5.0.
if (scheme && strcmp(scheme, "basic") == 0) {
assert(username && password);
}

if (!username && !password) {
mg_value *scheme_ = mg_value_make_string("none");
if (!scheme_ || mg_map_insert_unsafe(extra, "scheme", scheme_) != 0) {
goto cleanup;
}
return extra;
}

mg_value *scheme_ = mg_value_make_string(scheme ? scheme : "none"); // NOTE: Makes none default.
if (!scheme_ || mg_map_insert_unsafe(extra, "scheme", scheme_) != 0) {
goto cleanup;
}

if (username) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the change is also that username and password won't be inserted to extra if empty?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean here?

mg_value *principal = mg_value_make_string(username);
if (!principal || mg_map_insert_unsafe(extra, "principal", principal)) {
goto cleanup;
}
}

if (password) {
mg_value *credentials = mg_value_make_string(password);
if (!credentials ||
mg_map_insert_unsafe(extra, "credentials", credentials)) {
goto cleanup;
}
} else {
mg_value *scheme = mg_value_make_string("none");
if (!scheme || mg_map_insert_unsafe(extra, "scheme", scheme) != 0) {
goto cleanup;
}
}

return extra;

cleanup:
mg_map_destroy(extra);
return NULL;
}

int mg_bolt_init_v4(mg_session *session, const mg_session_params *params) {
mg_map *extra =
build_hello_extra(params->user_agent, params->username, params->password);
mg_map *extra = build_hello_extra(params->user_agent, params->scheme,
params->username, params->password);
if (!extra) {
return MG_ERROR_OOM;
}
Expand Down
205 changes: 77 additions & 128 deletions tests/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include <thread>

#include "mgclient.h"
#include "mgcommon.h"
#include "mgsession.h"
#include "mgsocket.h"

Expand Down Expand Up @@ -508,76 +507,73 @@ TEST_F(ConnectTest, Success) {
ASSERT_MEMORY_OK();
}

TEST_F(ConnectTest, Success_v4) {
RunServer([](int sockfd) {
// Perform handshake.
{
char handshake[20];
ASSERT_EQ(RecvData(sockfd, handshake, 20), 0);
ASSERT_EQ(std::string(handshake, 4), "\x60\x60\xB0\x17"s);
ASSERT_EQ(std::string(handshake + 4, 4), "\x00\x00\x01\x04"s);
ASSERT_EQ(std::string(handshake + 8, 4), "\x00\x00\x00\x01"s);
ASSERT_EQ(std::string(handshake + 12, 4), "\x00\x00\x00\x00"s);
ASSERT_EQ(std::string(handshake + 16, 4), "\x00\x00\x00\x00"s);
auto run_v4_server_success = [](int sockfd) {
// Perform handshake.
{
char handshake[20];
ASSERT_EQ(RecvData(sockfd, handshake, 20), 0);
ASSERT_EQ(std::string(handshake, 4), "\x60\x60\xB0\x17"s);
ASSERT_EQ(std::string(handshake + 4, 4), "\x00\x00\x01\x04"s);
ASSERT_EQ(std::string(handshake + 8, 4), "\x00\x00\x00\x01"s);
ASSERT_EQ(std::string(handshake + 12, 4), "\x00\x00\x00\x00"s);
ASSERT_EQ(std::string(handshake + 16, 4), "\x00\x00\x00\x00"s);

uint32_t version = htobe32(0x0104);
ASSERT_EQ(SendData(sockfd, (char *)&version, 4), 0);
}
uint32_t version = htobe32(1);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why changed from 0x0104 to 1?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if 0x0104 is here the tests never finishes

ASSERT_EQ(SendData(sockfd, (char *)&version, 4), 0);
}

mg_session *session = mg_session_init(&mg_system_allocator);
ASSERT_TRUE(session);
session->version = 4;
mg_raw_transport_init(sockfd, (mg_raw_transport **)&session->transport,
&mg_system_allocator);
mg_session *session = mg_session_init(&mg_system_allocator);
ASSERT_TRUE(session);
session->version = 1;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

version from 4 to 1?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea these seems like hacks to make the client work 🤔

mg_raw_transport_init(sockfd, (mg_raw_transport **)&session->transport,
&mg_system_allocator);

// Read HELLO message.
// Read INIT message.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test is changed from sending HELLO msg to INIT msg, any specific reason?

{
mg_message *message;
ASSERT_EQ(mg_session_receive_message(session), 0);
ASSERT_EQ(mg_session_read_bolt_message(session, &message), 0);
ASSERT_EQ(message->type, MG_MESSAGE_TYPE_INIT);

mg_message_init *msg_init = message->init_v;
EXPECT_EQ(
std::string(msg_init->client_name->data, msg_init->client_name->size),
MG_USER_AGENT);
{
mg_message *message;
ASSERT_EQ(mg_session_receive_message(session), 0);
ASSERT_EQ(mg_session_read_bolt_message(session, &message), 0);
ASSERT_EQ(message->type, MG_MESSAGE_TYPE_HELLO);

mg_message_hello *msg_hello = message->hello_v;
{
ASSERT_EQ(mg_map_size(msg_hello->extra), 4u);

const mg_value *user_agent_val =
mg_map_at(msg_hello->extra, "user_agent");
ASSERT_TRUE(user_agent_val);
ASSERT_EQ(mg_value_get_type(user_agent_val), MG_VALUE_TYPE_STRING);
const mg_string *user_agent = mg_value_string(user_agent_val);
ASSERT_EQ(std::string(user_agent->data, user_agent->size),
MG_USER_AGENT);

const mg_value *scheme_val = mg_map_at(msg_hello->extra, "scheme");
ASSERT_TRUE(scheme_val);
ASSERT_EQ(mg_value_get_type(scheme_val), MG_VALUE_TYPE_STRING);
const mg_string *scheme = mg_value_string(scheme_val);
ASSERT_EQ(std::string(scheme->data, scheme->size), "basic");

const mg_value *principal_val =
mg_map_at(msg_hello->extra, "principal");
ASSERT_TRUE(principal_val);
ASSERT_EQ(mg_value_get_type(principal_val), MG_VALUE_TYPE_STRING);
const mg_string *principal = mg_value_string(principal_val);
ASSERT_EQ(std::string(principal->data, principal->size), "user");
ASSERT_EQ(mg_map_size(msg_init->auth_token), 3u);

const mg_value *scheme_val = mg_map_at(msg_init->auth_token, "scheme");
ASSERT_TRUE(scheme_val);
ASSERT_EQ(mg_value_get_type(scheme_val), MG_VALUE_TYPE_STRING);
const mg_string *scheme = mg_value_string(scheme_val);
ASSERT_EQ(std::string(scheme->data, scheme->size), "basic");

const mg_value *principal_val =
mg_map_at(msg_init->auth_token, "principal");
ASSERT_TRUE(principal_val);
ASSERT_EQ(mg_value_get_type(principal_val), MG_VALUE_TYPE_STRING);
const mg_string *principal = mg_value_string(principal_val);
ASSERT_EQ(std::string(principal->data, principal->size), "user");

const mg_value *credentials_val =
mg_map_at(msg_init->auth_token, "credentials");
ASSERT_TRUE(credentials_val);
ASSERT_EQ(mg_value_get_type(credentials_val), MG_VALUE_TYPE_STRING);
const mg_string *credentials = mg_value_string(credentials_val);
ASSERT_EQ(std::string(credentials->data, credentials->size), "pass");
}

mg_message_destroy_ca(message, session->decoder_allocator);
}

const mg_value *credentials_val =
mg_map_at(msg_hello->extra, "credentials");
ASSERT_TRUE(credentials_val);
ASSERT_EQ(mg_value_get_type(credentials_val), MG_VALUE_TYPE_STRING);
const mg_string *credentials = mg_value_string(credentials_val);
ASSERT_EQ(std::string(credentials->data, credentials->size), "pass");
}
// Send SUCCESS message.
ASSERT_EQ(mg_session_send_success_message(session, &mg_empty_map), 0);

mg_message_destroy_ca(message, session->decoder_allocator);
}

// Send SUCCESS message.
ASSERT_EQ(mg_session_send_success_message(session, &mg_empty_map), 0);
mg_session_destroy(session);
};

mg_session_destroy(session);
});
TEST_F(ConnectTest, Success_v4) {
RunServer(run_v4_server_success);
mg_session_params *params = mg_session_params_make();
mg_session_params_set_host(params, "127.0.0.1");
mg_session_params_set_port(params, port);
Expand All @@ -592,70 +588,7 @@ TEST_F(ConnectTest, Success_v4) {
}

TEST_F(ConnectTest, SuccessWithSSL) {
RunServer([](int sockfd) {
// Perform handshake.
{
char handshake[20];
ASSERT_EQ(RecvData(sockfd, handshake, 20), 0);
ASSERT_EQ(std::string(handshake, 4), "\x60\x60\xB0\x17"s);
ASSERT_EQ(std::string(handshake + 4, 4), "\x00\x00\x01\x04"s);
ASSERT_EQ(std::string(handshake + 8, 4), "\x00\x00\x00\x01"s);
ASSERT_EQ(std::string(handshake + 12, 4), "\x00\x00\x00\x00"s);
ASSERT_EQ(std::string(handshake + 16, 4), "\x00\x00\x00\x00"s);

uint32_t version = htobe32(1);
ASSERT_EQ(SendData(sockfd, (char *)&version, 4), 0);
}

mg_session *session = mg_session_init(&mg_system_allocator);
ASSERT_TRUE(session);
session->version = 1;
mg_raw_transport_init(sockfd, (mg_raw_transport **)&session->transport,
&mg_system_allocator);

// Read INIT message.
{
mg_message *message;
ASSERT_EQ(mg_session_receive_message(session), 0);
ASSERT_EQ(mg_session_read_bolt_message(session, &message), 0);
ASSERT_EQ(message->type, MG_MESSAGE_TYPE_INIT);

mg_message_init *msg_init = message->init_v;
EXPECT_EQ(
std::string(msg_init->client_name->data, msg_init->client_name->size),
MG_USER_AGENT);
{
ASSERT_EQ(mg_map_size(msg_init->auth_token), 3u);

const mg_value *scheme_val = mg_map_at(msg_init->auth_token, "scheme");
ASSERT_TRUE(scheme_val);
ASSERT_EQ(mg_value_get_type(scheme_val), MG_VALUE_TYPE_STRING);
const mg_string *scheme = mg_value_string(scheme_val);
ASSERT_EQ(std::string(scheme->data, scheme->size), "basic");

const mg_value *principal_val =
mg_map_at(msg_init->auth_token, "principal");
ASSERT_TRUE(principal_val);
ASSERT_EQ(mg_value_get_type(principal_val), MG_VALUE_TYPE_STRING);
const mg_string *principal = mg_value_string(principal_val);
ASSERT_EQ(std::string(principal->data, principal->size), "user");

const mg_value *credentials_val =
mg_map_at(msg_init->auth_token, "credentials");
ASSERT_TRUE(credentials_val);
ASSERT_EQ(mg_value_get_type(credentials_val), MG_VALUE_TYPE_STRING);
const mg_string *credentials = mg_value_string(credentials_val);
ASSERT_EQ(std::string(credentials->data, credentials->size), "pass");
}

mg_message_destroy_ca(message, session->decoder_allocator);
}

// Send SUCCESS message.
ASSERT_EQ(mg_session_send_success_message(session, &mg_empty_map), 0);

mg_session_destroy(session);
});
RunServer(run_v4_server_success);

mg_secure_transport_init_called = 0;
trust_callback_ok = 0;
Expand All @@ -681,6 +614,22 @@ TEST_F(ConnectTest, SuccessWithSSL) {
ASSERT_MEMORY_OK();
}

TEST_F(ConnectTest, CustomScheme) {
RunServer(run_v4_server_success);
mg_session_params *params = mg_session_params_make();
mg_session_params_set_host(params, "127.0.0.1");
mg_session_params_set_port(params, port);
mg_session_params_set_scheme(params, "custom_scheme");
mg_session_params_set_username(params, "user");
mg_session_params_set_password(params, "pass");
mg_session *session;
ASSERT_EQ(mg_connect_ca(params, &session, (mg_allocator *)&allocator), 0);
EXPECT_EQ(mg_session_status(session), MG_SESSION_READY);
mg_session_params_destroy(params);
mg_session_destroy(session);
ASSERT_MEMORY_OK();
}

class RunTest : public ::testing::Test {
protected:
virtual void SetUp() override {
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/basic_cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class MemgraphConnection : public ::testing::Test {

client = mg::Client::Connect(
{GetEnvOrDefault<std::string>("MEMGRAPH_HOST", "127.0.0.1"),
GetEnvOrDefault<uint16_t>("MEMGRAPH_PORT", 7687), "", "",
GetEnvOrDefault<uint16_t>("MEMGRAPH_PORT", 7687), "basic", "", "",
GetEnvOrDefault<bool>("MEMGRAPH_SSLMODE", false), ""});

ASSERT_TRUE(client);
Expand Down
Loading