diff --git a/client.go b/client.go index 55f9417..edbad9a 100644 --- a/client.go +++ b/client.go @@ -141,37 +141,6 @@ type Client struct { msgpackUsage msgpackUsage } -var h2CTransport = http2.Transport{ - AllowHTTP: true, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - // Skip TLS dial - return net.DialTimeout(network, addr, 2*time.Second) - }, -} - -var h2Transport = http2.Transport{ - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - dialer := net.Dialer{Timeout: 2 * time.Second} - cn, err := tls.DialWithDialer(&dialer, network, addr, cfg) - if err != nil { - return nil, err - } - if err := cn.Handshake(); err != nil { - return nil, err - } - if !cfg.InsecureSkipVerify { - if err := cn.VerifyHostname(cfg.ServerName); err != nil { - return nil, err - } - } - state := cn.ConnectionState() - if p := state.NegotiatedProtocol; p != http2.NextProtoTLS { - return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, http2.NextProtoTLS) - } - return cn, nil - }, -} - // NewClient creates a RESTful client instance. // The instance has a semi-permanent transport TCP connection. func NewClient() *Client { @@ -223,8 +192,19 @@ func NewClientWInterface(networkInterface string) *Client { // NewH2Client creates a RESTful client instance, forced to use HTTP2 with TLS (H2) (a.k.a. prior knowledge). func NewH2Client() *Client { + return NewH2ClientWInterface("") +} + +// NewH2CClient creates a RESTful client instance, forced to use HTTP2 Cleartext (H2C). +func NewH2CClient() *Client { + return NewH2CClientWInterface("") +} + +// NewH2ClientWInterface creates a RESTful client instance with the http2 protocol bound to that network interface. +// The instance has a semi-permanent transport TCP connection. +func NewH2ClientWInterface(networkInterface string) *Client { c := &Client{Kind: KindH2} - var rt http.RoundTripper = &h2Transport + var rt http.RoundTripper = getH2Transport(networkInterface) if isTraced && tracer.GetOTel() { rt = otelhttp.NewTransport(rt) } @@ -232,10 +212,12 @@ func NewH2Client() *Client { return c } -// NewH2CClient creates a RESTful client instance, forced to use HTTP2 Cleartext (H2C). -func NewH2CClient() *Client { +// NewH2ClientWInterface creates a RESTful client instance with the http2 clear text protocol bound to that network interface. +// In other words, the http2 clear text is the http2 but without TLS handshake. +// The instance has a semi-permanent transport TCP connection. +func NewH2CClientWInterface(networkInterface string) *Client { c := &Client{Kind: KindH2C} - var rt http.RoundTripper = &h2CTransport + var rt http.RoundTripper = getH2CTransport(networkInterface) if isTraced && tracer.GetOTel() { rt = otelhttp.NewTransport(rt) } @@ -243,6 +225,72 @@ func NewH2CClient() *Client { return c } +func getH2Transport(iface string) *http2.Transport { + return &http2.Transport{ + DialTLS: getDialTLSCallback(iface, true), + } +} + +func getH2CTransport(iface string) *http2.Transport { + return &http2.Transport{ + AllowHTTP: true, + DialTLS: getDialTLSCallback(iface, false), + } +} + +func getDialTLSCallback(iface string, withTLS bool) func(string,string,*tls.Config) (net.Conn, error) { + return func(network, addr string, cfg *tls.Config) (net.Conn, error) { + dialer := net.Dialer{Timeout: 2 * time.Second} + + var conn net.Conn + var err error + if iface != "" { + IPs := getIPFromInterface(iface) + if IPs.IPv4 != nil { + dialer.LocalAddr = IPs.IPv4 + conn, err = dialWithDialer(&dialer, network, addr, cfg, withTLS) + } + + // Try IPv6 if IPv4 is unavailable or connection fails. + if IPs.IPv4 == nil || (IPs.IPv6 != nil && err != nil && !errDeadlineOrCancel(err)) { + dialer.LocalAddr = IPs.IPv6 + conn, err = dialWithDialer(&dialer, network, addr, cfg, withTLS) + } + } else { + conn, err = dialWithDialer(&dialer, network, addr, cfg, withTLS) + } + + if err != nil { + return nil, err + } + + // Skip TLS dial if it is the H2C + if withTLS { + if err := conn.(*tls.Conn).Handshake(); err != nil { + return nil, err + } + if !cfg.InsecureSkipVerify { + if err := conn.(*tls.Conn).VerifyHostname(cfg.ServerName); err != nil { + return nil, err + } + } + state := conn.(*tls.Conn).ConnectionState() + if p := state.NegotiatedProtocol; p != http2.NextProtoTLS { + return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, http2.NextProtoTLS) + } + } + return conn, nil + } +} + +func dialWithDialer(dialer *net.Dialer, network, addr string, cfg *tls.Config, withTLS bool) (net.Conn, error) { + if withTLS { + return tls.DialWithDialer(dialer, network, addr, cfg) + } else { + return dialer.Dial(network, addr) + } +} + // UserAgent to be sent as User-Agent HTTP header. If not set then default Go client settings are used. func (c *Client) UserAgent(userAgent string) *Client { c.userAgent = userAgent @@ -377,7 +425,7 @@ func (c *Client) SetOauth2Conf(config oauth2.Config, tokenClient *http.Client, g // SetOauth2H2 makes OAuth2 token client communicate using h2 transport with Authorization Server. func (c *Client) SetOauth2H2() *Client { - c.oauth2.client = &http.Client{Timeout: 10 * time.Second, Transport: &h2Transport} + c.oauth2.client = &http.Client{Timeout: 10 * time.Second, Transport: getH2Transport("")} return c } diff --git a/client_test.go b/client_test.go index 0f9355b..8c00158 100644 --- a/client_test.go +++ b/client_test.go @@ -6,7 +6,10 @@ package restful import ( "context" + "crypto/tls" + "encoding/json" "errors" + "fmt" "io" "net" "net/http" @@ -14,10 +17,13 @@ import ( "net/http/httptest" "net/url" "strings" + "sync" "testing" "time" "github.com/stretchr/testify/assert" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" "golang.org/x/oauth2" ) @@ -855,3 +861,92 @@ func TestCientInterface(t *testing.T) { c := NewClientWInterface(theUsedInterface) assert.NotNil(t, c) } + +func startH2Server(mux *http.ServeMux, wg *sync.WaitGroup) *http.Server { + defer wg.Done() + server := &http.Server{ + Addr: "localhost:8443", + Handler: mux, + TLSConfig: &tls.Config{ + NextProtos: []string{"h2"}, + }, + } + + go func() { + if err := server.ListenAndServeTLS("test_certs/tls.crt", "test_certs/tls.key"); err != nil && err != http.ErrServerClosed { + fmt.Printf("Failed to start server: %v", err) + } + }() + return server +} + +func startH2CServer(mux *http.ServeMux, wg *sync.WaitGroup) *http.Server { + defer wg.Done() + server := &http.Server{ + Addr: "localhost:8440", + Handler: h2c.NewHandler(mux, &http2.Server{}), + } + + go func() { + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + fmt.Printf("Failed to start server: %v", err) + } + }() + return server +} + +func TestClients(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + response := map[string]string{"message": "Hello, world!"} + json.NewEncoder(w).Encode(response) + }) + var wg sync.WaitGroup + wg.Add(2) + h2Server := startH2Server(mux, &wg) + h2cServer := startH2CServer(mux, &wg) + defer func() { + h2Server.Close() + h2cServer.Close() + }() + + h2Client := NewH2Client() + h2Client.Client.Transport.(*http2.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true,} + h2cClient := NewH2CClient() + + wg.Wait() + + tests := []struct { + name string + client *Client + serverURL string + }{ + { + name: "HTTP/2 Client (H2)", + client: h2Client, + serverURL: "https://localhost:8443", // H2 server + }, + { + name: "HTTP/2 Cleartext Client (H2C)", + client: h2cClient, + serverURL: "http://localhost:8440", // H2C server + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + + var resp any + err := test.client.Get(context.Background(), test.serverURL, &resp) + if err != nil { + t.Fatalf("Failed to make request: %v", err) + } + + b, _ := json.Marshal(resp) + if string(b) != "{\"message\":\"Hello, world!\"}" { + t.Fatalf("Unexpected response: %s", b) + } + }) + } +} \ No newline at end of file