From 14abd7388d44439e343146583953bfeff9e9705b Mon Sep 17 00:00:00 2001 From: Nate Sales Date: Fri, 20 Oct 2023 22:49:34 -0400 Subject: [PATCH] refactor: server URL parser (#66) --- main.go | 238 +++++++++++++++++++---------------------- main_test.go | 4 +- transport/transport.go | 3 + 3 files changed, 113 insertions(+), 132 deletions(-) diff --git a/main.go b/main.go index ed6c0cd..a32b9b3 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ import ( "os" "reflect" "regexp" + "slices" "strconv" "strings" "time" @@ -97,161 +98,130 @@ func txtConcat(m *dns.Msg) { m.Answer = answers } -// parseServer parses opts.Server and returns the server address and transport type -func parseServer() (string, transport.Type, error) { - var txp transport.Type - var host, port, scopeId string - var isHTTPS bool - - // Set default protocol - if !strings.Contains(opts.Server, "://") { - txp = transport.TypePlain - } else { - txp = transport.Type(strings.Split(opts.Server, "://")[0]) - if txp == "https" { - isHTTPS = true - txp = transport.TypeHTTP - } - } +// dnsStampToURL converts a DNS stamp string to a URL string +func dnsStampToURL(s string) (string, error) { + var u url.URL - // Parse DNS stamp - if strings.HasPrefix(opts.Server, "sdns://") { - parsedStamp, err := dnsstamps.NewServerStampFromString(opts.Server) - if err != nil { - return "", "", err - } + parsedStamp, err := dnsstamps.NewServerStampFromString(s) + if err != nil { + return "", err + } - switch parsedStamp.Proto { - case dnsstamps.StampProtoTypePlain: - txp = transport.TypePlain - case dnsstamps.StampProtoTypeTLS: - txp = transport.TypeTLS - case dnsstamps.StampProtoTypeDoH: - isHTTPS = true // Default to DoH (HTTPS) - txp = transport.TypeHTTP - case dnsstamps.StampProtoTypeDNSCrypt: - // DNS stamp parsing happens again in the DNSCrypt transport - return opts.Server, transport.TypeDNSCrypt, nil - default: - return "", "", fmt.Errorf("unsupported protocol %s in DNS stamp", parsedStamp.Proto.String()) - } - log.Tracef("DNS stamp parsed as %s", txp) - - // TODO: This might be a source of problems...we might want to be using parsedStamp.ServerAddrStr - host = parsedStamp.ProviderName - } else { // Not DNS stamp - // Remove anything before and including the first :// - host = regexp.MustCompile(`^.*://`).ReplaceAllString(opts.Server, "") - - // Remove port from host - switch { - case strings.Contains(host, "[") && !strings.Contains(host, "]") || - !strings.Contains(host, "[") && strings.Contains(host, "]"): - return "", "", fmt.Errorf("invalid IPv6 bracket notation") - case strings.Contains(host, "[") && strings.Contains(host, "]"): // IPv6 in bracket notation - portSuffix := strings.Split(host, "]:") - if len(portSuffix) > 1 { // With explicit port - port = portSuffix[1] - } else { - port = "" - } + switch parsedStamp.Proto { + case dnsstamps.StampProtoTypePlain: + u.Scheme = string(transport.TypePlain) + case dnsstamps.StampProtoTypeTLS: + u.Scheme = string(transport.TypeTLS) + case dnsstamps.StampProtoTypeDoH: + u.Scheme = string(transport.TypeHTTP) + "s" // default to HTTPS + case dnsstamps.StampProtoTypeDNSCrypt: + // DNS stamp parsing happens again in the DNSCrypt transport, so pass the input along unchanged + return s, nil + default: + return "", fmt.Errorf("unsupported protocol %s in DNS stamp", parsedStamp.Proto.String()) + } - host = strings.Split(strings.Split(host, "[")[1], "]")[0] + // TODO: This might be a source of problems...we might want to be using parsedStamp.ServerAddrStr + u.Host = parsedStamp.ProviderName - // Remove IPv6 scope ID - if strings.Contains(host, "%") { - parts := strings.Split(host, "%") - host = parts[0] - scopeId = parts[1] - } + log.Tracef("DNS stamp parsed into URL as %s", u.String()) + return u.String(), nil +} - host = "[" + host + "]" - log.Tracef("host contains ], treating as v6 with port. host: %s port: %s", host, port) - case strings.Contains(host, ".") && strings.Contains(host, ":"): // IPv4 or hostname with port - parts := strings.Split(host, ":") - host = parts[0] - port = parts[1] - log.Tracef("host contains . and :, treating as (v4 or host) with explicit port. host %s port %s", host, port) - case strings.Contains(host, ":") && !strings.Contains(host, "/"): // IPv6 no port - // Remove IPv6 scope ID - if strings.Contains(host, "%") { - parts := strings.Split(host, "%") - host = parts[0] - scopeId = parts[1] - } +// setPort sets the port of a url.URL +func setPort(u *url.URL, port int) { + if strings.Contains(u.Host, ":") { + if strings.Contains(u.Host, "[") && strings.Contains(u.Host, "]") { + u.Host = fmt.Sprintf("%s]:%d", strings.Split(u.Host, "]")[0], port) + return + } + u.Host = "[" + u.Host + "]" + } + u.Host = fmt.Sprintf("%s:%d", u.Host, port) +} - host = "[" + host + "]" - log.Tracef("host contains :, treating as v6 without port. host %s", host) - default: - log.Tracef("no cases matched for host %s port %s", host, port) +// parseServer is a revised version of parseServer that uses the URL package for parsing +func parseServer(s string) (string, transport.Type, error) { + // Remove IPv6 scope ID if present + var scopeId string + v6scopeRe := regexp.MustCompile(`\[[a-fA-F0-9:]+%[a-zA-Z0-9]+]`) + if v6scopeRe.MatchString(s) { + v6scopeRemoveRe := regexp.MustCompile(`(%[a-zA-Z0-9]+)`) + matches := v6scopeRemoveRe.FindStringSubmatch(s) + if len(matches) > 1 { + scopeId = matches[1] + s = v6scopeRemoveRe.ReplaceAllString(s, "") } + log.Tracef("Removed IPv6 scope ID %s from server %s", scopeId, s) } - // Validate ODoH - if opts.ODoHProxy != "" { - if !strings.HasPrefix(opts.ODoHProxy, "https://") { - return "", "", fmt.Errorf("ODoH proxy must use HTTPS") + // Handle DNS stamp + if strings.HasPrefix(s, "sdns://") { + var err error + s, err = dnsStampToURL(s) + if err != nil { + return "", "", fmt.Errorf("converting DNS stamp to URL: %s", err) } - if !strings.HasPrefix(opts.Server, "https://") { - return "", "", fmt.Errorf("ODoH target must use HTTPS") + // If s is still a DNS stamp, it's DNSCrypt + if strings.HasPrefix(s, "sdns://") { + return s, transport.TypeDNSCrypt, nil } } - if port == "" { - switch txp { - case transport.TypeQUIC: - port = "853" - case transport.TypeTLS: - port = "853" - case transport.TypeHTTP: - if isHTTPS { - port = "443" - } else { - port = "80" - } - case transport.TypePlain, transport.TypeTCP: - port = "53" - } - log.Tracef("Setting port to %s", port) - } else { - log.Tracef("Port is %s, not overriding", port) + // Check if server starts with a scheme, if not, default to plain + schemeRe := regexp.MustCompile(`^[a-zA-Z0-9]+://`) + if !schemeRe.MatchString(s) { + s = "plain://" + s } - urlScheme := string(txp) - if isHTTPS { - urlScheme = "https" + // Parse server as URL + tu, err := url.Parse(s) + if err != nil { + return "", "", fmt.Errorf("parsing %s: %s", s, err) } - fqdn := urlScheme + "://" + host - if txp != transport.TypeHTTP { - fqdn += ":" + port + // Parse transport type + ts := transport.Type(tu.Scheme) + if tu.Scheme == "https" { // Override HTTPS to HTTP, preserving tu.Scheme as HTTPS + ts = transport.TypeHTTP } - log.Tracef("checking FQDN %s", fqdn) - u, err := url.Parse(fqdn) - if err != nil { - return "", "", err + if !slices.Contains(transport.Types, ts) { + return "", "", fmt.Errorf("unsupported transport %s. expected: %+v", ts, transport.Types) } - server := host + ":" + port + // Set default port + if tu.Port() == "" { + switch ts { + case transport.TypeQUIC, transport.TypeTLS: + setPort(tu, 853) + case transport.TypeHTTP: + if tu.Scheme == "https" { + setPort(tu, 443) + } else { + setPort(tu, 80) + } + case transport.TypePlain, transport.TypeTCP: + setPort(tu, 53) + } + } - if txp == transport.TypeHTTP { - port = strings.Split(port, "/")[0] - u.Host += ":" + port - server = u.String() + // Add default path if missing + if ts == transport.TypeHTTP && tu.Path == "" { + tu.Path = "/dns-query" + } - // Add default path if missing - if u.Path == "" { - server += "/dns-query" - log.Tracef("HTTPS scheme and no path, setting server to %s", server) - } + server := tu.String() + // Remove scheme from server if irrelevant to protocol + if ts != transport.TypeHTTP { + server = strings.Split(server, "://")[1] } - // Insert scope ID before ']' + // Add IPv6 scope ID back to server if scopeId != "" { - server = strings.Replace(server, "]", "%"+scopeId+"]", 1) + server = strings.Replace(server, "]", scopeId+"]", 1) } - return server, txp, nil + return server, ts, nil } // driver is the "main" function for this program that accepts a flag slice for testing @@ -387,6 +357,16 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation } } + // Validate ODoH + if opts.ODoHProxy != "" { + if !strings.HasPrefix(opts.ODoHProxy, "https://") { + return fmt.Errorf("ODoH proxy must use HTTPS") + } + if !strings.HasPrefix(opts.Server, "https://") { + return fmt.Errorf("ODoH target must use HTTPS") + } + } + if opts.Chaos { log.Debug("Flag set, using chaos class") opts.Class = dns.ClassCHAOS @@ -439,7 +419,7 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation ) // Parse server address and transport type - server, transportType, err := parseServer() + server, transportType, err := parseServer(opts.Server) if err != nil { return err } diff --git a/main_test.go b/main_test.go index a50669c..8027f94 100644 --- a/main_test.go +++ b/main_test.go @@ -492,9 +492,7 @@ func TestMainParseServer(t *testing.T) { }, } { t.Run(tc.Server, func(t *testing.T) { - clearOpts() - opts.Server = tc.Server - server, transportType, err := parseServer() + server, transportType, err := parseServer(tc.Server) assert.Nil(t, err) assert.Equal(t, tc.ExpectedHost, server) assert.Equal(t, tc.Type, transportType) diff --git a/transport/transport.go b/transport/transport.go index 292252a..1508af3 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -18,6 +18,9 @@ const ( TypeDNSCrypt Type = "dnscrypt" ) +// Types is a list of all supported transports +var Types = []Type{TypePlain, TypeTCP, TypeTLS, TypeHTTP, TypeQUIC, TypeDNSCrypt} + // Interface guards var ( _ Transport = (*Plain)(nil)