Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(netemx,oohelperd): use oohelperd.NewHandler constructor #1468

Merged
merged 1 commit into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 0 additions & 32 deletions internal/netemx/oohelperd.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"github.com/apex/log"
"github.com/ooni/netem"
"github.com/ooni/probe-cli/v3/internal/logx"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/oohelperd"
)
Expand All @@ -26,36 +25,5 @@ func (f *OOHelperDFactory) NewHandler(env NetStackServerFactoryEnv, unet *netem.
Logger: log.Log,
}
handler := oohelperd.NewHandler(logger, netx)

handler.NewDialer = func(logger model.Logger) model.Dialer {
return netx.NewDialerWithResolver(logger, netx.NewStdlibResolver(logger))
}

handler.NewQUICDialer = func(logger model.Logger) model.QUICDialer {
return netx.NewQUICDialerWithResolver(
netx.NewUDPListener(),
logger,
netx.NewStdlibResolver(logger),
)
}

handler.NewResolver = func(logger model.Logger) model.Resolver {
return netx.NewStdlibResolver(logger)
}

handler.NewHTTPClient = func(logger model.Logger) model.HTTPClient {
return oohelperd.NewHTTPClientWithTransportFactory(
netx, logger,
netxlite.NewHTTPTransportWithResolver,
)
}

handler.NewHTTP3Client = func(logger model.Logger) model.HTTPClient {
return oohelperd.NewHTTPClientWithTransportFactory(
netx, logger,
netxlite.NewHTTP3TransportWithResolver,
)
}

return handler
}
93 changes: 49 additions & 44 deletions internal/oohelperd/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,89 +22,94 @@ import (
"golang.org/x/net/publicsuffix"
)

// MaxAcceptableBodySize is the maximum acceptable body size for incoming
// maxAcceptableBodySize is the maximum acceptable body size for incoming
// API requests as well as when we're measuring webpages.
const MaxAcceptableBodySize = 1 << 24
const maxAcceptableBodySize = 1 << 24

// Handler is an [http.Handler] implementing the Web
// Connectivity test helper HTTP API.
//
// The zero value is invalid; construct using [NewHandler].
type Handler struct {
// BaseLogger is the MANDATORY logger to use.
BaseLogger model.Logger
// baseLogger is the MANDATORY logger to use.
baseLogger model.Logger

// CountRequests is the MANDATORY count of the number of
// countRequests is the MANDATORY count of the number of
// requests that are currently in flight.
CountRequests *atomic.Int64
countRequests *atomic.Int64

// Indexer is the MANDATORY atomic integer used to assign an index to requests.
Indexer *atomic.Int64
// indexer is the MANDATORY atomic integer used to assign an index to requests.
indexer *atomic.Int64

// MaxAcceptableBody is the MANDATORY maximum acceptable response body.
MaxAcceptableBody int64
// maxAcceptableBody is the MANDATORY maximum acceptable response body.
maxAcceptableBody int64

// Measure is the MANDATORY function that the handler should call
// measure is the MANDATORY function that the handler should call
// for producing a response for a valid incoming request.
Measure func(ctx context.Context, config *Handler, creq *model.THRequest) (*model.THResponse, error)
measure func(ctx context.Context, config *Handler, creq *model.THRequest) (*model.THResponse, error)

// NewDialer is the MANDATORY factory to create a new Dialer.
NewDialer func(model.Logger) model.Dialer
// newDialer is the MANDATORY factory to create a new Dialer.
newDialer func(model.Logger) model.Dialer

// NewHTTPClient is the MANDATORY factory to create a new HTTPClient.
NewHTTPClient func(model.Logger) model.HTTPClient
// newHTTPClient is the MANDATORY factory to create a new HTTPClient.
newHTTPClient func(model.Logger) model.HTTPClient

// NewHTTP3Client is the MANDATORY factory to create a new HTTP3Client.
NewHTTP3Client func(model.Logger) model.HTTPClient
// newHTTP3Client is the MANDATORY factory to create a new HTTP3Client.
newHTTP3Client func(model.Logger) model.HTTPClient

// NewQUICDialer is the MANDATORY factory to create a new QUICDialer.
NewQUICDialer func(model.Logger) model.QUICDialer
// newQUICDialer is the MANDATORY factory to create a new QUICDialer.
newQUICDialer func(model.Logger) model.QUICDialer

// NewResolver is the MANDATORY factory for creating a new resolver.
NewResolver func(model.Logger) model.Resolver
// newResolver is the MANDATORY factory for creating a new resolver.
newResolver func(model.Logger) model.Resolver

// NewTLSHandshaker is the MANDATORY factory for creating a new TLS handshaker.
NewTLSHandshaker func(model.Logger) model.TLSHandshaker
// newTLSHandshaker is the MANDATORY factory for creating a new TLS handshaker.
newTLSHandshaker func(model.Logger) model.TLSHandshaker
}

