Skip to content

Commit

Permalink
Don't abort connection in mysql after encoding error (#535)
Browse files Browse the repository at this point in the history
* Refactor sending error to client into a function

why waste time say lot word when few word do trick

* Add states into database handler

Right now it just replaces the `firstPacket` variable with a
separate state. But in the future it can be extended to support other
states to model a state machine if needed.

* Add skipping packets in stateSkipResponse

* Add tests for reusing connection after error

* Update changelog

[skip ci]
  • Loading branch information
G1gg1L3s authored May 11, 2022
1 parent ba63733 commit 30d12e4
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 34 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG_DEV.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# 0.93.0 - 2022-05-10
- Don't abort connection of mysql after encoding error.

# 0.93.0 - 2022-05-04
- Add mysql support for `response_on_fail` options.

Expand Down
102 changes: 68 additions & 34 deletions decryptor/mysql/response_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,21 @@ const (
TypeGeometry
)

type databaseHandlerState int

const (
// stateFirstPacket is the starting state of the handler. Most of the time
// it is the same as `serve` expect allows setting some parameters of
// connection.
stateFirstPacket databaseHandlerState = iota
// stateServe is the most common state of the handler. It means normal
// processing of packets
stateServe
// stateSkipResponse is a state of a handler when it skips a select response
// from database
stateSkipResponse
)

// ResponseHandler database response header
type ResponseHandler func(ctx context.Context, packet *Packet, dbConnection, clientConnection net.Conn) error

