diff --git a/client.go b/client.go index ebf1574..4772969 100644 --- a/client.go +++ b/client.go @@ -31,6 +31,23 @@ type ResolverInfo struct { SharedKey [keySize]byte // Shared key that is to be used to encrypt/decrypt messages } +// SetCertInfo and validates DNSCrypt certificate from the given dns message +// Data received during this call is then used for DNS requests encryption/decryption +// stampStr is an sdns:// address which is parsed using go-dnsstamps package +func (c *Client) SetCertInfo(stampStr string, r dns.Msg) (*ResolverInfo, error) { + stamp, err := dnsstamps.NewServerStampFromString(stampStr) + if err != nil { + // Invalid SDNS stamp + return nil, err + } + + if stamp.Proto != dnsstamps.StampProtoTypeDNSCrypt { + return nil, ErrInvalidDNSStamp + } + + return c.SetCertInfoStamp(stamp, r) +} + // Dial fetches and validates DNSCrypt certificate from the given server // Data received during this call is then used for DNS requests encryption/decryption // stampStr is an sdns:// address which is parsed using go-dnsstamps package @@ -48,6 +65,34 @@ func (c *Client) Dial(stampStr string) (*ResolverInfo, error) { return c.DialStamp(stamp) } +// SetCertInfoStamp set and validates DNSCrypt certificate from the given server +// Data received during this call is then used for DNS requests encryption/decryption +func (c *Client) SetCertInfoStamp(stamp dnsstamps.ServerStamp, r dns.Msg) (*ResolverInfo, error) { + resolverInfo := &ResolverInfo{} + + // Generate the secret/public pair + resolverInfo.SecretKey, resolverInfo.PublicKey = generateRandomKeyPair() + + // Set the provider properties + resolverInfo.ServerPublicKey = stamp.ServerPk + resolverInfo.ServerAddress = stamp.ServerAddrStr + resolverInfo.ProviderName = stamp.ProviderName + + cert, err := c.setCertInfo(stamp, r) + if err != nil { + return nil, err + } + resolverInfo.ResolverCert = cert + + // Compute shared key that we'll use to encrypt/decrypt messages + sharedKey, err := computeSharedKey(cert.EsVersion, &resolverInfo.SecretKey, &cert.ResolverPk) + if err != nil { + return nil, err + } + resolverInfo.SharedKey = sharedKey + return resolverInfo, nil +} + // DialStamp fetches and validates DNSCrypt certificate from the given server // Data received during this call is then used for DNS requests encryption/decryption func (c *Client) DialStamp(stamp dnsstamps.ServerStamp) (*ResolverInfo, error) { @@ -203,13 +248,12 @@ func (c *Client) decrypt(b []byte, resolverInfo *ResolverInfo) (*dns.Msg, error) return res, nil } -// fetchCert loads DNSCrypt cert from the specified server +//fetchCertPlaintext loads DNSCrypt record from the specified server func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (*Cert, error) { providerName := stamp.ProviderName if !strings.HasSuffix(providerName, ".") { providerName = providerName + "." } - query := new(dns.Msg) query.SetQuestion(providerName, dns.TypeTXT) client := dns.Client{Net: c.Net, UDPSize: uint16(maxQueryLen), Timeout: c.Timeout} @@ -221,7 +265,15 @@ func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (*Cert, error) { if r.Rcode != dns.RcodeSuccess { return nil, ErrFailedToFetchCert } + return c.setCertInfo(stamp, *r) +} +// fetchCert set DNSCrypt cert +func (c *Client) setCertInfo(stamp dnsstamps.ServerStamp, r dns.Msg) (*Cert, error) { + providerName := stamp.ProviderName + if !strings.HasSuffix(providerName, ".") { + providerName = providerName + "." + } var certErr error currentCert := &Cert{} foundValid := false