diff --git a/cmd/gvproxy/main.go b/cmd/gvproxy/main.go index 07c6b17..8d901c2 100644 --- a/cmd/gvproxy/main.go +++ b/cmd/gvproxy/main.go @@ -3,6 +3,7 @@ package main import ( "context" "flag" + "fmt" "net" "os" "os/signal" @@ -11,6 +12,7 @@ import ( "github.com/containers/gvisor-tap-vsock/pkg/types" "github.com/containers/gvisor-tap-vsock/pkg/virtualnetwork" "github.com/pkg/errors" + "github.com/sakai135/wsl-vpnkit/pkg/services/dns" "github.com/sakai135/wsl-vpnkit/pkg/transport" log "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" @@ -123,6 +125,17 @@ func run(ctx context.Context, g *errgroup.Group, configuration *types.Configurat return err } + lnDns, err := vn.Listen("tcp", fmt.Sprintf("%s:53", gatewayIP)) + if err != nil { + return err + } + go func() { + err := dns.Serve(lnDns, configuration.DNS) + if err != nil { + log.Error(err) + } + }() + conn := transport.GetStdioConn() err = vn.AcceptQemu(ctx, conn) if err != nil { diff --git a/pkg/services/dns/tcp.go b/pkg/services/dns/tcp.go new file mode 100644 index 0000000..a2c23e4 --- /dev/null +++ b/pkg/services/dns/tcp.go @@ -0,0 +1,123 @@ +package dns + +import ( + "context" + "fmt" + "net" + "strings" + + "github.com/containers/gvisor-tap-vsock/pkg/types" + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" +) + +type dnsHandler struct { + zones []types.Zone +} + +func (h *dnsHandler) handle(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.RecursionAvailable = true + h.addAnswers(m) + if err := w.WriteMsg(m); err != nil { + log.Error(err) + } +} + +func (h *dnsHandler) addAnswers(m *dns.Msg) { + for _, q := range m.Question { + for _, zone := range h.zones { + zoneSuffix := fmt.Sprintf(".%s", zone.Name) + if strings.HasSuffix(q.Name, zoneSuffix) { + if q.Qtype != dns.TypeA { + return + } + for _, record := range zone.Records { + withoutZone := strings.TrimSuffix(q.Name, zoneSuffix) + if (record.Name != "" && record.Name == withoutZone) || + (record.Regexp != nil && record.Regexp.MatchString(withoutZone)) { + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 0, + }, + A: record.IP, + }) + return + } + } + if !zone.DefaultIP.Equal(net.IP("")) { + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 0, + }, + A: zone.DefaultIP, + }) + return + } + m.Rcode = dns.RcodeNameError + return + } + } + + resolver := net.Resolver{ + PreferGo: false, + } + switch q.Qtype { + case dns.TypeNS: + records, err := resolver.LookupNS(context.TODO(), q.Name) + if err != nil { + m.Rcode = dns.RcodeNameError + return + } + for _, ns := range records { + m.Answer = append(m.Answer, &dns.NS{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + Ttl: 0, + }, + Ns: ns.Host, + }) + } + case dns.TypeA: + ips, err := resolver.LookupIPAddr(context.TODO(), q.Name) + if err != nil { + m.Rcode = dns.RcodeNameError + return + } + for _, ip := range ips { + if len(ip.IP.To4()) != net.IPv4len { + continue + } + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 0, + }, + A: ip.IP.To4(), + }) + } + } + } +} + +func Serve(ln net.Listener, zones []types.Zone) error { + mux := dns.NewServeMux() + handler := &dnsHandler{zones: zones} + mux.HandleFunc(".", handler.handle) + srv := &dns.Server{ + Listener: ln, + Handler: mux, + } + return srv.ActivateAndServe() +}