var _ http.Handler = &Handler{}

// NewHandler constructs the [handler].
func NewHandler(logger model.Logger, netx *netxlite.Netx) *Handler {
return &Handler{
BaseLogger: logger,
CountRequests: &atomic.Int64{},
Indexer: &atomic.Int64{},
MaxAcceptableBody: MaxAcceptableBodySize,
Measure: measure,
baseLogger: logger,
countRequests: &atomic.Int64{},
indexer: &atomic.Int64{},
maxAcceptableBody: maxAcceptableBodySize,
measure: measure,

NewHTTPClient: func(logger model.Logger) model.HTTPClient {
newHTTPClient: func(logger model.Logger) model.HTTPClient {
// TODO(https://github.com/ooni/probe/issues/2534): the NewHTTPTransportWithResolver has QUIRKS and
// we should evaluate whether we can avoid using it here
return NewHTTPClientWithTransportFactory(
return newHTTPClientWithTransportFactory(
netx, logger,
netxlite.NewHTTPTransportWithResolver,
)
},

NewHTTP3Client: func(logger model.Logger) model.HTTPClient {
return NewHTTPClientWithTransportFactory(
newHTTP3Client: func(logger model.Logger) model.HTTPClient {
return newHTTPClientWithTransportFactory(
netx, logger,
netxlite.NewHTTP3TransportWithResolver,
)
},

NewDialer: func(logger model.Logger) model.Dialer {
newDialer: func(logger model.Logger) model.Dialer {
return netx.NewDialerWithoutResolver(logger)
},
NewQUICDialer: func(logger model.Logger) model.QUICDialer {

newQUICDialer: func(logger model.Logger) model.QUICDialer {
return netx.NewQUICDialerWithoutResolver(
netx.NewUDPListener(),
logger,
)
},
NewResolver: func(logger model.Logger) model.Resolver {

newResolver: func(logger model.Logger) model.Resolver {
return newResolver(logger, netx)
},
NewTLSHandshaker: func(logger model.Logger) model.TLSHandshaker {

newTLSHandshaker: func(logger model.Logger) model.TLSHandshaker {
return netx.NewTLSHandshakerStdlib(logger)
},
}
Expand Down Expand Up @@ -151,16 +156,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}

// protect against too many requests in flight
if handlerShouldThrottleClient(h.CountRequests.Load(), req.Header.Get("user-agent")) {
if handlerShouldThrottleClient(h.countRequests.Load(), req.Header.Get("user-agent")) {
metricRequestsCount.WithLabelValues("503", "service_unavailable").Inc()
w.WriteHeader(503)
return
}
h.CountRequests.Add(1)
defer h.CountRequests.Add(-1)
h.countRequests.Add(1)
defer h.countRequests.Add(-1)

// read and parse request body
reader := io.LimitReader(req.Body, h.MaxAcceptableBody)
reader := io.LimitReader(req.Body, h.maxAcceptableBody)
data, err := netxlite.ReadAllContext(req.Context(), reader)
if err != nil {
metricRequestsCount.WithLabelValues("400", "request_body_too_large").Inc()
Expand All @@ -176,7 +181,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {

// measure the given input
started := time.Now()
cresp, err := h.Measure(req.Context(), h, &creq)
cresp, err := h.measure(req.Context(), h, &creq)
elapsed := time.Since(started)

// track the time required to produce a response
Expand Down Expand Up @@ -219,9 +224,9 @@ func newCookieJar() *cookiejar.Jar {
}))
}

// NewHTTPClientWithTransportFactory creates a new HTTP client
// newHTTPClientWithTransportFactory creates a new HTTP client
// using the given [model.HTTPTransport] factory.
func NewHTTPClientWithTransportFactory(
func newHTTPClientWithTransportFactory(
netx *netxlite.Netx, logger model.Logger,
txpFactory func(*netxlite.Netx, model.DebugLogger, model.Resolver) model.HTTPTransport,
) model.HTTPClient {
Expand Down
4 changes: 2 additions & 2 deletions internal/oohelperd/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,12 @@ func TestHandlerWorkingAsIntended(t *testing.T) {
// create handler and possibly override .Measure
handler := NewHandler(log.Log, &netxlite.Netx{})
if expect.measureFn != nil {
handler.Measure = expect.measureFn
handler.measure = expect.measureFn
}

// configure the CountRequests field if needed
if expect.initialCountRequests > 0 {
handler.CountRequests.Add(expect.initialCountRequests) // 0 + value = value :-)
handler.countRequests.Add(expect.initialCountRequests) // 0 + value = value :-)
}

// create request
Expand Down
20 changes: 10 additions & 10 deletions internal/oohelperd/measure.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ type (
func measure(ctx context.Context, config *Handler, creq *ctrlRequest) (*ctrlResponse, error) {
// create indexed logger
logger := &logx.PrefixLogger{
Prefix: fmt.Sprintf("<#%d> ", config.Indexer.Add(1)),
Logger: config.BaseLogger,
Prefix: fmt.Sprintf("<#%d> ", config.indexer.Add(1)),
Logger: config.baseLogger,
}

// parse input for correctness
Expand All @@ -47,7 +47,7 @@ func measure(ctx context.Context, config *Handler, creq *ctrlRequest) (*ctrlResp
go dnsDo(ctx, &dnsConfig{
Domain: URL.Hostname(),
Logger: logger,
NewResolver: config.NewResolver,
NewResolver: config.newResolver,
Out: dnsch,
Wg: wg,
})
Expand Down Expand Up @@ -91,8 +91,8 @@ func measure(ctx context.Context, config *Handler, creq *ctrlRequest) (*ctrlResp
EnableTLS: endpoint.TLS,
Endpoint: endpoint.Epnt,
Logger: logger,
NewDialer: config.NewDialer,
NewTSLHandshaker: config.NewTLSHandshaker,
NewDialer: config.newDialer,
NewTSLHandshaker: config.newTLSHandshaker,
URLHostname: URL.Hostname(),
Out: tcpconnch,
Wg: wg,
Expand All @@ -105,8 +105,8 @@ func measure(ctx context.Context, config *Handler, creq *ctrlRequest) (*ctrlResp
go httpDo(ctx, &httpConfig{
Headers: creq.HTTPRequestHeaders,
Logger: logger,
MaxAcceptableBody: config.MaxAcceptableBody,
NewClient: config.NewHTTPClient,
MaxAcceptableBody: config.maxAcceptableBody,
NewClient: config.newHTTPClient,
Out: httpch,
URL: creq.HTTPRequest,
Wg: wg,
Expand All @@ -133,7 +133,7 @@ func measure(ctx context.Context, config *Handler, creq *ctrlRequest) (*ctrlResp
Address: endpoint.Addr,
Endpoint: endpoint.Epnt,
Logger: logger,
NewQUICDialer: config.NewQUICDialer,
NewQUICDialer: config.newQUICDialer,
URLHostname: URL.Hostname(),
Out: quicconnch,
Wg: wg,
Expand All @@ -147,8 +147,8 @@ func measure(ctx context.Context, config *Handler, creq *ctrlRequest) (*ctrlResp
go httpDo(ctx, &httpConfig{
Headers: creq.HTTPRequestHeaders,
Logger: logger,
MaxAcceptableBody: config.MaxAcceptableBody,
NewClient: config.NewHTTP3Client,
MaxAcceptableBody: config.maxAcceptableBody,
NewClient: config.newHTTP3Client,
Out: http3ch,
URL: "https://" + cresp.HTTPRequest.DiscoveredH3Endpoint,
Wg: wg,
Expand Down