Skip to content

Commit

Permalink
GODRIVER-3260 [master] Code hardening (#1691)
Browse files Browse the repository at this point in the history
Co-authored-by: Qingyang Hu <103950869+qingyang-hu@users.noreply.github.com>
Co-authored-by: Qingyang Hu <qingyang.hu@mongodb.com>
Co-authored-by: Matt Dale <9760375+matthewdale@users.noreply.github.com>
  • Loading branch information
4 people authored Jun 29, 2024
1 parent ea15f7d commit f1f7050
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 44 deletions.
2 changes: 1 addition & 1 deletion bson/default_value_decoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func intDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Va
case reflect.Int64:
return reflect.ValueOf(i64), nil
case reflect.Int:
if int64(int(i64)) != i64 { // Can we fit this inside of an int
if i64 > math.MaxInt { // Can we fit this inside of an int
return emptyValue, fmt.Errorf("%d overflows int", i64)
}

Expand Down
4 changes: 2 additions & 2 deletions bson/extjson_wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ func (ejv *extJSONValue) parseBinary() (b []byte, subType byte, err error) {
return nil, 0, fmt.Errorf("$binary subType value should be string, but instead is %s", val.t)
}

i, err := strconv.ParseInt(val.v.(string), 16, 64)
i, err := strconv.ParseUint(val.v.(string), 16, 8)
if err != nil {
return nil, 0, fmt.Errorf("invalid $binary subType string: %s", val.v.(string))
return nil, 0, fmt.Errorf("invalid $binary subType string: %q: %w", val.v.(string), err)
}

subType = byte(i)
Expand Down
8 changes: 6 additions & 2 deletions bson/uint_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,15 @@ func (uic *uintCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Typ

return reflect.ValueOf(uint64(i64)), nil
case reflect.Uint:
if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint
if i64 < 0 {
return emptyValue, fmt.Errorf("%d overflows uint", i64)
}
v := uint64(i64)
if v > math.MaxUint { // Can we fit this inside of an uint
return emptyValue, fmt.Errorf("%d overflows uint", i64)
}

return reflect.ValueOf(uint(i64)), nil
return reflect.ValueOf(uint(v)), nil
default:
return emptyValue, ValueDecoderError{
Name: "UintDecodeValue",
Expand Down
12 changes: 5 additions & 7 deletions bson/value_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ func (vr *valueReader) peekLength() (int32, error) {
}

idx := vr.offset
return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil
return int32(binary.LittleEndian.Uint32(vr.d[idx:])), nil
}

func (vr *valueReader) readLength() (int32, error) { return vr.readi32() }
Expand All @@ -851,7 +851,7 @@ func (vr *valueReader) readi32() (int32, error) {

idx := vr.offset
vr.offset += 4
return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil
return int32(binary.LittleEndian.Uint32(vr.d[idx:])), nil
}

func (vr *valueReader) readu32() (uint32, error) {
Expand All @@ -861,7 +861,7 @@ func (vr *valueReader) readu32() (uint32, error) {

idx := vr.offset
vr.offset += 4
return (uint32(vr.d[idx]) | uint32(vr.d[idx+1])<<8 | uint32(vr.d[idx+2])<<16 | uint32(vr.d[idx+3])<<24), nil
return binary.LittleEndian.Uint32(vr.d[idx:]), nil
}

func (vr *valueReader) readi64() (int64, error) {
Expand All @@ -871,8 +871,7 @@ func (vr *valueReader) readi64() (int64, error) {

idx := vr.offset
vr.offset += 8
return int64(vr.d[idx]) | int64(vr.d[idx+1])<<8 | int64(vr.d[idx+2])<<16 | int64(vr.d[idx+3])<<24 |
int64(vr.d[idx+4])<<32 | int64(vr.d[idx+5])<<40 | int64(vr.d[idx+6])<<48 | int64(vr.d[idx+7])<<56, nil
return int64(binary.LittleEndian.Uint64(vr.d[idx:])), nil
}

func (vr *valueReader) readu64() (uint64, error) {
Expand All @@ -882,6 +881,5 @@ func (vr *valueReader) readu64() (uint64, error) {

idx := vr.offset
vr.offset += 8
return uint64(vr.d[idx]) | uint64(vr.d[idx+1])<<8 | uint64(vr.d[idx+2])<<16 | uint64(vr.d[idx+3])<<24 |
uint64(vr.d[idx+4])<<32 | uint64(vr.d[idx+5])<<40 | uint64(vr.d[idx+6])<<48 | uint64(vr.d[idx+7])<<56, nil
return binary.LittleEndian.Uint64(vr.d[idx:]), nil
}
4 changes: 2 additions & 2 deletions etc/run-atlas-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ set +x
# Get the atlas secrets.
. ${DRIVERS_TOOLS}/.evergreen/secrets_handling/setup-secrets.sh drivers/atlas_connect

echo "Running cmd/testatlas/main.go"
go run ./internal/cmd/testatlas/main.go "$ATLAS_REPL" "$ATLAS_SHRD" "$ATLAS_FREE" "$ATLAS_TLS11" "$ATLAS_TLS12" "$ATLAS_SERVERLESS" "$ATLAS_SRV_REPL" "$ATLAS_SRV_SHRD" "$ATLAS_SRV_FREE" "$ATLAS_SRV_TLS11" "$ATLAS_SRV_TLS12" "$ATLAS_SRV_SERVERLESS" >> test.suite
echo "Running cmd/testatlas"
go test -v -run ^TestAtlas$ go.mongodb.org/mongo-driver/internal/cmd/testatlas -args "$ATLAS_REPL" "$ATLAS_SHRD" "$ATLAS_FREE" "$ATLAS_TLS11" "$ATLAS_TLS12" "$ATLAS_SERVERLESS" "$ATLAS_SRV_REPL" "$ATLAS_SRV_SHRD" "$ATLAS_SRV_FREE" "$ATLAS_SRV_TLS11" "$ATLAS_SRV_TLS12" "$ATLAS_SRV_SERVERLESS" >> test.suite
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"errors"
"flag"
"fmt"
"os"
"testing"
"time"

"go.mongodb.org/mongo-driver/bson"
Expand All @@ -19,15 +21,19 @@ import (
"go.mongodb.org/mongo-driver/mongo/options"
)

func main() {
func TestMain(m *testing.M) {
flag.Parse()
os.Exit(m.Run())
}

func TestAtlas(t *testing.T) {
uris := flag.Args()
ctx := context.Background()

fmt.Printf("Running atlas tests for %d uris\n", len(uris))
t.Logf("Running atlas tests for %d uris\n", len(uris))

for idx, uri := range uris {
fmt.Printf("Running test %d\n", idx)
t.Logf("Running test %d\n", idx)

// Set a low server selection timeout so we fail fast if there are errors.
clientOpts := options.Client().
Expand All @@ -36,18 +42,18 @@ func main() {

// Run basic connectivity test.
if err := runTest(ctx, clientOpts); err != nil {
panic(fmt.Sprintf("error running test with TLS at index %d: %v", idx, err))
t.Fatalf("error running test with TLS at index %d: %v", idx, err)
}

// Run the connectivity test with InsecureSkipVerify to ensure SNI is done correctly even if verification is
// disabled.
clientOpts.TLSConfig.InsecureSkipVerify = true
if err := runTest(ctx, clientOpts); err != nil {
panic(fmt.Sprintf("error running test with tlsInsecure at index %d: %v", idx, err))
t.Fatalf("error running test with tlsInsecure at index %d: %v", idx, err)
}
}

fmt.Println("Finished!")
t.Logf("Finished!")
}

func runTest(ctx context.Context, clientOpts *options.ClientOptions) error {
Expand Down
7 changes: 6 additions & 1 deletion internal/logger/io_sink.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package logger
import (
"encoding/json"
"io"
"math"
"sync"
"time"
)
Expand Down Expand Up @@ -36,7 +37,11 @@ func NewIOSink(out io.Writer) *IOSink {

// Info will write a JSON-encoded message to the io.Writer.
func (sink *IOSink) Info(_ int, msg string, keysAndValues ...interface{}) {
kvMap := make(map[string]interface{}, len(keysAndValues)/2+2)
mapSize := len(keysAndValues) / 2
if math.MaxInt-mapSize >= 2 {
mapSize += 2
}
kvMap := make(map[string]interface{}, mapSize)

kvMap[KeyTimestamp] = time.Now().UnixNano()
kvMap[KeyMessage] = msg
Expand Down
15 changes: 14 additions & 1 deletion mongo/options/clientoptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"errors"
"fmt"
"io/ioutil"
"math"
"net"
"net/http"
"strings"
Expand Down Expand Up @@ -1142,7 +1143,19 @@ func addClientCertFromSeparateFiles(cfg *tls.Config, keyFile, certFile, keyPassw
return "", err
}

data := make([]byte, 0, len(keyData)+len(certData)+1)
keySize := len(keyData)
if keySize > 64*1024*1024 {
return "", errors.New("X.509 key must be less than 64 MiB")
}
certSize := len(certData)
if certSize > 64*1024*1024 {
return "", errors.New("X.509 certificate must be less than 64 MiB")
}
dataSize := int64(keySize) + int64(certSize) + 1
if dataSize > math.MaxInt {
return "", errors.New("size overflow")
}
data := make([]byte, 0, int(dataSize))
data = append(data, keyData...)
data = append(data, '\n')
data = append(data, certData...)
Expand Down
4 changes: 4 additions & 0 deletions mongo/writeconcern/writeconcern.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ package writeconcern
import (
"errors"
"fmt"
"math"

"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
Expand Down Expand Up @@ -157,6 +158,9 @@ func (wc *WriteConcern) MarshalBSONValue() (bson.Type, []byte, error) {
return 0, nil, ErrInconsistent
}

if w > math.MaxInt32 {
return 0, nil, fmt.Errorf("%d overflows int32", w)
}
elems = bsoncore.AppendInt32Element(elems, "w", int32(w))
case string:
elems = bsoncore.AppendStringElement(elems, "w", w)
Expand Down
40 changes: 18 additions & 22 deletions x/bsonx/bsoncore/bsoncore.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package bsoncore

import (
"bytes"
"encoding/binary"
"fmt"
"math"
"strconv"
Expand Down Expand Up @@ -705,17 +706,16 @@ func ReserveLength(dst []byte) (int32, []byte) {

// UpdateLength updates the length at index with length and returns the []byte.
func UpdateLength(dst []byte, index, length int32) []byte {
dst[index] = byte(length)
dst[index+1] = byte(length >> 8)
dst[index+2] = byte(length >> 16)
dst[index+3] = byte(length >> 24)
binary.LittleEndian.PutUint32(dst[index:], uint32(length))
return dst
}

func appendLength(dst []byte, l int32) []byte { return appendi32(dst, l) }

func appendi32(dst []byte, i32 int32) []byte {
return append(dst, byte(i32), byte(i32>>8), byte(i32>>16), byte(i32>>24))
b := []byte{0, 0, 0, 0}
binary.LittleEndian.PutUint32(b, uint32(i32))
return append(dst, b...)
}

// ReadLength reads an int32 length from src and returns the length and the remaining bytes. If
Expand All @@ -733,51 +733,47 @@ func readi32(src []byte) (int32, []byte, bool) {
if len(src) < 4 {
return 0, src, false
}
return (int32(src[0]) | int32(src[1])<<8 | int32(src[2])<<16 | int32(src[3])<<24), src[4:], true
return int32(binary.LittleEndian.Uint32(src)), src[4:], true
}

func appendi64(dst []byte, i64 int64) []byte {
return append(dst,
byte(i64), byte(i64>>8), byte(i64>>16), byte(i64>>24),
byte(i64>>32), byte(i64>>40), byte(i64>>48), byte(i64>>56),
)
b := []byte{0, 0, 0, 0, 0, 0, 0, 0}
binary.LittleEndian.PutUint64(b, uint64(i64))
return append(dst, b...)
}

func readi64(src []byte) (int64, []byte, bool) {
if len(src) < 8 {
return 0, src, false
}
i64 := (int64(src[0]) | int64(src[1])<<8 | int64(src[2])<<16 | int64(src[3])<<24 |
int64(src[4])<<32 | int64(src[5])<<40 | int64(src[6])<<48 | int64(src[7])<<56)
return i64, src[8:], true
return int64(binary.LittleEndian.Uint64(src)), src[8:], true
}

func appendu32(dst []byte, u32 uint32) []byte {
return append(dst, byte(u32), byte(u32>>8), byte(u32>>16), byte(u32>>24))
b := []byte{0, 0, 0, 0}
binary.LittleEndian.PutUint32(b, u32)
return append(dst, b...)
}

func readu32(src []byte) (uint32, []byte, bool) {
if len(src) < 4 {
return 0, src, false
}

return (uint32(src[0]) | uint32(src[1])<<8 | uint32(src[2])<<16 | uint32(src[3])<<24), src[4:], true
return binary.LittleEndian.Uint32(src), src[4:], true
}

func appendu64(dst []byte, u64 uint64) []byte {
return append(dst,
byte(u64), byte(u64>>8), byte(u64>>16), byte(u64>>24),
byte(u64>>32), byte(u64>>40), byte(u64>>48), byte(u64>>56),
)
b := []byte{0, 0, 0, 0, 0, 0, 0, 0}
binary.LittleEndian.PutUint64(b, u64)
return append(dst, b...)
}

func readu64(src []byte) (uint64, []byte, bool) {
if len(src) < 8 {
return 0, src, false
}
u64 := (uint64(src[0]) | uint64(src[1])<<8 | uint64(src[2])<<16 | uint64(src[3])<<24 |
uint64(src[4])<<32 | uint64(src[5])<<40 | uint64(src[6])<<48 | uint64(src[7])<<56)
return u64, src[8:], true
return binary.LittleEndian.Uint64(src), src[8:], true
}

// keep in sync with readcstringbytes
Expand Down

0 comments on commit f1f7050

Please sign in to comment.