Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Process sequence of pending packets #580

Merged
merged 19 commits into from
Sep 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG_DEV.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# 0.94.0 - 2022-09-19
- Accumulate packets in a queue and handle paired packets in the correct order. Fixes issue with incorrectly linked Bind packet to inappropriate Parse packet and nil dereferences.

# 0.94.0 - 2022-08-25
- Add support of Hashicorp Consul for `encryptor_config loading`.
- Introduce new Hashicorp Consul flags: `consul_connection_api_string` and `consul_kv_config_path` and corresponded `consul` TLS configuration flags.
Expand Down
3 changes: 3 additions & 0 deletions decryptor/postgresql/packet_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,9 @@ const WithoutMessageType = 0
// ErrUnsupportedPacketType error when recognized unsupported message type or new added to postgresql wire protocol
var ErrUnsupportedPacketType = errors.New("unsupported postgresql message type")

// ErrNilPendingPacket error when took nil instead of pending packet
var ErrNilPendingPacket = errors.New("nil pending packet")

// ReadClientPacket read and recognize packets that may be sent only from client/frontend.
//
// There are two types of messages: startup and general ones.
Expand Down
70 changes: 70 additions & 0 deletions decryptor/postgresql/packet_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ package postgresql
import (
"bufio"
"bytes"
"context"
"encoding/binary"
"encoding/hex"
acracensor "github.com/cossacklabs/acra/acra-censor"
"github.com/cossacklabs/acra/cmd/acra-server/common"
"github.com/cossacklabs/acra/sqlparser"
"testing"

"github.com/cossacklabs/acra/decryptor/base"
Expand Down Expand Up @@ -236,3 +240,69 @@ func TestParseColumns(t *testing.T) {
t.Fatal("Incorrect ")
}
}

func TestSequenceOfParsePackets(t *testing.T) {
// Regression test for T2663 && https://github.com/cossacklabs/acra/issues/575
// used a dump from wireshark and java app with request + response
// We check that Acra correctly handle Bind and BindComplete packets and links Bind to the first Parse packet,
// not to last received

// Parse + Bind + Execute + Parse + Describe + Sync
requestPacket, err := hex.DecodeString("5000000010535f3100424547494e000000420000000f00535f3100000000000000450000000900000000005000000042535f3200494e5345525420494e544f206d797461626c6520286e616d652c20616765292056414c554553202824312c202432290000020000041300000017440000000953535f32005300000004")
if err != nil {
t.Fatal(err)
}
// ParseComplete + BindComplete + CommandComplete + ParseComplete + ParameterDescription + NoData (empty data rows) + ReadyForQuery
responsePacket, err := hex.DecodeString(`31000000043200000004430000000a424547494e003100000004740000000e000200000413000000176e000000045a0000000554`)
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
//emptyStore := &tableSchemaStore{true}
parser := sqlparser.New(sqlparser.ModeStrict)
session, err := common.NewClientSession(ctx, &common.Config{}, nil)
if err != nil {
t.Fatal(err)
}
setting := base.NewProxySetting(parser, nil, nil, nil, acracensor.NewAcraCensor(), nil, false)
proxy, err := NewPgProxy(session, parser, setting)
if err != nil {
t.Fatal(err)
}
reader := bytes.NewReader(requestPacket)
output := &bytes.Buffer{}
writer := bufio.NewWriter(output)
packetHandler, err := NewClientSidePacketHandler(reader, writer, logrus.NewEntry(logrus.StandardLogger()))
if err != nil {
t.Fatal(err)
}
// don't wait startup packet
packetHandler.started = true
for i := 0; i < 6; i++ {
if err := packetHandler.ReadClientPacket(); err != nil {
t.Fatal(err)
}
if censored, err := proxy.handleClientPacket(ctx, packetHandler, logrus.NewEntry(logrus.New())); err != nil {
t.Fatal(err)
} else if censored {
t.Fatal("Should not be censored")
}
}

reader = bytes.NewReader(responsePacket)
output.Reset()
packetHandler, err = NewDbSidePacketHandler(reader, writer, logrus.NewEntry(logrus.StandardLogger()))
if err != nil {
t.Fatal(err)
}

for i := 0; i < 7; i++ {
if err := packetHandler.ReadPacket(); err != nil {
t.Fatal(err)
}
if err := proxy.handleDatabasePacket(ctx, packetHandler, logrus.NewEntry(logrus.New())); err != nil {
t.Fatal(err)
}
}

}
126 changes: 126 additions & 0 deletions decryptor/postgresql/pending_packets.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Copyright 2022, Cossack Labs Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package postgresql

