-
Notifications
You must be signed in to change notification settings - Fork 250
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pull request 338: AGDNS-1982 Fix BeforeRequestHandler
Squashed commit of the following: commit bc4e3c6 Merge: bb485bd 0368683 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Tue Apr 9 16:18:01 2024 +0300 Merge branch 'master' into AGDNS-1982-fix-before-handler commit bb485bd Merge: c9672b6 480eb52 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Mon Apr 8 18:03:33 2024 +0300 Merge branch 'master' into AGDNS-1982-fix-before-handler commit c9672b6 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Mon Apr 8 16:52:22 2024 +0300 proxy: imp tests commit 28d2c84 Merge: 7259240 0e2cfca Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Mon Apr 8 16:50:18 2024 +0300 Merge branch 'master' into AGDNS-1982-fix-before-handler commit 7259240 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Fri Apr 5 19:45:48 2024 +0300 peoxy: imp code, docs commit f226aff Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Fri Apr 5 19:28:15 2024 +0300 proxy: fix test, imp code commit 3533c3f Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Fri Apr 5 19:13:40 2024 +0300 proxy: add servfail commit 2dae2da Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Fri Apr 5 19:04:32 2024 +0300 proxy: imp code commit 2f776e9 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Fri Apr 5 17:18:34 2024 +0300 proxy: imp tests, code, docs commit 5ce78d6 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Fri Apr 5 17:18:34 2024 +0300 proxy: imp tests commit 486766b Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Fri Apr 5 16:21:39 2024 +0300 proxy: make before handler an interface
- Loading branch information
1 parent
0368683
commit 9b75951
Showing
5 changed files
with
230 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
package proxy | ||
|
||
import ( | ||
"fmt" | ||
|
||
"github.com/AdguardTeam/golibs/errors" | ||
"github.com/AdguardTeam/golibs/log" | ||
"github.com/miekg/dns" | ||
) | ||
|
||
// BeforeRequestError is an error that signals that the request should be | ||
// responded with the given response message. | ||
type BeforeRequestError struct { | ||
// Err is the error that caused the response. It must not be nil. | ||
Err error | ||
|
||
// Response is the response message to be sent to the client. It must be a | ||
// valid response message. | ||
Response *dns.Msg | ||
} | ||
|
||
// type check | ||
var _ error = (*BeforeRequestError)(nil) | ||
|
||
// Error implements the [error] interface for *BeforeRequestError. | ||
func (e *BeforeRequestError) Error() (msg string) { | ||
return fmt.Sprintf("%s; respond with %s", e.Err, dns.RcodeToString[e.Response.Rcode]) | ||
} | ||
|
||
// type check | ||
var _ errors.Wrapper = (*BeforeRequestError)(nil) | ||
|
||
// Unwrap implements the [errors.Wrapper] interface for *BeforeRequestError. | ||
func (e *BeforeRequestError) Unwrap() (unwrapped error) { | ||
return e.Err | ||
} | ||
|
||
// BeforeRequestHandler is an object that can handle the request before it's | ||
// processed by [Proxy]. | ||
type BeforeRequestHandler interface { | ||
// HandleBefore is called before each DNS request is started processing. | ||
// The passed [DNSContext] contains the Req, Addr, and IsLocalClient fields | ||
// set accordingly. | ||
// | ||
// If returned err is a [BeforeRequestError], the given response message is | ||
// used, on any other error a SERVFAIL response used. If err is nil, the | ||
// request is processed further. [Proxy] assumes a handler itself doesn't | ||
// set the [DNSContext.Res] field. | ||
HandleBefore(p *Proxy, dctx *DNSContext) (err error) | ||
} | ||
|
||
// noopRequestHandler is a no-op implementation of [BeforeRequestHandler] that | ||
// always returns nil. | ||
type noopRequestHandler struct{} | ||
|
||
// type check | ||
var _ BeforeRequestHandler = noopRequestHandler{} | ||
|
||
// HandleBefore implements the [BeforeRequestHandler] interface for | ||
// noopRequestHandler. | ||
func (noopRequestHandler) HandleBefore(_ *Proxy, _ *DNSContext) (err error) { | ||
return nil | ||
} | ||
|
||
// handleBefore calls the [BeforeRequestHandler] if it's set and returns true if | ||
// the request should be processed further. It sets the SERVFAIL response to | ||
// [DNSContext.Res] if an error returned, or the [BeforeRequestError.Response] | ||
// on an appropriate error. | ||
func (p *Proxy) handleBefore(d *DNSContext) (cont bool) { | ||
err := p.beforeRequestHandler.HandleBefore(p, d) | ||
if err == nil { | ||
return true | ||
} | ||
|
||
log.Debug("dnsproxy: handling before request: %s", err) | ||
|
||
if befReqErr := (&BeforeRequestError{}); errors.As(err, &befReqErr) { | ||
d.Res = befReqErr.Response | ||
} else { | ||
d.Res = p.messages.NewMsgSERVFAIL(d.Req) | ||
} | ||
|
||
p.logDNSMessage(d.Res) | ||
p.respond(d) | ||
|
||
return false | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
package proxy | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"net" | ||
"testing" | ||
"time" | ||
|
||
"github.com/AdguardTeam/dnsproxy/upstream" | ||
"github.com/AdguardTeam/golibs/errors" | ||
"github.com/AdguardTeam/golibs/netutil" | ||
"github.com/AdguardTeam/golibs/testutil" | ||
"github.com/miekg/dns" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
// testBeforeRequestHandler is a mock before request handler implementation to | ||
// simplify testing. | ||
type testBeforeRequestHandler struct { | ||
onHandleBefore func(p *Proxy, dctx *DNSContext) (err error) | ||
} | ||
|
||
// type check | ||
var _ BeforeRequestHandler = (*testBeforeRequestHandler)(nil) | ||
|
||
// HandleBefore implements the [BeforeRequestHandler] interface for | ||
// *testBeforeRequestHandler. | ||
func (h *testBeforeRequestHandler) HandleBefore(p *Proxy, dctx *DNSContext) (err error) { | ||
return h.onHandleBefore(p, dctx) | ||
} | ||
|
||
func TestProxy_HandleDNSRequest_beforeRequestHandler(t *testing.T) { | ||
t.Parallel() | ||
|
||
const ( | ||
allowedID = iota | ||
failedID | ||
errorID | ||
) | ||
|
||
allowedRequest := (&dns.Msg{}).SetQuestion("allowed.", dns.TypeA) | ||
allowedRequest.Id = allowedID | ||
allowedResponse := (&dns.Msg{}).SetReply(allowedRequest) | ||
|
||
failedRequest := (&dns.Msg{}).SetQuestion("failed.", dns.TypeA) | ||
failedRequest.Id = failedID | ||
|
||
errorRequest := (&dns.Msg{}).SetQuestion("error.", dns.TypeA) | ||
errorRequest.Id = errorID | ||
errorResponse := (&dns.Msg{}).SetReply(errorRequest) | ||
|
||
p := mustNew(t, &Config{ | ||
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)}, | ||
UpstreamConfig: &UpstreamConfig{ | ||
Upstreams: []upstream.Upstream{&fakeUpstream{ | ||
onExchange: func(m *dns.Msg) (resp *dns.Msg, err error) { | ||
return allowedResponse.Copy(), nil | ||
}, | ||
onAddress: func() (addr string) { return "general" }, | ||
onClose: func() (err error) { return nil }, | ||
}}, | ||
}, | ||
TrustedProxies: defaultTrustedProxies, | ||
PrivateSubnets: netutil.SubnetSetFunc(netutil.IsLocallyServed), | ||
BeforeRequestHandler: &testBeforeRequestHandler{ | ||
onHandleBefore: func(p *Proxy, dctx *DNSContext) (err error) { | ||
switch dctx.Req.Id { | ||
case allowedID: | ||
return nil | ||
case failedID: | ||
return errors.Error("servfail") | ||
case errorID: | ||
return &BeforeRequestError{ | ||
Err: errors.Error("just error"), | ||
Response: errorResponse, | ||
} | ||
default: | ||
panic(fmt.Sprintf("unexpected request id: %d", dctx.Req.Id)) | ||
} | ||
}, | ||
}, | ||
}) | ||
ctx := context.Background() | ||
require.NoError(t, p.Start(ctx)) | ||
testutil.CleanupAndRequireSuccess(t, func() (err error) { return p.Shutdown(ctx) }) | ||
|
||
client := &dns.Client{ | ||
Net: string(ProtoTCP), | ||
Timeout: 200 * time.Millisecond, | ||
} | ||
addr := p.Addr(ProtoTCP).String() | ||
|
||
testCases := []struct { | ||
req *dns.Msg | ||
wantResp *dns.Msg | ||
name string | ||
}{{ | ||
req: allowedRequest, | ||
wantResp: allowedResponse, | ||
name: "allowed", | ||
}, { | ||
req: failedRequest, | ||
wantResp: p.messages.NewMsgSERVFAIL(failedRequest), | ||
name: "failed", | ||
}, { | ||
req: errorRequest, | ||
wantResp: errorResponse, | ||
name: "error", | ||
}} | ||
|
||
for _, tc := range testCases { | ||
t.Run(tc.name, func(t *testing.T) { | ||
t.Parallel() | ||
|
||
resp, _, err := client.Exchange(tc.req, addr) | ||
require.NoError(t, err) | ||
assert.Equal(t, tc.wantResp, resp) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters