diff --git a/.github/actions/go-test-setup/action.yml b/.github/actions/go-test-setup/action.yml index 6b15ea06..7275bc7d 100644 --- a/.github/actions/go-test-setup/action.yml +++ b/.github/actions/go-test-setup/action.yml @@ -7,13 +7,6 @@ runs: shell: bash run: | echo 'CGO_ENABLED=1' >> $GITHUB_ENV - - name: Windows setup - shell: bash - if: ${{ runner.os == 'Windows' }} - run: | - pacman -S --noconfirm mingw-w64-x86_64-toolchain mingw-w64-i686-toolchain - echo '/c/msys64/mingw64/bin' >> $GITHUB_PATH - echo 'PATH_386=/c/msys64/mingw32/bin:${{ env.PATH_386 }}' >> $GITHUB_ENV - name: Linux setup shell: bash if: ${{ runner.os == 'Linux' }} diff --git a/.github/workflows/go-test-ubuntu-22.04.yml b/.github/workflows/go-test-ubuntu-22.04.yml index cb086365..bf16af73 100644 --- a/.github/workflows/go-test-ubuntu-22.04.yml +++ b/.github/workflows/go-test-ubuntu-22.04.yml @@ -27,14 +27,6 @@ jobs: run: | go version go env - - name: Use msys2 on windows - if: startsWith(matrix.os, 'windows') - shell: bash - # The executable for msys2 is also called bash.cmd - # https://github.com/actions/virtual-environments/blob/main/images/win/Windows2019-Readme.md#shells - # If we prepend its location to the PATH - # subsequent 'shell: bash' steps will use msys2 instead of gitbash - run: echo "C:/msys64/usr/bin" >> $GITHUB_PATH - name: Run repo-specific setup uses: ./.github/actions/go-test-setup if: hashFiles('./.github/actions/go-test-setup') != '' @@ -55,7 +47,7 @@ jobs: export "PATH=${{ env.PATH_386 }}:$PATH" go test -v ./... - name: Run tests with race detector - if: startsWith(matrix.os, 'ubuntu') # speed things up. Windows and OSX VMs are slow + if: startsWith(matrix.os, 'ubuntu') # speed things up. OSX VMs is slow uses: protocol/multiple-go-modules@v1.2 with: run: go test -v -race ./... diff --git a/.github/workflows/go-test.yml b/.github/workflows/go-test.yml index 8a1697b2..e5673e86 100644 --- a/.github/workflows/go-test.yml +++ b/.github/workflows/go-test.yml @@ -9,7 +9,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ "ubuntu", "windows", "macos" ] + os: [ "ubuntu", "macos" ] go: [ "1.18.x", "1.19.x" ] env: COVERAGES: "" @@ -26,14 +26,6 @@ jobs: run: | go version go env - - name: Use msys2 on windows - if: ${{ matrix.os == 'windows' }} - shell: bash - # The executable for msys2 is also called bash.cmd - # https://github.com/actions/virtual-environments/blob/main/images/win/Windows2019-Readme.md#shells - # If we prepend its location to the PATH - # subsequent 'shell: bash' steps will use msys2 instead of gitbash - run: echo "C:/msys64/usr/bin" >> $GITHUB_PATH - name: Run repo-specific setup uses: ./.github/actions/go-test-setup if: hashFiles('./.github/actions/go-test-setup') != '' @@ -54,7 +46,7 @@ jobs: export "PATH=${{ env.PATH_386 }}:$PATH" go test -v -shuffle=on ./... - name: Run tests with race detector - if: ${{ matrix.os == 'ubuntu' }} # speed things up. Windows and OSX VMs are slow + if: ${{ matrix.os == 'ubuntu' }} # speed things up. OSX VMs is slow uses: protocol/multiple-go-modules@v1.2 with: run: go test -v -race ./... diff --git a/net.go b/net.go index b2293c7c..f20ae3f7 100644 --- a/net.go +++ b/net.go @@ -15,6 +15,7 @@ package openssl import ( + "context" "errors" "net" "time" @@ -77,8 +78,8 @@ const ( // some certs to the certificate store of the client context you're using. // This library is not nice enough to use the system certificate store by // default for you yet. -func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) { - return DialSession(network, addr, ctx, flags, nil) +func Dial(network, addr string, sslCtx *Ctx, flags DialFlags) (*Conn, error) { + return DialSession(network, addr, sslCtx, flags, nil) } // DialTimeout acts like Dial but takes a timeout for network dial. @@ -87,10 +88,57 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) { // // See func Dial for a description of the network, addr, ctx and flags // parameters. -func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx, +func DialTimeout(network, addr string, timeout time.Duration, sslCtx *Ctx, flags DialFlags) (*Conn, error) { - d := net.Dialer{Timeout: timeout} - return dialSession(d, network, addr, ctx, flags, nil) + host, err := parseHost(addr) + if err != nil { + return nil, err + } + + conn, err := net.DialTimeout(network, addr, timeout) + if err != nil { + return nil, err + } + sslCtx, err = prepareCtx(sslCtx) + if err != nil { + conn.Close() + return nil, err + } + client, err := createSession(conn, flags, host, sslCtx, nil) + if err != nil { + conn.Close() + } + return client, err +} + +// DialContext acts like Dial but takes a context for network dial. +// +// The context includes only network dial. It does not include OpenSSL calls. +// +// See func Dial for a description of the network, addr, ctx and flags +// parameters. +func DialContext(ctx context.Context, network, addr string, + sslCtx *Ctx, flags DialFlags) (*Conn, error) { + host, err := parseHost(addr) + if err != nil { + return nil, err + } + + dialer := net.Dialer{} + conn, err := dialer.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + sslCtx, err = prepareCtx(sslCtx) + if err != nil { + conn.Close() + return nil, err + } + client, err := createSession(conn, flags, host, sslCtx, nil) + if err != nil { + conn.Close() + } + return client, err } // DialSession will connect to network/address and then wrap the corresponding @@ -106,61 +154,78 @@ func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx, // // If session is not nil it will be used to resume the tls state. The session // can be retrieved from the GetSession method on the Conn. -func DialSession(network, addr string, ctx *Ctx, flags DialFlags, +func DialSession(network, addr string, sslCtx *Ctx, flags DialFlags, session []byte) (*Conn, error) { - var d net.Dialer - return dialSession(d, network, addr, ctx, flags, session) -} - -func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags, - session []byte) (*Conn, error) { - host, _, err := net.SplitHostPort(addr) + host, err := parseHost(addr) if err != nil { return nil, err } - if ctx == nil { - var err error - ctx, err = NewCtx() - if err != nil { - return nil, err - } - // TODO: use operating system default certificate chain? - } - c, err := d.Dial(network, addr) + conn, err := net.Dial(network, addr) if err != nil { return nil, err } - conn, err := Client(c, ctx) + sslCtx, err = prepareCtx(sslCtx) if err != nil { - c.Close() + conn.Close() return nil, err } - if session != nil { - err := conn.setSession(session) - if err != nil { - c.Close() - return nil, err - } + client, err := createSession(conn, flags, host, sslCtx, session) + if err != nil { + conn.Close() + } + return client, err +} + +func prepareCtx(sslCtx *Ctx) (*Ctx, error) { + if sslCtx == nil { + return NewCtx() } + return sslCtx, nil +} + +func parseHost(addr string) (string, error) { + host, _, err := net.SplitHostPort(addr) + return host, err +} + +func handshake(conn *Conn, host string, flags DialFlags) error { + var err error if flags&DisableSNI == 0 { err = conn.SetTlsExtHostName(host) if err != nil { - conn.Close() - return nil, err + return err } } err = conn.Handshake() if err != nil { - conn.Close() - return nil, err + return err } if flags&InsecureSkipHostVerification == 0 { err = conn.VerifyHostname(host) + if err != nil { + return err + } + } + return nil +} + +func createSession(c net.Conn, flags DialFlags, host string, sslCtx *Ctx, + session []byte) (*Conn, error) { + conn, err := Client(c, sslCtx) + if err != nil { + return nil, err + } + if session != nil { + err := conn.setSession(session) if err != nil { conn.Close() return nil, err } } + if err := handshake(conn, host, flags); err != nil { + conn.Close() + return nil, err + } return conn, nil } diff --git a/net_test.go b/net_test.go new file mode 100644 index 00000000..808e254c --- /dev/null +++ b/net_test.go @@ -0,0 +1,101 @@ +package openssl_test + +import ( + "context" + "crypto/rand" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/tarantool/go-openssl" +) + +func sslConnect(t *testing.T, ssl_listener net.Listener) { + for { + var err error + conn, err := ssl_listener.Accept() + if err != nil { + t.Errorf("failed accept: %s", err) + continue + } + io.Copy(conn, io.LimitReader(rand.Reader, 1024)) + break + } +} + +func TestDial(t *testing.T) { + ctx := openssl.GetCtx(t) + if err := ctx.SetCipherList("AES128-SHA"); err != nil { + t.Fatal(err) + } + ssl_listener, err := openssl.Listen("tcp", "localhost:0", ctx) + if err != nil { + t.Fatal(err) + } + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + sslConnect(t, ssl_listener) + wg.Done() + }() + + client, err := openssl.Dial(ssl_listener.Addr().Network(), + ssl_listener.Addr().String(), ctx, openssl.InsecureSkipHostVerification) + + wg.Wait() + + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + n, err := io.Copy(io.Discard, io.LimitReader(client, 1024)) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if n != 1024 { + if n == 0 { + t.Fatal("client is closed after creation") + } + t.Fatalf("client lost some bytes, expected %d, got %d", 1024, n) + } +} + +func TestDialTimeout(t *testing.T) { + ctx := openssl.GetCtx(t) + if err := ctx.SetCipherList("AES128-SHA"); err != nil { + t.Fatal(err) + } + ssl_listener, err := openssl.Listen("tcp", "localhost:0", ctx) + if err != nil { + t.Fatal(err) + } + + client, err := openssl.DialTimeout(ssl_listener.Addr().Network(), + ssl_listener.Addr().String(), time.Nanosecond, ctx, 0) + + if client != nil || err == nil { + t.Fatalf("expected error") + } +} + +func TestDialContext(t *testing.T) { + ctx := openssl.GetCtx(t) + if err := ctx.SetCipherList("AES128-SHA"); err != nil { + t.Fatal(err) + } + ssl_listener, err := openssl.Listen("tcp", "localhost:0", ctx) + if err != nil { + t.Fatal(err) + } + + cancelCtx, cancel := context.WithCancel(context.Background()) + cancel() + client, err := openssl.DialContext(cancelCtx, ssl_listener.Addr().Network(), + ssl_listener.Addr().String(), ctx, 0) + + if client != nil || err == nil { + t.Fatalf("expected error") + } +} diff --git a/ssl_test.go b/ssl_test.go index b99e57ec..5eb0c514 100644 --- a/ssl_test.go +++ b/ssl_test.go @@ -738,7 +738,7 @@ func TestStdlibLotsOfConns(t *testing.T) { }) } -func getCtx(t *testing.T) *Ctx { +func GetCtx(t *testing.T) *Ctx { ctx, err := NewCtx() if err != nil { t.Fatal(err) @@ -761,7 +761,7 @@ func getCtx(t *testing.T) *Ctx { } func TestOpenSSLLotsOfConns(t *testing.T) { - ctx := getCtx(t) + ctx := GetCtx(t) if err := ctx.SetCipherList("AES128-SHA"); err != nil { t.Fatal(err) } @@ -928,7 +928,7 @@ func TestOpenSSLLotsOfConnsWithFail(t *testing.T) { t.Run(name, func(t *testing.T) { LotsOfConns(t, 1024*64, 10, 100, 0*time.Second, func(l net.Listener) net.Listener { - return NewListener(l, getCtx(t)) + return NewListener(l, GetCtx(t)) }, func(c net.Conn) (net.Conn, error) { return Client(c, getClientCtx(t)) })