Skip to content

Commit

Permalink
Extract client name from the URL (DoH and DoT) (#317)
Browse files Browse the repository at this point in the history
* Extract client name from the URL (DoH and DoT) #304

* improved tests
  • Loading branch information
0xERR0R authored Oct 13, 2021
1 parent cd76796 commit a90fb5d
Show file tree
Hide file tree
Showing 12 changed files with 188 additions and 56 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.idea/
*.iml
*.pem
/*.pem
bin/
dist/
docs/swagger.json
Expand Down
3 changes: 3 additions & 0 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ const (

// PathQueryPath defines the REST endpoint for query
PathQueryPath = "/api/query"

// PathDohQuery DoH Url
PathDohQuery = "/dns-query"
)

// QueryRequest is a data structure for a DNS request
Expand Down
29 changes: 23 additions & 6 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,17 +158,34 @@ client.lan.net to 192.170.1.2 and 192.170.1.3.

## Client name lookup

Blocky can try to resolve a user-friendly client name from the IP address. This is useful for defining of blocking
groups, since IP address can change dynamically. Blocky uses rDNS to retrieve client's name. To use this feature, you
can configure a DNS server for client lookup (typically your router). You can also define client names manually per IP
address.
Blocky can try to resolve a user-friendly client name from the IP address or server URL (DoT and DoH). This is useful
for defining of blocking groups, since IP address can change dynamically.

### Single name order
### Resolving client name from URL/Host

If DoT or DoH is enabled, you can use a subdomain prefixed with `id-` to provide a client name (wildcard ssl certificate
recommended).

Example: domain `example.com`

DoT Host: `id-bob.example.com` -> request's client name is `bob`
DoH URL: `https://id-bob.example.com/dns-query` -> request's client name is `bob`

For DoH you can also pass the client name as url parameter:

DoH URL: `htpps://blocky.example.com/dns-query/alice` -> request's client name is `alice`

### Resolving client name from IP address

Blocky uses rDNS to retrieve client's name. To use this feature, you can configure a DNS server for client lookup (
typically your router). You can also define client names manually per IP address.

#### Single name order

Some routers return multiple names for the client (host name and user defined name). With
parameter `clientLookup.singleNameOrder` you can specify, which of retrieved names should be used.

### Custom client name mapping
#### Custom client name mapping

You can also map a particular client name to one (or more) IP (ipv4/ipv6) addresses. Parameter `clientLookup.clients`
contains a map of client name and multiple IP addresses.
Expand Down
15 changes: 8 additions & 7 deletions model/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,17 @@ type Response struct {

// RequestProtocol represents the server protocol ENUM(
// TCP // is the TPC protocol
// UDP // is the UDP protocol
// UDP // is the UDP protocol
// )
type RequestProtocol uint8

// Request represents client's DNS request
type Request struct {
ClientIP net.IP
Protocol RequestProtocol
ClientNames []string
Req *dns.Msg
Log *logrus.Entry
RequestTS time.Time
ClientIP net.IP
RequestClientID string
Protocol RequestProtocol
ClientNames []string
Req *dns.Msg
Log *logrus.Entry
RequestTS time.Time
}
5 changes: 5 additions & 0 deletions resolver/client_names_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func NewClientNamesResolver(cfg config.ClientLookupConfig) ChainedResolver {
func (r *ClientNamesResolver) Configuration() (result []string) {
if r.externalResolver != nil || len(r.clientIPMapping) > 0 {
result = append(result, fmt.Sprintf("singleNameOrder = \"%v\"", r.singleNameOrder))

if r.externalResolver != nil {
result = append(result, fmt.Sprintf("externalResolver = \"%s\"", r.externalResolver))
}
Expand Down Expand Up @@ -89,6 +90,10 @@ func (r *ClientNamesResolver) getClientNames(request *model.Request) []string {
}
}

if request.RequestClientID != "" {
return []string{request.RequestClientID}
}

names := r.resolveClientNames(ip, withPrefix(request.Log, "client_names_resolver"))
r.cache.Set(ip.String(), names, cache.DefaultExpiration)

Expand Down
23 changes: 23 additions & 0 deletions resolver/client_names_resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,29 @@ var _ = Describe("ClientResolver", func() {

})

Describe("Resolve client name from request clientID", func() {

It("should use clientID if set", func() {
request := newRequestWithClientID("google1.de.", dns.TypeA, "1.2.3.4", "client123")
resp, err = sut.Resolve(request)

Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(request.ClientNames).Should(HaveLen(1))
Expect(request.ClientNames[0]).Should(Equal("client123"))
Expect(mockReverseUpstreamCallCount).Should(Equal(0))
})
It("should use IP as fallback if clientID not set", func() {
request := newRequestWithClientID("google2.de.", dns.TypeA, "1.2.3.4", "")
resp, err = sut.Resolve(request)

Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess))
Expect(request.ClientNames).Should(HaveLen(1))
Expect(request.ClientNames[0]).Should(Equal("1.2.3.4"))
Expect(mockReverseUpstreamCallCount).Should(Equal(1))
})

})

Describe("Resolve client name with custom name mapping", func() {
BeforeEach(func() {
sutConfig = config.ClientLookupConfig{
Expand Down
11 changes: 11 additions & 0 deletions resolver/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ func newRequestWithClient(question string, rType uint16, ip string, clientNames
}
}

func newRequestWithClientID(question string, rType uint16, ip string, requestClientID string) *model.Request {
return &model.Request{
ClientIP: net.ParseIP(ip),
RequestClientID: requestClientID,
Req: util.NewMsgWithQuestion(question, rType),
Log: logrus.NewEntry(log.Log()),
RequestTS: time.Time{},
Protocol: model.RequestProtocolUDP,
}
}

// Resolver generic interface for all resolvers
type Resolver interface {

Expand Down
42 changes: 33 additions & 9 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,30 +271,54 @@ func (s *Server) Stop() {
}
}

func createResolverRequest(remoteAddress net.Addr, request *dns.Msg) *model.Request {
clientIP, protocol := resolveClientIPAndProtocol(remoteAddress)
func createResolverRequest(rw dns.ResponseWriter, request *dns.Msg) *model.Request {
var hostName string

return newRequest(clientIP, protocol, request)
var remoteAddr net.Addr

if rw != nil {
remoteAddr = rw.RemoteAddr()
}

clientIP, protocol := resolveClientIPAndProtocol(remoteAddr)
con, ok := rw.(dns.ConnectionStater)

if ok && con.ConnectionState() != nil {
hostName = con.ConnectionState().ServerName
}

return newRequest(clientIP, protocol, extractClientIDFromHost(hostName), request)
}

func newRequest(clientIP net.IP, protocol model.RequestProtocol, request *dns.Msg) *model.Request {
func extractClientIDFromHost(hostName string) string {
const clientIDPrefix = "id-"
if strings.HasPrefix(hostName, clientIDPrefix) && strings.Contains(hostName, ".") {
return hostName[len(clientIDPrefix):strings.Index(hostName, ".")]
}

return ""
}

func newRequest(clientIP net.IP, protocol model.RequestProtocol,
requestClientID string, request *dns.Msg) *model.Request {
return &model.Request{
ClientIP: clientIP,
Protocol: protocol,
Req: request,
RequestTS: time.Now(),
ClientIP: clientIP,
RequestClientID: requestClientID,
Protocol: protocol,
Req: request,
Log: log.Log().WithFields(logrus.Fields{
"question": util.QuestionToString(request.Question),
"client_ip": clientIP,
}),
RequestTS: time.Now(),
}
}

// OnRequest will be executed if a new DNS request is received
func (s *Server) OnRequest(w dns.ResponseWriter, request *dns.Msg) {
logger().Debug("new request")

r := createResolverRequest(w.RemoteAddr(), request)
r := createResolverRequest(w, request)

response, err := s.queryResolver.Resolve(r)

Expand Down
57 changes: 26 additions & 31 deletions server/server_endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ const (
func (s *Server) registerAPIEndpoints(router *chi.Mux) {
router.Post(api.PathQueryPath, s.apiQuery)

router.Get("/dns-query", s.dohGetRequestHandler)
router.Post("/dns-query", s.dohPostRequestHandler)
router.Get(api.PathDohQuery, s.dohGetRequestHandler)
router.Get(api.PathDohQuery+"/{clientID}", s.dohGetRequestHandler)
router.Post(api.PathDohQuery, s.dohPostRequestHandler)
router.Post(api.PathDohQuery+"/{clientID}", s.dohPostRequestHandler)
}

func (s *Server) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) {
Expand Down Expand Up @@ -94,14 +96,17 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
return
}

r := newRequest(net.ParseIP(extractIP(req)), model.RequestProtocolTCP, msg)
clientID := chi.URLParam(req, "clientID")
if clientID == "" {
clientID = extractClientIDFromHost(req.Host)
}

r := newRequest(net.ParseIP(extractIP(req)), model.RequestProtocolTCP, clientID, msg)

resResponse, err := s.queryResolver.Resolve(r)

if err != nil {
logger().Error("unable to process query: ", err)
http.Error(rw, err.Error(), http.StatusInternalServerError)

logAndResponseWithError(err, "unable to process query: ", rw)
return
}

Expand All @@ -112,17 +117,14 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h

b, err := resResponse.Res.Pack()
if err != nil {
logger().Error("can't serialize message: ", err)
http.Error(rw, err.Error(), http.StatusInternalServerError)
logAndResponseWithError(err, "can't serialize message: ", rw)
return
}

rw.Header().Set("content-type", dnsContentType)

_, err = rw.Write(b)
if err != nil {
logger().Error("can't write response: ", err)
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
logAndResponseWithError(err, "can't write response: ", rw)
}

func extractIP(r *http.Request) string {
Expand Down Expand Up @@ -158,18 +160,15 @@ func (s *Server) apiQuery(rw http.ResponseWriter, req *http.Request) {
err := json.NewDecoder(req.Body).Decode(&queryRequest)

if err != nil {
logger().Error("can't read request: ", err)
http.Error(rw, err.Error(), http.StatusInternalServerError)

logAndResponseWithError(err, "can't read request: ", rw)
return
}

// validate query type
qType := dns.StringToType[queryRequest.Type]
if qType == dns.TypeNone {
err = fmt.Errorf("unknown query type '%s'", queryRequest.Type)
logger().Error(err)
http.Error(rw, err.Error(), http.StatusInternalServerError)
logAndResponseWithError(err, "unknown query type: ", rw)

return
}
Expand All @@ -187,9 +186,7 @@ func (s *Server) apiQuery(rw http.ResponseWriter, req *http.Request) {
response, err := s.queryResolver.Resolve(r)

if err != nil {
logger().Error("unable to process query: ", err)
http.Error(rw, err.Error(), http.StatusInternalServerError)

logAndResponseWithError(err, "unable to process query: ", rw)
return
}

Expand All @@ -200,13 +197,7 @@ func (s *Server) apiQuery(rw http.ResponseWriter, req *http.Request) {
ReturnCode: dns.RcodeToString[response.Res.Rcode],
})
_, err = rw.Write(jsonResponse)

if err != nil {
logger().Error("unable to write response ", err)
http.Error(rw, err.Error(), http.StatusInternalServerError)

return
}
logAndResponseWithError(err, "unable to write response: ", rw)
}

func createRouter(cfg *config.Config) *chi.Mux {
Expand Down Expand Up @@ -261,13 +252,17 @@ func configureRootHandler(cfg *config.Config, router *chi.Mux) {
}

err := t.Execute(writer, pd)
if err != nil {
log.Log().Error("can't write index template: ", err)
writer.WriteHeader(http.StatusInternalServerError)
}
logAndResponseWithError(err, "can't write index template: ", writer)
})
}

func logAndResponseWithError(err error, message string, writer http.ResponseWriter) {
if err != nil {
log.Log().Error(message, err)
http.Error(writer, err.Error(), http.StatusInternalServerError)
}
}

func configureDebugHandler(router *chi.Mux) {
router.Mount("/debug", middleware.Profiler())
}
Expand Down
27 changes: 25 additions & 2 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,12 @@ var _ = Describe("Running DNS server", func() {
Upstream: upstreamClient,
},

Port: "55555",
HTTPPort: "4000",
Port: "55555",
TLSPort: "8853",
CertFile: "../testdata/cert.pem",
KeyFile: "../testdata/key.pem",
HTTPPort: "4000",
HTTPSPort: "4443",
Prometheus: config.PrometheusConfig{
Enable: true,
Path: "/metrics",
Expand Down Expand Up @@ -423,6 +427,25 @@ var _ = Describe("Running DNS server", func() {
err = msg.Unpack(rawMsg)
Expect(err).Should(Succeed())

Expect(msg.Answer).Should(BeDNSRecord("www.example.com.", dns.TypeA, 0, "123.124.122.122"))
})
It("should get a valid response, clientId is passed", func() {
msg := util.NewMsgWithQuestion("www.example.com.", dns.TypeA)
rawDNSMessage, err := msg.Pack()
Expect(err).Should(Succeed())

resp, err := http.Post("http://localhost:4000/dns-query/client123",
"application/dns-message", bytes.NewReader(rawDNSMessage))
Expect(err).Should(Succeed())
defer resp.Body.Close()
Expect(resp).Should(HaveHTTPStatus(http.StatusOK))
rawMsg, err := ioutil.ReadAll(resp.Body)
Expect(err).Should(Succeed())

msg = new(dns.Msg)
err = msg.Unpack(rawMsg)
Expect(err).Should(Succeed())

Expect(msg.Answer).Should(BeDNSRecord("www.example.com.", dns.TypeA, 0, "123.124.122.122"))
})
})
Expand Down
14 changes: 14 additions & 0 deletions testdata/cert.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
-----BEGIN CERTIFICATE-----
MIICMzCCAZygAwIBAgIRAJCCrDTGEtZfRpxDY1KAoswwDQYJKoZIhvcNAQELBQAw
EjEQMA4GA1UEChMHQWNtZSBDbzAgFw03MDAxMDEwMDAwMDBaGA8yMDg0MDEyOTE2
MDAwMFowEjEQMA4GA1UEChMHQWNtZSBDbzCBnzANBgkqhkiG9w0BAQEFAAOBjQAw
gYkCgYEA4mEaF5yWYYrTfMgRXdBpgGnqsHIADQWlw7BIJWD/gNp+fgp4TUZ/7ggV
rrvRORvRFjw14avd9L9EFP7XLi8ViU3uoE1UWI32MlrKqLbGNCXyUIApIoqlbRg6
iErxIk5+ChzFuysQOx01S2yv/ML6dx7NOGHs1S38MUzRZtcXBH8CAwEAAaOBhjCB
gzAOBgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/
BAUwAwEB/zAdBgNVHQ4EFgQUslNI6tYIv909RttHaZVMS/u/VYYwLAYDVR0RBCUw
I4IJbG9jYWxob3N0hwR/AAABhxAAAAAAAAAAAAAAAAAAAAABMA0GCSqGSIb3DQEB
CwUAA4GBAJ2gRpQHr5Qj7dt26bYVMdN4JGXTsvjbVrJfKI0VfPGJ+SUY/uTVBUeX
+Cwv4DFEPBlNx/lzuUkwmRaExC4/w81LWwxe5KltYsjyJuYowiUbLZ6tzLaQ9Bcx
jxClAVvgj90TGYOwsv6ESOX7GWteN1FlD3+jk7vefjFagaKKFYR9
-----END CERTIFICATE-----
Loading

0 comments on commit a90fb5d

Please sign in to comment.