import (
"container/list"
"errors"
"github.com/jackc/pgx/pgproto3"
log "github.com/sirupsen/logrus"
"reflect"
)

// pendingPacketsList stores objects per their type and provides API similar to queue
type pendingPacketsList struct {
lists map[reflect.Type]*list.List
}

func newPendingPacketsList() *pendingPacketsList {
return &pendingPacketsList{lists: make(map[reflect.Type]*list.List)}
}

// ErrUnsupportedPendingPacketType error after using unknown type of structure
var ErrUnsupportedPendingPacketType = errors.New("unsupported pending packet type")

// ErrRemoveFromEmptyPendingList error after trying to remove object from empty list
var ErrRemoveFromEmptyPendingList = errors.New("removing from empty pending list")

// Add packet to pending list of packets of this type
func (packets *pendingPacketsList) Add(packet interface{}) error {
switch packet.(type) {
case *ParsePacket, *BindPacket, *ExecutePacket, *pgproto3.RowDescription, *pgproto3.ParameterDescription:
packetList, ok := packets.lists[reflect.TypeOf(packet)]
if !ok {
packetList = list.New()
packets.lists[reflect.TypeOf(packet)] = packetList
}
log.WithField("packet", packet).Debugln("Add pending packet")
packetList.PushBack(packet)
return nil
}
return ErrUnsupportedPendingPacketType
}

// RemoveNextPendingPacket removes first in the list pending packet
func (packets *pendingPacketsList) RemoveNextPendingPacket(packet interface{}) error {
switch packet.(type) {
case *ParsePacket, *BindPacket, *ExecutePacket, *pgproto3.RowDescription, *pgproto3.ParameterDescription:
packetList, ok := packets.lists[reflect.TypeOf(packet)]
if !ok {
return ErrRemoveFromEmptyPendingList
}
currentElement := packetList.Front()
if currentElement == nil {
return nil
}
log.WithField("packet", currentElement.Value).Debugln("Remove pending packet")
packetList.Remove(currentElement)
return nil
}
return ErrUnsupportedPendingPacketType
}

// RemoveAll pending packets of packet's type
func (packets *pendingPacketsList) RemoveAll(packet interface{}) error {
switch packet.(type) {
case *ParsePacket, *BindPacket, *ExecutePacket, *pgproto3.RowDescription, *pgproto3.ParameterDescription:
packetList, ok := packets.lists[reflect.TypeOf(packet)]
if !ok {
return nil
}
log.Debugln("Remove all pending packets")
packetList.Init()
return nil
}
return ErrUnsupportedPendingPacketType
}

// GetPendingPacket returns next pending packet
func (packets *pendingPacketsList) GetPendingPacket(packet interface{}) (interface{}, error) {
switch packet.(type) {
case *ParsePacket, *BindPacket, *ExecutePacket, *pgproto3.RowDescription, *pgproto3.ParameterDescription:
packetList, ok := packets.lists[reflect.TypeOf(packet)]
if !ok {
return nil, nil
}
currentElement := packetList.Front()
if currentElement == nil {
return nil, nil
}
log.WithField("packet", currentElement.Value).Debugln("Return pending packet")
return currentElement.Value, nil
}
return nil, ErrUnsupportedPendingPacketType
}