Expand Down Expand Up @@ -296,9 +311,8 @@ func (handler *Handler) ProxyClientConnection(ctx context.Context, errCh chan<-
"for connections AcraServer->Database and CA certificate which will be used to verify certificate " +
"from database")
handler.logger.Debugln("Send error to db")
errPacket := NewQueryInterruptedError(handler.clientProtocol41, QueryExecutionWasInterrupted)
packet.SetData(errPacket)
if _, err := handler.clientConnection.Write(packet.Dump()); err != nil {

if err := handler.sendClientError(QueryExecutionWasInterrupted, packet); err != nil {
handler.logger.WithError(err).WithField(logging.FieldKeyEventCode, logging.EventCodeErrorResponseConnectorCantWriteToClient).
Debugln("Can't write response with error to client")
}
Expand Down Expand Up @@ -395,9 +409,7 @@ func (handler *Handler) ProxyClientConnection(ctx context.Context, errCh chan<-
if err := handler.acracensor.HandleQuery(query); err != nil {
censorSpan.End()
clientLog.WithError(err).WithField(logging.FieldKeyEventCode, logging.EventCodeErrorCensorQueryIsNotAllowed).Errorln("Error on AcraCensor check")
errPacket := NewQueryInterruptedError(handler.clientProtocol41, QueryExecutionWasInterrupted)
packet.SetData(errPacket)
if _, err := handler.clientConnection.Write(packet.Dump()); err != nil {
if err := handler.sendClientError(QueryExecutionWasInterrupted, packet); err != nil {
handler.logger.WithError(err).WithField(logging.FieldKeyEventCode, logging.EventCodeErrorResponseConnectorCantWriteToClient).
Errorln("Can't write response with error to client")
}
Expand Down Expand Up @@ -790,7 +802,7 @@ func (handler *Handler) ProxyDatabaseConnection(ctx context.Context, errCh chan<
defer span.End()
serverLog := handler.logger.WithField("proxy", "server")
serverLog.Debugln("Start proxy db responses")
firstPacket := true
var state databaseHandlerState = stateFirstPacket
var responseHandler ResponseHandler
// use pointers to function where should be stored some function that should be called if code return error and interrupt loop
// default value empty func to avoid != nil check
Expand Down Expand Up @@ -847,42 +859,64 @@ func (handler *Handler) ProxyDatabaseConnection(ctx context.Context, errCh chan<
if packet.IsErr() {
handler.resetQueryHandler()
}
if firstPacket {
firstPacket = false

switch state {
case stateSkipResponse:
// Ok, EOF or ERROR packets are the last in the query response
// sequence. After them, we should continue serving.
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response.html
last := packet.IsErr() || packet.IsEOF()
if last {
state = stateServe
}
serverLog.WithField("last", last).Debugln("Skipping the packet")
continue
case stateFirstPacket:
state = stateServe
handler.serverProtocol41 = packet.ServerSupportProtocol41()
serverLog.Debugf("Set support protocol 41 %v", handler.serverProtocol41)
}
// reset previously matched zoneID
accessContext := base.AccessContextFromContext(ctx)
accessContext.SetZoneID(nil)
responseHandler = handler.getResponseHandler()
err = responseHandler(ctx, packet, handler.dbConnection, handler.clientConnection)

// EncodingError is the only one that we should forward to the client
if encodingError, ok := err.(*base.EncodingError); ok {
handler.logger.WithError(encodingError).Debugln("Sending encoding error to the client")
errPacket := NewQueryInterruptedError(handler.clientProtocol41, encodingError.Error())
packet.SetData(errPacket)
if _, err := handler.clientConnection.Write(packet.Dump()); err != nil {
handler.logger.WithError(err).
WithField(logging.FieldKeyEventCode, logging.EventCodeErrorResponseConnectorCantWriteToClient).
Debugln("Can't write response with error to client")
fallthrough

case stateServe:
// reset previously matched zoneID
accessContext := base.AccessContextFromContext(ctx)
accessContext.SetZoneID(nil)
responseHandler = handler.getResponseHandler()
err = responseHandler(ctx, packet, handler.dbConnection, handler.clientConnection)

// EncodingError is the only one that we should forward to the client
if encodingError, ok := err.(*base.EncodingError); ok {
handler.logger.WithError(err).Debugln("Sending encoding error to the client")
if err := handler.sendClientError(encodingError.Error(), packet); err != nil {
handler.logger.WithError(err).
WithField(logging.FieldKeyEventCode, logging.EventCodeErrorResponseConnectorCantWriteToClient).
Debugln("Can't write response with error to client")
errCh <- base.NewDBProxyError(err)
return
}
// Now we should flush the rest of the database packets because
// the client doesn't expect them
state = stateSkipResponse
continue
}

if err != nil {
handler.resetQueryHandler()
errCh <- base.NewDBProxyError(err)
return
}
// Continue serving packet, though we should skip them till the end
// of the response.
continue
}

if err != nil {
handler.resetQueryHandler()
errCh <- base.NewDBProxyError(err)
return
}
}
}

// sendClientError sends an `QueryInterruptedError` with a custom message
func (handler *Handler) sendClientError(msg string, packet *Packet) error {
errPacket := NewQueryInterruptedError(handler.clientProtocol41, msg)
packet.SetData(errPacket)
_, err := handler.clientConnection.Write(packet.Dump())
return err
}

// AddClientIDObserver subscribe new observer for clientID changes
func (handler *Handler) AddClientIDObserver(observer base.ClientIDObserver) {
handler.clientIDObserverManager.AddClientIDObserver(observer)
Expand Down
11 changes: 11 additions & 0 deletions tests/encryptor_configs/transparent_type_aware_decryption.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,14 @@ schemas:
- column: value_empty_str
data_type: "str"
response_on_fail: ciphertext

- table: test_proper_db_flushing_on_error
columns:
- id
- value_bytes

encrypted:
# We want to test the proper workflow, so one field will do the trick
- column: value_bytes
data_type: "bytes"
response_on_fail: error
104 changes: 104 additions & 0 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9755,6 +9755,110 @@ async def _testPlainConnectionAfterDeny():
loop = asyncio.new_event_loop() # create new to avoid concurrent usage of the loop in the current thread and allow parallel execution in the future
loop.run_until_complete(_testPlainConnectionAfterDeny())

class TestMySQLDbFlushingOnError(BaseBinaryMySQLTestCase, BaseTransparentEncryption):
encryptor_table = sa.Table(
'test_proper_db_flushing_on_error', sa.MetaData(),
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('value_bytes', sa.LargeBinary),
)
ENCRYPTOR_CONFIG = get_encryptor_config('tests/encryptor_configs/transparent_type_aware_decryption.yaml')

def checkSkip(self):
if not (TEST_MYSQL and TEST_WITH_TLS):
self.skipTest("Test only for MySQL with TLS")

def testConnectionIsNotAborted(self):
"""
Test that connection is not closed in case of "encoding error". Test
that we can reuse connection for queries after.
"""

self.encryptor_table.create(bind=self.engine_raw, checkfirst=True)
# Insert data that will trigger decryption error
corrupted_data = {
'id': get_random_id(),
'value_bytes': random_bytes(),
}
self.engine_raw.execute(self.encryptor_table.insert(), corrupted_data)

with self.engine1.connect() as conn:
# Insert some data
data = {
'id': get_random_id(),
'value_bytes': random_bytes(),
}
conn.execute(self.encryptor_table.insert(), data)

query = sa \
.select([self.encryptor_table]) \
.where(self.encryptor_table.c.id == data['id'])
row = conn.execute(query).fetchone()
self.assertEqual(data['value_bytes'], row['value_bytes'])

# Expect "encoding error"
ids = (data['id'], corrupted_data['id'])
query = sa \
.select([self.encryptor_table]) \
.where(self.encryptor_table.c.id.in_(ids))

with self.assertRaises(sa.exc.OperationalError) as ex:
conn.execute(query).fetchall()
self.assertEqual(ex.exception.orig.args, (1317, 'encoding error in column "value_bytes"'))

# Insert and select new data using the same connection to be sure
# it doesn't close or get out of sync
data = {
'id': get_random_id(),
'value_bytes': random_bytes(),
}
conn.execute(self.encryptor_table.insert(), data)

query = sa \
.select([self.encryptor_table]) \
.where(self.encryptor_table.c.id == data['id'])
row = conn.execute(query).fetchone()
self.assertEqual(data['value_bytes'], row['value_bytes'])

def testTransactionRollback(self):
"""
Test that connection is not closed in case of "encoding error" and
sqlaclchemy can do rollback in transaction after that.
"""

self.encryptor_table.create(bind=self.engine_raw, checkfirst=True)
# Insert data that will trigger decryption error
corrupted_data = {
'id': get_random_id(),
'value_bytes': random_bytes(),
}
self.engine_raw.execute(self.encryptor_table.insert(), corrupted_data)
data = {
'id': get_random_id(),
'value_bytes': random_bytes(),
}
select_data = sa \
.select([self.encryptor_table]) \
.where(self.encryptor_table.c.id == data['id'])

with self.assertRaises(sa.exc.OperationalError) as ex:
with self.engine1.begin() as conn:
conn.execute(self.encryptor_table.insert(), data)

row = conn.execute(select_data).fetchone()
self.assertEqual(data['value_bytes'], row['value_bytes'])

# Expect "encoding error"
ids = (data['id'], corrupted_data['id'])
query = sa \
.select([self.encryptor_table]) \
.where(self.encryptor_table.c.id.in_(ids))

conn.execute(query).fetchall()

self.assertEqual(ex.exception.orig.args, (1317, 'encoding error in column "value_bytes"'))
row = self.engine1.execute(select_data).fetchone()
self.assertEqual(row, None)


if __name__ == '__main__':
import xmlrunner
Expand Down

0 comments on commit 30d12e4

Please sign in to comment.