diff --git a/contrib/ruby/test/client_test.rb b/contrib/ruby/test/client_test.rb index 8ef7cf6f..dd6d28e2 100644 --- a/contrib/ruby/test/client_test.rb +++ b/contrib/ruby/test/client_test.rb @@ -903,6 +903,18 @@ def test_close_terminate_parent_connection assert_match "TRILOGY_CLOSED_CONNECTION", error.message end + def test_discard_closes_connection + client = new_tcp_client + + assert_equal [1], client.query("SELECT 1").to_a.first + + client.discard! + + assert_raises Trilogy::ConnectionClosed do + client.query("SELECT 1") + end + end + def test_discard_doesnt_terminate_parent_connection skip("Fork isn't supported on this platform") unless Process.respond_to?(:fork) diff --git a/inc/trilogy/socket.h b/inc/trilogy/socket.h index 37086d7a..8debd145 100644 --- a/inc/trilogy/socket.h +++ b/inc/trilogy/socket.h @@ -76,6 +76,7 @@ typedef struct trilogy_sock_t { int (*wait_cb)(struct trilogy_sock_t *self, trilogy_wait_t wait); int (*shutdown_cb)(struct trilogy_sock_t *self); int (*close_cb)(struct trilogy_sock_t *self); + int (*discard_cb)(struct trilogy_sock_t *self); int (*fd_cb)(struct trilogy_sock_t *self); trilogy_sockopt_t opts; @@ -102,6 +103,7 @@ static inline int trilogy_sock_wait_write(trilogy_sock_t *sock) { return sock->w static inline int trilogy_sock_shutdown(trilogy_sock_t *sock) { return sock->shutdown_cb(sock); } static inline int trilogy_sock_close(trilogy_sock_t *sock) { return sock->close_cb(sock); } +static inline int trilogy_sock_discard(trilogy_sock_t *sock) { return sock->discard_cb(sock); } static inline int trilogy_sock_fd(trilogy_sock_t *sock) { return sock->fd_cb(sock); } diff --git a/src/client.c b/src/client.c index 2821753c..72450d56 100644 --- a/src/client.c +++ b/src/client.c @@ -768,6 +768,7 @@ int trilogy_discard(trilogy_conn_t *conn) { int rc = trilogy_sock_discard(conn->socket); if (rc == TRILOGY_OK) { + conn->socket = NULL; trilogy_free(conn); } return rc; diff --git a/src/socket.c b/src/socket.c index e7ade82b..7e39ea31 100644 --- a/src/socket.c +++ b/src/socket.c @@ -109,6 +109,9 @@ static int _cb_raw_close(trilogy_sock_t *_sock) return rc; } +// Close and discard are the same for a raw socket +#define _cb_raw_discard _cb_raw_close + static int _cb_raw_shutdown(trilogy_sock_t *_sock) { return shutdown(trilogy_sock_fd(_sock), SHUT_RDWR); } static int set_nonblocking_fd(int sock) @@ -235,6 +238,7 @@ trilogy_sock_t *trilogy_sock_new(const trilogy_sockopt_t *opts) sock->base.wait_cb = _cb_wait; sock->base.shutdown_cb = _cb_raw_shutdown; sock->base.close_cb = _cb_raw_close; + sock->base.discard_cb = _cb_raw_discard; sock->base.fd_cb = _cb_raw_fd; sock->base.opts = *opts; @@ -347,6 +351,7 @@ static int _cb_ssl_shutdown(trilogy_sock_t *_sock) sock->base.write_cb = _cb_raw_write; sock->base.shutdown_cb = _cb_raw_shutdown; sock->base.close_cb = _cb_raw_close; + sock->base.discard_cb = _cb_raw_discard; sock->ssl = NULL; return _cb_raw_shutdown(_sock); @@ -363,6 +368,16 @@ static int _cb_ssl_close(trilogy_sock_t *_sock) return _cb_raw_close(_sock); } +static int _cb_ssl_discard(trilogy_sock_t *_sock) { + struct trilogy_sock *sock = (struct trilogy_sock *)_sock; + if (sock->ssl != NULL) { + // unlike close, we do not want to send SSL_shutdown + SSL_free(sock->ssl); + sock->ssl = NULL; + } + return _cb_raw_close(_sock); +} + #if OPENSSL_VERSION_NUMBER >= 0x1010000fL static int trilogy_tls_version_map[] = {0, TLS1_VERSION, TLS1_1_VERSION, TLS1_2_VERSION @@ -614,6 +629,7 @@ int trilogy_sock_upgrade_ssl(trilogy_sock_t *_sock) sock->base.write_cb = _cb_ssl_write; sock->base.shutdown_cb = _cb_ssl_shutdown; sock->base.close_cb = _cb_ssl_close; + sock->base.discard_cb = _cb_ssl_discard; return TRILOGY_OK; fail: @@ -621,28 +637,3 @@ int trilogy_sock_upgrade_ssl(trilogy_sock_t *_sock) sock->ssl = NULL; return TRILOGY_OPENSSL_ERR; } - -int trilogy_sock_discard(trilogy_sock_t *_sock) -{ - struct trilogy_sock *sock = (struct trilogy_sock *)_sock; - - if (sock->fd < 0) { - return TRILOGY_OK; - } - - int null_fd = open("/dev/null", O_RDWR | O_CLOEXEC); - if (null_fd < 0) { - return TRILOGY_SYSERR; - } - - if (dup2(null_fd, sock->fd) < 0) { - close(null_fd); - return TRILOGY_SYSERR; - } - - if (close(null_fd) < 0) { - return TRILOGY_SYSERR; - } - - return TRILOGY_OK; -}