// GetLastPending return last added pending packet
func (packets *pendingPacketsList) GetLastPending(packet interface{}) (interface{}, error) {
switch packet.(type) {
case *ParsePacket, *BindPacket, *ExecutePacket, *pgproto3.RowDescription, *pgproto3.ParameterDescription:
packetList, ok := packets.lists[reflect.TypeOf(packet)]
if !ok {
return nil, nil
}
currentElement := packetList.Back()
if currentElement == nil {
return nil, nil
}
log.WithField("packet", currentElement.Value).Debugln("Return last added packet")
return currentElement.Value, nil
}
return nil, ErrUnsupportedPendingPacketType
}
94 changes: 94 additions & 0 deletions decryptor/postgresql/pending_packets_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package postgresql

import (
"testing"
)

func Test_newPendingPackets(t *testing.T) {
pendingPackets := newPendingPacketsList()
if packet, err := pendingPackets.GetPendingPacket(&BindPacket{}); err != nil {
t.Fatal(err)
} else if packet != nil {
t.Fatal("Packet should be nil")
}

if packet, err := pendingPackets.GetLastPending(&BindPacket{}); err != nil {
t.Fatal(err)
} else if packet != nil {
t.Fatal("Packet should be nil")
}

if err := pendingPackets.Add(&BindPacket{portal: "portal1"}); err != nil {
t.Fatal(err)
}

if packet, err := pendingPackets.GetPendingPacket(&BindPacket{}); err != nil {
t.Fatal(err)
} else if packet == nil {
t.Fatal("Packet should not be nil")
} else if packet.(*BindPacket).portal != "portal1" {
t.Fatal("Unexpected value")
}

if packet, err := pendingPackets.GetLastPending(&BindPacket{}); err != nil {
t.Fatal(err)
} else if packet == nil {
t.Fatal("Packet should not be nil")
} else if packet.(*BindPacket).portal != "portal1" {
t.Fatal("Unexpected value")
}

if err := pendingPackets.Add(&BindPacket{portal: "portal2"}); err != nil {
t.Fatal(err)
}

if packet, err := pendingPackets.GetPendingPacket(&BindPacket{}); err != nil {
t.Fatal(err)
} else if packet == nil {
t.Fatal("Packet should not be nil")
} else if packet.(*BindPacket).portal != "portal1" {
t.Fatal("Unexpected value")
}

if packet, err := pendingPackets.GetLastPending(&BindPacket{}); err != nil {
t.Fatal(err)
} else if packet == nil {
t.Fatal("Packet should not be nil")
} else if packet.(*BindPacket).portal != "portal2" {
t.Fatal("Unexpected value")
}

if err := pendingPackets.RemoveNextPendingPacket(&BindPacket{}); err != nil {
t.Fatal(err)
}
if packet, err := pendingPackets.GetPendingPacket(&BindPacket{}); err != nil {
t.Fatal(err)
} else if packet == nil {
t.Fatal("Packet should not be nil")
} else if packet.(*BindPacket).portal != "portal2" {
t.Fatal("Unexpected value")
}

if packet, err := pendingPackets.GetLastPending(&BindPacket{}); err != nil {
t.Fatal(err)
} else if packet == nil {
t.Fatal("Packet should not be nil")
} else if packet.(*BindPacket).portal != "portal2" {
t.Fatal("Unexpected value")
}

if err := pendingPackets.RemoveAll(&BindPacket{}); err != nil {
t.Fatal(err)
}
if packet, err := pendingPackets.GetPendingPacket(&BindPacket{}); err != nil {
t.Fatal(err)
} else if packet != nil {
t.Fatal("Packet should be nil")
}

if packet, err := pendingPackets.GetLastPending(&BindPacket{}); err != nil {
t.Fatal(err)
} else if packet != nil {
t.Fatal("Packet should be nil")
}
}
Loading