From 003b201c96501bc733b55a4412e17043e74887a0 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Sun, 22 May 2022 14:41:48 +0200 Subject: [PATCH] feat(tls): add stdlib-aware Client-like factory This commit adds a factory that works like tls.Client and, in particular, takes in input a crypto/tls.Config pointer. We convert the config we receive to the config defined in this module by the tls package. If we don't know how to convert specific fields and those fields have a nonzero value, then we return an error. See https://github.com/ooni/probe/issues/2106, which documents the effort of integrating this code inside ooni/probe-cli. --- tls/stdlibwrapper.go | 77 +++++++++++++++++++++++++++++++++++++++ tls/stdlibwrapper_test.go | 56 ++++++++++++++++++++++++++++ 2 files changed, 133 insertions(+) create mode 100644 tls/stdlibwrapper.go create mode 100644 tls/stdlibwrapper_test.go diff --git a/tls/stdlibwrapper.go b/tls/stdlibwrapper.go new file mode 100644 index 00000000..bc9bfb0f --- /dev/null +++ b/tls/stdlibwrapper.go @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: BSD-3-Clause + +package tls + +import ( + "crypto/tls" + "errors" + "fmt" + "net" + "reflect" +) + +// ErrIncompatibleStdlibConfig is returned when NewClientConnStdlib is +// passed an incompatible config (i.e., a config containing fields that +// we don't know how to convert and whose value is not ~zero). +var ErrIncompatibleStdlibConfig = errors.New("ootls: incompatible stdlib config") + +// NewClientConnStdlib is like Client but takes in input a *crypto/tls.Config +// rather than a *github.com/ooni/oocrypto/tls.Config. +// +// The config cannot be nil: users must set either ServerName or +// InsecureSkipVerify in the config. +// +// This function will return ErrIncompatibleStdlibConfig if unsupported +// fields have a nonzero value, because the resulting Conn will not +// be compatible with the configuration you provided us with. +// +// We currently support these fields: +// +// - DynamicRecordSizingDisabled +// +// - InsecureSkipVerify +// +// - MaxVersion +// +// - MinVersion +// +// - NextProtos +// +// - RootCAs +// +// - ServerName +func NewClientConnStdlib(conn net.Conn, config *tls.Config) (*Conn, error) { + supportedFields := map[string]bool{ + "DynamicRecordSizingDisabled": true, + "InsecureSkipVerify": true, + "MaxVersion": true, + "MinVersion": true, + "NextProtos": true, + "RootCAs": true, + "ServerName": true, + } + value := reflect.ValueOf(config).Elem() + kind := value.Type() + for idx := 0; idx < value.NumField(); idx++ { + field := value.Field(idx) + if field.IsZero() { + continue + } + fieldKind := kind.Field(idx) + if supportedFields[fieldKind.Name] { + continue + } + err := fmt.Errorf("%w: field %s is nonzero", ErrIncompatibleStdlibConfig, fieldKind.Name) + return nil, err + } + ourConfig := &Config{ + DynamicRecordSizingDisabled: config.DynamicRecordSizingDisabled, + InsecureSkipVerify: config.InsecureSkipVerify, + MaxVersion: config.MaxVersion, + MinVersion: config.MinVersion, + NextProtos: config.NextProtos, + RootCAs: config.RootCAs, + ServerName: config.ServerName, + } + return Client(conn, ourConfig), nil +} diff --git a/tls/stdlibwrapper_test.go b/tls/stdlibwrapper_test.go new file mode 100644 index 00000000..1f3fdc6c --- /dev/null +++ b/tls/stdlibwrapper_test.go @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: BSD-3-Clause + +package tls + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "net" + "testing" + "time" +) + +func TestNewClientConnStdlib(t *testing.T) { + tests := []struct { + name string + config *tls.Config + err error + }{{ + name: "with only supported config fields", + config: &tls.Config{ + DynamicRecordSizingDisabled: true, + RootCAs: x509.NewCertPool(), + ServerName: "ooni.org", + InsecureSkipVerify: true, + MinVersion: VersionTLS10, + MaxVersion: VersionTLS13, + NextProtos: []string{"h3"}, + }, + err: nil, + }, { + name: "with unsupported fields", + config: &tls.Config{ + Time: func() time.Time { + return time.Now() + }, + }, + err: ErrIncompatibleStdlibConfig, + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conn, err := net.Dial("udp", "8.8.8.8:443") // we just want a valid conn + if err != nil { + t.Fatal(err) + } + defer conn.Close() + got, err := NewClientConnStdlib(conn, tt.config) + if !errors.Is(err, tt.err) { + t.Fatal("unexpected error", err) + } + if err == nil && got == nil { + t.Fatal("expected non-nil conn here") + } + }) + } +}