Skip to content

Commit

Permalink
Support routing by ListenerID or TLS server name
Browse files Browse the repository at this point in the history
  • Loading branch information
folbricht committed Dec 28, 2022
1 parent baf5126 commit 8117a90
Show file tree
Hide file tree
Showing 18 changed files with 118 additions and 60 deletions.
2 changes: 1 addition & 1 deletion cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ func minTTL(answer *dns.Msg) (uint32, bool) {
type AnswerShuffleFunc func(*dns.Msg)

// Randomly re-order the A/AAAA answer records.
func AnswerShuffleRandon(msg *dns.Msg) {
func AnswerShuffleRandom(msg *dns.Msg) {
if len(msg.Answer) < 2 {
return
}
Expand Down
3 changes: 3 additions & 0 deletions cmd/routedns/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type resolver struct {
CA string
ClientKey string `toml:"client-key"`
ClientCrt string `toml:"client-crt"`
ServerName string `toml:"server-name"` // TLS server name presented in the server certificate
BootstrapAddr string `toml:"bootstrap-address"`
LocalAddr string `toml:"local-address"`
EDNS0UDPSize uint16 `toml:"edns0-udp-size"` // UDP resolver option
Expand Down Expand Up @@ -160,6 +161,8 @@ type route struct {
Invert bool // Invert the result of the match
DoHPath string `toml:"doh-path"` // DoH query path if received over DoH (regexp)
Resolver string
Listener string // ID of the listener that received the original request
TLSServerName string `toml:"servername"` // TLS servername
}

// LoadConfig reads a config file and returns the decoded structure.
Expand Down
12 changes: 6 additions & 6 deletions cmd/routedns/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func start(opt options, args []string) error {
resolver, ok := resolvers[l.Resolver]
// All Listeners should route queries (except the admin service).
if !ok && l.Protocol != "admin" {
return fmt.Errorf("listener '%s' references non-existant resolver, group or router '%s'", id, l.Resolver)
return fmt.Errorf("listener '%s' references non-existent resolver, group or router '%s'", id, l.Resolver)
}
allowedNet, err := parseCIDRList(l.AllowedNet)
if err != nil {
Expand Down Expand Up @@ -301,7 +301,7 @@ func instantiateGroup(id string, g group, resolvers map[string]rdns.Resolver) er
for _, rid := range g.Resolvers {
resolver, ok := resolvers[rid]
if !ok {
return fmt.Errorf("group '%s' references non-existant resolver or group '%s'", id, rid)
return fmt.Errorf("group '%s' references non-existent resolver or group '%s'", id, rid)
}
gr = append(gr, resolver)
}
Expand Down Expand Up @@ -547,7 +547,7 @@ func instantiateGroup(id string, g group, resolvers map[string]rdns.Resolver) er
switch g.CacheAnswerShuffle {
case "": // default
case "random":
shuffleFunc = rdns.AnswerShuffleRandon
shuffleFunc = rdns.AnswerShuffleRandom
case "round-robin":
shuffleFunc = rdns.AnswerShuffleRoundRobin
default:
Expand Down Expand Up @@ -693,7 +693,7 @@ func instantiateGroup(id string, g group, resolvers map[string]rdns.Resolver) er
if len(gr) != 1 {
return fmt.Errorf("type response-collapse only supports one resolver in '%s'", id)
}
opt := rdns.ResponseCollapsOptions{
opt := rdns.ResponseCollapseOptions{
NullRCode: g.NullRCode,
}
resolvers[id] = rdns.NewResponseCollapse(id, gr[0], opt)
Expand Down Expand Up @@ -724,13 +724,13 @@ func instantiateRouter(id string, r router, resolvers map[string]rdns.Resolver)
for _, route := range r.Routes {
resolver, ok := resolvers[route.Resolver]
if !ok {
return fmt.Errorf("router '%s' references non-existant resolver or group '%s'", id, route.Resolver)
return fmt.Errorf("router '%s' references non-existent resolver or group '%s'", id, route.Resolver)
}
types := route.Types
if route.Type != "" { // Support the deprecated "Type" by just adding it to "Types" if defined
types = append(types, route.Type)
}
r, err := rdns.NewRoute(route.Name, route.Class, types, route.Weekdays, route.Before, route.After, route.Source, route.DoHPath, resolver)
r, err := rdns.NewRoute(route.Name, route.Class, types, route.Weekdays, route.Before, route.After, route.Source, route.DoHPath, route.Listener, route.TLSServerName, resolver)
if err != nil {
return fmt.Errorf("failure parsing routes for router '%s' : %s", id, err.Error())
}
Expand Down
6 changes: 3 additions & 3 deletions cmd/routedns/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func instantiateResolver(id string, r resolver, resolvers map[string]rdns.Resolv
case "doq":
r.Address = rdns.AddressWithDefault(r.Address, rdns.DoQPort)

tlsConfig, err := rdns.TLSClientConfig(r.CA, r.ClientCrt, r.ClientKey)
tlsConfig, err := rdns.TLSClientConfig(r.CA, r.ClientCrt, r.ClientKey, r.ServerName)
if err != nil {
return err
}
Expand All @@ -31,7 +31,7 @@ func instantiateResolver(id string, r resolver, resolvers map[string]rdns.Resolv
case "dot":
r.Address = rdns.AddressWithDefault(r.Address, rdns.DoTPort)

tlsConfig, err := rdns.TLSClientConfig(r.CA, r.ClientCrt, r.ClientKey)
tlsConfig, err := rdns.TLSClientConfig(r.CA, r.ClientCrt, r.ClientKey, r.ServerName)
if err != nil {
return err
}
Expand Down Expand Up @@ -64,7 +64,7 @@ func instantiateResolver(id string, r resolver, resolvers map[string]rdns.Resolv
case "doh":
r.Address = rdns.AddressWithDefault(r.Address, rdns.DoHPort)

tlsConfig, err := rdns.TLSClientConfig(r.CA, r.ClientCrt, r.ClientKey)
tlsConfig, err := rdns.TLSClientConfig(r.CA, r.ClientCrt, r.ClientKey, r.ServerName)
if err != nil {
return err
}
Expand Down
17 changes: 13 additions & 4 deletions dnslistener.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package rdns

import (
"crypto/tls"
"net"

"github.com/miekg/dns"
Expand Down Expand Up @@ -49,10 +50,18 @@ func (s DNSListener) String() string {
func listenHandler(id, protocol, addr string, r Resolver, allowedNet []*net.IPNet) dns.HandlerFunc {
metrics := NewListenerMetrics("listener", id)
return func(w dns.ResponseWriter, req *dns.Msg) {
var (
ci ClientInfo
err error
)
var err error

ci := ClientInfo{
Listener: id,
}

if r, ok := w.(interface{ ConnectionState() *tls.ConnectionState }); ok {
connState := r.ConnectionState()
if connState != nil {
ci.TLSServerName = connState.ServerName
}
}

switch addr := w.RemoteAddr().(type) {
case *net.TCPAddr:
Expand Down
3 changes: 3 additions & 0 deletions doc/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,8 @@ A route has the following fields:
- `before` - Time of day in the format HH:mm before which the rule matches. Uses 24h format. For example `17:30`.
- `invert` - Invert the result of the matching if set to `true`. Optional.
- `doh-path` - Regexp that matches on the DoH query path the client used.
- `listener` - Regexp that matches on the ID of the listener that first received.
- `servername` - Regexp that matches on the TLS server name used in the TLS handshake with the listener.
- `resolver` - The identifier of a resolver, group, or another router. Required.

Examples:
Expand Down Expand Up @@ -1348,6 +1350,7 @@ Secure resolvers such as DoT, DoH, or DoQ offer additional options to configure
- `client-crt` - Client certificate file.
- `client-key` - Client certificate key file
- `ca` - CA certificate to validate server certificates.
- `server-name` - Name of the certificate presented by the server if it does not match the name in the endpoint address.

Examples:

Expand Down
10 changes: 8 additions & 2 deletions dohlistener.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,15 @@ func (s *DoHListener) parseAndRespond(b []byte, w http.ResponseWriter, r *http.R
http.Error(w, "Invalid RemoteAddr", http.StatusBadRequest)
return
}
var tlsServerName string
if r.TLS != nil {
tlsServerName = r.TLS.ServerName
}
ci := ClientInfo{
SourceIP: clientIP,
DoHPath: r.URL.Path,
SourceIP: clientIP,
DoHPath: r.URL.Path,
TLSServerName: tlsServerName,
Listener: s.id,
}
log := Log.WithFields(logrus.Fields{
"id": s.id,
Expand Down
6 changes: 3 additions & 3 deletions dohlistener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestDoHListenerSimple(t *testing.T) {
time.Sleep(time.Second)

// Make a client talking to the listener using POST
tlsConfig, err := TLSClientConfig("testdata/ca.crt", "", "")
tlsConfig, err := TLSClientConfig("testdata/ca.crt", "", "", "")
require.NoError(t, err)
u := "https://" + addr + "/dns-query"
cPost, err := NewDoHClient("test-doh", u, DoHClientOptions{TLSConfig: tlsConfig, Method: "POST"})
Expand Down Expand Up @@ -82,7 +82,7 @@ func TestDoHListenerMutual(t *testing.T) {

// Make a client talking to the listener. Need to trust the issuer of the server certificate and
// present a client certificate.
tlsClientConfig, err := TLSClientConfig("testdata/ca.crt", "testdata/client.crt", "testdata/client.key")
tlsClientConfig, err := TLSClientConfig("testdata/ca.crt", "testdata/client.crt", "testdata/client.key", "")
require.NoError(t, err)
u := "https://" + addr + "/dns-query"
c, err := NewDoHClient("test-doh", u, DoHClientOptions{TLSConfig: tlsClientConfig})
Expand Down Expand Up @@ -116,7 +116,7 @@ func TestDoHListenerMutualQUIC(t *testing.T) {

// Make a client talking to the listener. Need to trust the issuer of the server certificate and
// present a client certificate.
tlsClientConfig, err := TLSClientConfig("testdata/ca.crt", "testdata/client.crt", "testdata/client.key")
tlsClientConfig, err := TLSClientConfig("testdata/ca.crt", "testdata/client.crt", "testdata/client.key", "")
require.NoError(t, err)
u := "https://" + addr + "/dns-query"
c, err := NewDoHClient("test-doh", u, DoHClientOptions{TLSConfig: tlsClientConfig, Transport: "quic"})
Expand Down
7 changes: 6 additions & 1 deletion doqlistener.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,12 @@ func (s DoQListener) Stop() error {
}

func (s DoQListener) handleConnection(connection quic.Connection) {
var ci ClientInfo
tlsServerName := connection.ConnectionState().TLS.ServerName

ci := ClientInfo{
Listener: s.id,
TLSServerName: tlsServerName,
}
switch addr := connection.RemoteAddr().(type) {
case *net.TCPAddr:
ci.SourceIP = addr.IP
Expand Down
2 changes: 1 addition & 1 deletion dotclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestDoTClientCA(t *testing.T) {
conn.Close()

// Create a config with CA using the temp file
tlsConfig, err := TLSClientConfig(crtFile, "", "")
tlsConfig, err := TLSClientConfig(crtFile, "", "", "")
require.NoError(t, err)

// DoT client with valid CA
Expand Down
6 changes: 3 additions & 3 deletions dotlistener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestDoTListenerSimple(t *testing.T) {
time.Sleep(time.Second)

// Make a client talking to the listener. Need to trust the issue of the server certificate.
tlsConfig, err := TLSClientConfig("testdata/ca.crt", "", "")
tlsConfig, err := TLSClientConfig("testdata/ca.crt", "", "", "")
require.NoError(t, err)
c, _ := NewDoTClient("test-dot", addr, DoTClientOptions{TLSConfig: tlsConfig})

Expand Down Expand Up @@ -64,7 +64,7 @@ func TestDoTListenerMutual(t *testing.T) {

// Make a client talking to the listener. Need to trust the issue of the server certificate and
// present a client certificate.
tlsClientConfig, err := TLSClientConfig("testdata/ca.crt", "testdata/client.crt", "testdata/client.key")
tlsClientConfig, err := TLSClientConfig("testdata/ca.crt", "testdata/client.crt", "testdata/client.key", "")
require.NoError(t, err)
c, _ := NewDoTClient("test-dot", addr, DoTClientOptions{TLSConfig: tlsClientConfig})

Expand Down Expand Up @@ -99,7 +99,7 @@ func TestDoTListenerPadding(t *testing.T) {
time.Sleep(time.Second)

// Make a client talking to the listener. Need to trust the issue of the server certificate.
tlsConfig, err := TLSClientConfig("testdata/ca.crt", "", "")
tlsConfig, err := TLSClientConfig("testdata/ca.crt", "", "", "")
require.NoError(t, err)
c, _ := NewDoTClient("test-dot", addr, DoTClientOptions{TLSConfig: tlsConfig})

Expand Down
4 changes: 2 additions & 2 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ func Example_router() {

// Build a router that will send all "*.cloudflare.com" to the cloudflare
// resolver while everything else goes to the google resolver (default)
route1, _ := rdns.NewRoute(`\.cloudflare\.com\.$`, "", nil, nil, "", "", "", "", cloudflare)
route2, _ := rdns.NewRoute("", "", nil, nil, "", "", "", "", google)
route1, _ := rdns.NewRoute(`\.cloudflare\.com\.$`, "", nil, nil, "", "", "", "", "", "", cloudflare)
route2, _ := rdns.NewRoute("", "", nil, nil, "", "", "", "", "", "", google)
r := rdns.NewRouter("my-router")
r.Add(route1, route2)

Expand Down
7 changes: 7 additions & 0 deletions listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ type ClientInfo struct {
// DoH query path used by the client. Only populated when
// the query was received over DoH.
DoHPath string

// TLS SNI server name
TLSServerName string

// Listener ID of the listener that first received the request. Can be
// used to route queries.
Listener string
}

// Metrics that are available from listeners and clients.
Expand Down
8 changes: 4 additions & 4 deletions response-collapse.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@ import (
type ResponseCollapse struct {
id string
resolver Resolver
ResponseCollapsOptions
ResponseCollapseOptions
}

type ResponseCollapsOptions struct {
type ResponseCollapseOptions struct {
NullRCode int // Response code when there's nothing left after collapsing the response
}

var _ Resolver = &ResponseCollapse{}

// NewResponseMinimize returns a new instance of a response minimizer.
func NewResponseCollapse(id string, resolver Resolver, opt ResponseCollapsOptions) *ResponseCollapse {
return &ResponseCollapse{id: id, resolver: resolver, ResponseCollapsOptions: opt}
func NewResponseCollapse(id string, resolver Resolver, opt ResponseCollapseOptions) *ResponseCollapse {
return &ResponseCollapse{id: id, resolver: resolver, ResponseCollapseOptions: opt}
}

// Resolve a DNS query, then collapse the response to remove anything from the
Expand Down
64 changes: 44 additions & 20 deletions route.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,22 @@ import (
)

type route struct {
types []uint16
class uint16
name *regexp.Regexp
source *net.IPNet
weekdays []time.Weekday
before *TimeOfDay
after *TimeOfDay
inverted bool // invert the matching behavior
dohPath *regexp.Regexp
resolver Resolver
types []uint16
class uint16
name *regexp.Regexp
source *net.IPNet
weekdays []time.Weekday
before *TimeOfDay
after *TimeOfDay
inverted bool // invert the matching behavior
dohPath *regexp.Regexp
resolver Resolver
listenerID *regexp.Regexp
tlsServerName *regexp.Regexp
}

// NewRoute initializes a route from string parameters.
func NewRoute(name, class string, types, weekdays []string, before, after, source, dohPath string, resolver Resolver) (*route, error) {
func NewRoute(name, class string, types, weekdays []string, before, after, source, dohPath, listenerID, tlsServerName string, resolver Resolver) (*route, error) {
if resolver == nil {
return nil, errors.New("no resolver defined for route")
}
Expand Down Expand Up @@ -58,6 +60,14 @@ func NewRoute(name, class string, types, weekdays []string, before, after, sourc
if err != nil {
return nil, err
}
listenerRe, err := regexp.Compile(listenerID)
if err != nil {
return nil, err
}
tlsRe, err := regexp.Compile(tlsServerName)
if err != nil {
return nil, err
}
var sNet *net.IPNet
if source != "" {
_, sNet, err = net.ParseCIDR(source)
Expand All @@ -66,15 +76,17 @@ func NewRoute(name, class string, types, weekdays []string, before, after, sourc
}
}
return &route{
types: t,
class: c,
name: re,
weekdays: w,
before: b,
after: a,
source: sNet,
dohPath: dohRe,
resolver: resolver,
types: t,
class: c,
name: re,
weekdays: w,
before: b,
after: a,
source: sNet,
dohPath: dohRe,
listenerID: listenerRe,
tlsServerName: tlsRe,
resolver: resolver,
}, nil
}

Expand All @@ -95,6 +107,12 @@ func (r *route) match(q *dns.Msg, ci ClientInfo) bool {
if !r.dohPath.MatchString(ci.DoHPath) {
return r.inverted
}
if !r.listenerID.MatchString(ci.Listener) {
return r.inverted
}
if !r.tlsServerName.MatchString(ci.TLSServerName) {
return r.inverted
}
if len(r.weekdays) > 0 || r.before != nil || r.after != nil {
now := time.Now().Local()
hour := now.Hour()
Expand Down Expand Up @@ -151,6 +169,12 @@ func (r *route) String() string {
if r.dohPath.String() != "" {
fragments = append(fragments, "doh-path="+r.dohPath.String())
}
if r.listenerID.String() != "" {
fragments = append(fragments, "listener="+r.listenerID.String())
}
if r.tlsServerName.String() != "" {
fragments = append(fragments, "servername="+r.tlsServerName.String())
}
if len(r.weekdays) > 0 {
fragments = append(fragments, fmt.Sprintf("weekdays=%v", r.weekdays))
}
Expand Down
Loading

0 comments on commit 8117a90

Please sign in to comment.