diff --git a/include/PgSQL_Data_Stream.h b/include/PgSQL_Data_Stream.h index d830fbfda..5b7dbc454 100644 --- a/include/PgSQL_Data_Stream.h +++ b/include/PgSQL_Data_Stream.h @@ -136,7 +136,7 @@ class PgSQL_Data_Stream AUTHENTICATION_METHOD auth_method = AUTHENTICATION_METHOD::NO_PASSWORD; uint32_t auth_next_pkt_type = 0; bool auth_received_startup = false; - + unsigned char tmp_login_salt[4]; ScramState* scram_state; unsigned int connect_tries; diff --git a/lib/PgSQL_Protocol.cpp b/lib/PgSQL_Protocol.cpp index 40ff303e9..513f0b1f9 100644 --- a/lib/PgSQL_Protocol.cpp +++ b/lib/PgSQL_Protocol.cpp @@ -1,16 +1,7 @@ -//#include "openssl/rand.h" + +#include #include "proxysql.h" #include "cpp.h" -/* -#include "re2/re2.h" -#include "re2/regexp.h" -#include "MySQL_PreparedStatement.h" - - -#include "MySQL_LDAP_Authentication.hpp" -#include "MySQL_Variables.h" -#include -*/ #include "PgSQL_Authentication.h" #include "PgSQL_Data_Stream.h" #include "PgSQL_Protocol.h" @@ -408,7 +399,15 @@ bool PgSQL_Protocol::generate_pkt_initial_handshake(bool send, void** _ptr, unsi pgpkt.write_generic(type, "i", PG_PKT_AUTH_PLAIN); break; case AUTHENTICATION_METHOD::MD5_PASSWORD: - pgpkt.write_generic(type, "i", PG_PKT_AUTH_MD5); + memset((*myds)->tmp_login_salt, 0, sizeof((*myds)->tmp_login_salt)); + if (RAND_bytes((*myds)->tmp_login_salt, sizeof((*myds)->tmp_login_salt)) != 1) { + // Fallback method: using a basic pseudo-random generator + srand((unsigned int)time(NULL)); + for (int i = 0; i < sizeof((*myds)->tmp_login_salt); i++) { + (*myds)->tmp_login_salt[i] = rand() % 256; + } + } + pgpkt.write_generic(type, "ib", PG_PKT_AUTH_MD5, (*myds)->tmp_login_salt, sizeof((*myds)->tmp_login_salt)); break; case AUTHENTICATION_METHOD::SASL_SCRAM_SHA_256: pgpkt.write_generic(type, "iss", PG_PKT_AUTH_SASL, "SCRAM-SHA-256", ""); @@ -781,8 +780,55 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* if (password) { - proxy_debug(PROXY_DEBUG_MYSQL_AUTH, 5, "Session=%p , DS=%p , user='%s' , auth_method=%d\n", (*myds), (*myds)->sess, user, (int)(*myds)->auth_method); + proxy_debug(PROXY_DEBUG_MYSQL_AUTH, 5, "Session=%p , DS=%p , user='%s' , auth_method=%s\n", (*myds), (*myds)->sess, user, AUTHENTICATION_METHOD_STR[(int)(*myds)->auth_method]); switch ((*myds)->auth_method) { + case AUTHENTICATION_METHOD::MD5_PASSWORD: + { + uint32_t pass_len = hdr.data.size; + pass = (char*)malloc(pass_len + 1); + memcpy(pass, hdr.data.ptr, pass_len); + pass[pass_len] = 0; + + using_password = (pass_len > 0); + + if (pass_len) { + if (pass[pass_len - 1] == 0) { + pass_len--; // remove the extra 0 if present + } + } + + if (!pass || *pass == '\0') { + proxy_debug(PROXY_DEBUG_MYSQL_AUTH, 5, "Session=%p , DS=%p , user='%s'. Empty password returned by client.\n", (*myds), (*myds)->sess, user); + generate_error_packet(true, false, "empty password returned by client", PGSQL_ERROR_CODES::ERRCODE_PROTOCOL_VIOLATION, true); + break; + } + + unsigned char md5_digest[MD5_DIGEST_LENGTH]; + char md5_string[MD5_DIGEST_LENGTH * 2 + sizeof((*myds)->tmp_login_salt)]; + MD5_CTX md5_context; + // needs to be precalculated and stored in DB + MD5_Init(&md5_context); + MD5_Update(&md5_context, password, strlen(password)); + MD5_Update(&md5_context, user, strlen(user)); + MD5_Final(md5_digest, &md5_context); + for (int i = 0; i < MD5_DIGEST_LENGTH; i++) { + sprintf(&md5_string[i * 2], "%02x", (unsigned int)md5_digest[i]); + } + // + memcpy(md5_string+(MD5_DIGEST_LENGTH*2), (*myds)->tmp_login_salt, sizeof((*myds)->tmp_login_salt)); + MD5_Init(&md5_context); + MD5_Update(&md5_context, md5_string, (MD5_DIGEST_LENGTH*2)+sizeof((*myds)->tmp_login_salt)); + MD5_Final(md5_digest, &md5_context); + memcpy(md5_string, "md5", 3); + for (int i = 0, j = 3; i < MD5_DIGEST_LENGTH; i++, j+=2) { + sprintf(&md5_string[j], "%02x", (unsigned int)md5_digest[i]); + } + + if (strlen(md5_string) == pass_len && strcmp(md5_string, pass) == 0) { + ret = EXECUTION_STATE::SUCCESSFUL; + } + } + break; case AUTHENTICATION_METHOD::CLEAR_TEXT_PASSWORD: { uint32_t pass_len = hdr.data.size; @@ -804,7 +850,7 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* break; } - if (strcmp(password, pass) == 0) { + if (strlen(password) == pass_len && strcmp(password, pass) == 0) { ret = EXECUTION_STATE::SUCCESSFUL; } } @@ -903,6 +949,7 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* break; default: proxy_debug(PROXY_DEBUG_MYSQL_AUTH, 5, "Session=%p , DS=%p , user='%s' . goto __exit_process_pkt_handshake_response . Unknown auth method\n", (*myds), (*myds)->sess, user); + //generate_error_packet(true, false, "authentication method not supported", PGSQL_ERROR_CODES::ERRCODE_PROTOCOL_VIOLATION, true); break; } } else { diff --git a/lib/PgSQL_Session.cpp b/lib/PgSQL_Session.cpp index b0ce1c90e..dff8a1079 100644 --- a/lib/PgSQL_Session.cpp +++ b/lib/PgSQL_Session.cpp @@ -2845,7 +2845,7 @@ int PgSQL_Session::get_pkts_from_client(bool& wrong_pass, PtrSize_t& pkt) { return handler_ret; } else { - proxy_error("Not implemented yet"); + proxy_error("Not implemented yet\n"); assert(0); } } @@ -3036,7 +3036,7 @@ int PgSQL_Session::get_pkts_from_client(bool& wrong_pass, PtrSize_t& pkt) { return handler_ret; break; default: - proxy_error("Not implemented yet"); + proxy_error("Not implemented yet\n"); assert(0); } } @@ -4344,9 +4344,8 @@ void PgSQL_Session::handler___status_CONNECTING_CLIENT___STATE_SERVER_HANDSHAKE( } l_free(pkt->size, pkt->ptr); //if (client_myds->encrypted==false) { - if (client_myds->myconn->userinfo->dbname == NULL) { - client_myds->myconn->userinfo->set_dbname(default_schema); - } + assert(client_myds->myconn->userinfo->dbname); + int free_users = 0; int used_users = 0; if ( diff --git a/lib/PgSQL_Thread.cpp b/lib/PgSQL_Thread.cpp index ebfeb6ff0..a199d1dc3 100644 --- a/lib/PgSQL_Thread.cpp +++ b/lib/PgSQL_Thread.cpp @@ -2057,7 +2057,7 @@ char** PgSQL_Threads_Handler::get_variables_list() { // initialize VariablesPointers_int // it is safe to do it here because get_variables_list() is the first function called during start time if (VariablesPointers_int.size() == 0) { - VariablesPointers_int["authentication_method"] = make_tuple(&variables.authentication_method, 0, 4, false); + VariablesPointers_int["authentication_method"] = make_tuple(&variables.authentication_method, 1, 3, false); // Monitor variables VariablesPointers_int["monitor_history"] = make_tuple(&variables.monitor_history, 1000, 7 * 24 * 3600 * 1000, false);