diff --git a/server/config.go b/server/config.go index 26124eb..ddf5a34 100644 --- a/server/config.go +++ b/server/config.go @@ -22,6 +22,14 @@ const ( defaultNameServer = "localhost." ) +type CommonConfig struct { + domain string + port string + rname string + nameserver string + private bool +} + type AwsConfig struct { clients map[string]*ec2.EC2 // map[region]client } @@ -32,68 +40,75 @@ type GcpConfig struct { client *compute.Service } -func ParseConfig(config map[interface{}]interface{}) (domain, nameserver, port, rname string, private bool, awsConfig *AwsConfig, gcpConfig *GcpConfig, err error) { +//domain, nameserver, port, rname string, private bool, +func ParseConfig(config map[interface{}]interface{}) (commonConfig *CommonConfig, awsConfig *AwsConfig, gcpConfig *GcpConfig, err error) { if config == nil { err = fmt.Errorf("[err] ParseConfig empty params") return } + commonConfig = &CommonConfig{} + // get domain if v, ok := config["domain"]; !ok { + commonConfig = nil err = fmt.Errorf("[err] ParseConfig empty domain") + return } else { rawDomain := strings.TrimSpace(v.(string)) if !strings.HasSuffix(rawDomain, ".") { - domain = rawDomain + "." + commonConfig.domain = rawDomain + "." } else { - domain = rawDomain + commonConfig.domain = rawDomain } } + // get nameserver if v, ok := config["nameserver"]; !ok { - nameserver = defaultNameServer + commonConfig.nameserver = defaultNameServer } else { rawNameserver := strings.TrimSpace(v.(string)) if !strings.HasSuffix(rawNameserver, ".") { - nameserver = rawNameserver + "." + commonConfig.nameserver = rawNameserver + "." } else { - nameserver = rawNameserver + commonConfig.nameserver = rawNameserver } } // get port if v, ok := config["port"]; !ok { - port = defaultPort + commonConfig.port = defaultPort } else { switch v.(type) { case int, int64: - port = fmt.Sprintf("%d", v) + commonConfig.port = fmt.Sprintf("%d", v) case string: - port = strings.TrimSpace(v.(string)) + commonConfig.port = strings.TrimSpace(v.(string)) } } - // email + // get email if v, ok := config["email"]; !ok { - rname = defaultRName + commonConfig.rname = defaultRName } else { - rname = strings.Replace(v.(string), "@", ".", -1) + "." + commonConfig.rname = strings.Replace(v.(string), "@", ".", -1) + "." } // private if v, ok := config["private"]; !ok { - private = false + commonConfig.private = false } else { switch v.(type) { case bool: - private = v.(bool) + commonConfig.private = v.(bool) case string: e, suberr := strconv.ParseBool(strings.TrimSpace(v.(string))) if suberr != nil { + commonConfig = nil err = fmt.Errorf("[err] private field is invalid.") return } - private = e + commonConfig.private = e } } @@ -207,4 +222,4 @@ func ParseConfig(config map[interface{}]interface{}) (domain, nameserver, port, } } return -} +} \ No newline at end of file diff --git a/server/config_test.go b/server/config_test.go index 330c943..4dae043 100644 --- a/server/config_test.go +++ b/server/config_test.go @@ -14,22 +14,23 @@ func TestParseConfig(t *testing.T) { assert := assert.New(t) tests := map[string]struct { - input map[interface{}]interface{} - domain string - port string - awsConfig *AwsConfig - gcpConfig *GcpConfig - err bool + input map[interface{}]interface{} + commonConfig *CommonConfig + awsConfig *AwsConfig + gcpConfig *GcpConfig + err bool }{ "empty": {input: nil, err: true}, "emptyDomain": {input: make(map[interface{}]interface{}), err: true}, "success": {input: map[interface{}]interface{}{ - "domain": "localhost"}, err: false, domain: "localhost."}, + "domain": "localhost"}, err: false, commonConfig: &CommonConfig{domain: "localhost."}}, } for _, t := range tests { - do, _, _, _, _, ac, gc, err := ParseConfig(t.input) - assert.Equal(t.domain, do) + co, ac, gc, err := ParseConfig(t.input) + if co != nil { + assert.Equal(t.commonConfig.domain, co.domain) + } assert.Equal(t.awsConfig, ac) assert.Equal(t.gcpConfig, gc) if t.err { @@ -48,14 +49,14 @@ func TestParseConfig(t *testing.T) { assert.NoError(err) // both enable - do, na, po, rn, pr, ac, gc, err := ParseConfig(config) + co, ac, gc, err := ParseConfig(config) assert.NoError(err) - assert.NotEqual("", do) - assert.NotEqual("", na) - assert.NotEqual("", po) - assert.NotEqual("", rn) - assert.True(strings.HasSuffix(do, ".")) - assert.Equal(config["private"], pr) + assert.NotEqual("", co.domain) + assert.True(strings.HasSuffix(co.domain, ".")) + assert.NotEqual("", co.nameserver) + assert.NotEqual("", co.port) + assert.NotEqual("", co.rname) + assert.Equal(config["private"], co.private) assert.NotEmpty(ac) assert.NotEmpty(gc) assert.NotEqual(0, len(ac.clients)) @@ -65,18 +66,18 @@ func TestParseConfig(t *testing.T) { // gcp enable off config["gcp"].(map[interface{}]interface{})["enable"] = "false" - do, _, _, _, _, ac, gc, err = ParseConfig(config) + co, ac, gc, err = ParseConfig(config) assert.NoError(err) - assert.NotEqual("", do) + assert.NotEqual("", co.domain) assert.NotEmpty(ac) assert.Empty(gc) // aws enable off config["gcp"].(map[interface{}]interface{})["enable"] = "true" config["aws"].(map[interface{}]interface{})["enable"] = "false" - do, _, _, _, _, ac, gc, err = ParseConfig(config) + co, ac, gc, err = ParseConfig(config) assert.NoError(err) - assert.NotEqual("", do) + assert.NotEqual("", co.domain) assert.Empty(ac) assert.NotEmpty(gc) } diff --git a/server/server.go b/server/server.go index 57e332c..75d8726 100644 --- a/server/server.go +++ b/server/server.go @@ -20,32 +20,28 @@ type Server interface { } type server struct { - domain string - port string - rname string - nameserver string - publicIP string - private bool - store *Store + publicIP string + config *CommonConfig + store *Store } func (s *server) Start() { - udpServer := &dns.Server{Addr: ":" + s.port, Net: "udp"} + udpServer := &dns.Server{Addr: ":" + s.config.port, Net: "udp"} go func() { if err := udpServer.ListenAndServe(); err != nil { log.Panic(err) } }() - tcpServer := &dns.Server{Addr: ":" + s.port, Net: "tcp"} + tcpServer := &dns.Server{Addr: ":" + s.config.port, Net: "tcp"} mode := "PUBLIC-IP" - if s.private { + if s.config.private { mode = "PRIVATE-IP" } log.Printf("%s listen(%s) nameserver(%s) domain(%s) Serving %s\n", aurora.Green("[start]"), - aurora.Blue(fmt.Sprintf("%s:%s", s.publicIP, s.port)), - aurora.Yellow(s.nameserver), - aurora.Cyan(s.domain), + aurora.Blue(fmt.Sprintf("%s:%s", s.publicIP, s.config.port)), + aurora.Yellow(s.config.nameserver), + aurora.Cyan(s.config.domain), aurora.Magenta(mode), ) if err := tcpServer.ListenAndServe(); err != nil { @@ -124,23 +120,23 @@ func (s *server) dnsRequest(w dns.ResponseWriter, r *dns.Msg) { for _, msg := range m.Question { switch msg.Qtype { case dns.TypeNS: // dns nameserver - if msg.Name == s.domain { + if msg.Name == s.config.domain { m.Answer = append(m.Answer, s.ns()) } case dns.TypeSOA: // dns info - if msg.Name == s.domain { + if msg.Name == s.config.domain { m.Answer = append(m.Answer, s.soa()) } case dns.TypeA: // ipv4 - if strings.HasSuffix(msg.Name, s.domain) { - prefix := strings.TrimSpace(strings.TrimSuffix(msg.Name, "."+s.domain)) + if strings.HasSuffix(msg.Name, s.config.domain) { + prefix := strings.TrimSpace(strings.TrimSuffix(msg.Name, "."+s.config.domain)) records, err := s.Lookup(prefix) if err != nil { log.Printf("[err] lookup %+v\n", err) } else { for _, record := range records { ip := record.PublicIP - if s.private { + if s.config.private { ip = record.PrivateIP } m.Answer = append(m.Answer, &dns.A{ @@ -168,16 +164,16 @@ func (s *server) dnsRequest(w dns.ResponseWriter, r *dns.Msg) { func (s *server) ns() *dns.NS { return &dns.NS{ - Hdr: dns.RR_Header{Name: s.domain, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: uint32(TTL / time.Second)}, - Ns: s.nameserver, + Hdr: dns.RR_Header{Name: s.config.domain, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: uint32(TTL / time.Second)}, + Ns: s.config.nameserver, } } func (s *server) soa() *dns.SOA { return &dns.SOA{ - Hdr: dns.RR_Header{Name: s.domain, Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: uint32(TTL / time.Second)}, - Ns: s.nameserver, - Mbox: s.rname, + Hdr: dns.RR_Header{Name: s.config.domain, Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: uint32(TTL / time.Second)}, + Ns: s.config.nameserver, + Mbox: s.config.rname, Serial: uint32(s.store.cacheUpdatedAt.Unix()), // cache updatedAt Refresh: uint32((6 * time.Hour) / time.Second), Retry: uint32((30 * time.Minute) / time.Second), @@ -200,24 +196,55 @@ func NewServer(yamlPath string) (Server, error) { return nil, err } - domain, nameserver, port, rname, private, awsconfig, gcpconfig, err := ParseConfig(config) + // parse yaml + commonConfig, awsconfig, gcpconfig, err := ParseConfig(config) if err != nil { return nil, err } + if commonConfig == nil { + return nil, fmt.Errorf("[err] required config missing") + } + if awsconfig == nil && gcpconfig == nil { return nil, fmt.Errorf("[err] the aws or the gcp must be useful at least one") } + // get machine public ip publicIP, err := goip.GetPublicIPV4() if err != nil { log.Printf("%s not found machine public ip \n", aurora.Red("[fail]")) } - // check NS Record - nsrecords, err := net.LookupNS(domain) + // generate dns table + store, err := NewStore(awsconfig, gcpconfig) + if err != nil { + return nil, err + } + + // check normal config + checkedConfig, err := checkConfig(commonConfig) + if err != nil { + return nil, err + } + + s := &server{config: checkedConfig, + publicIP: publicIP, store: store} + + // register handler + dns.HandleFunc(s.config.domain, s.dnsRequest) + return Server(s), nil +} + +func checkConfig(config *CommonConfig) (*CommonConfig, error) { + if config == nil { + return nil, fmt.Errorf("[err] empty checkConfig") + } + + // check a machine state to be associated nameserver. + nsrecords, err := net.LookupNS(config.domain) if err != nil { - log.Printf("%s %s not found NS Record %v\n", aurora.Red("[fail]"), aurora.Magenta(domain), err) + log.Printf("%s %s not found NS Record %v\n", aurora.Red("[fail]"), aurora.Magenta(config.domain), err) } else { for _, ns := range nsrecords { ips, err := net.LookupIP(ns.Host) @@ -226,33 +253,22 @@ func NewServer(yamlPath string) (Server, error) { } else { check := false for _, ip := range ips { - if ip.String() == nameserver { + if ip.String() == config.nameserver { check = true } } if check { - if nameserver == defaultNameServer { - log.Printf("%s matched %s \n", aurora.Green("[success-auto-detect]"), aurora.Magenta(ns.Host)) - nameserver = ns.Host + if config.nameserver == defaultNameServer { + log.Printf("%s matched %s \n", aurora.Green("[success-match-with-detect]"), aurora.Magenta(ns.Host)) + config.nameserver = ns.Host } else { - log.Printf("%s matched %s \n", aurora.Green("[success-match]"), aurora.Magenta(nameserver)) + log.Printf("%s matched %s \n", aurora.Green("[success-match]"), aurora.Magenta(config.nameserver)) } - } else { - log.Printf("%s not matched %s \n", aurora.Red("[fail]"), aurora.Magenta(nameserver)) + log.Printf("%s not matched %s \n", aurora.Red("[fail]"), aurora.Magenta(config.nameserver)) } } } } - - store, err := NewStore(awsconfig, gcpconfig) - if err != nil { - return nil, err - } - s := &server{domain: domain, port: port, nameserver: nameserver, rname: rname, private: private, - publicIP: publicIP, store: store} - - // register handler - dns.HandleFunc(s.domain, s.dnsRequest) - return Server(s), nil -} + return config, nil +} \ No newline at end of file diff --git a/server/store_test.go b/server/store_test.go index 5d68cec..e36266b 100644 --- a/server/store_test.go +++ b/server/store_test.go @@ -20,7 +20,7 @@ func TestNewStore(t *testing.T) { assert.NoError(err) yaml.Unmarshal(bys, &config) assert.NoError(err) - _, _, _, _, _, awsconfig, gcpconfig, err := ParseConfig(config) + _, awsconfig, gcpconfig, err := ParseConfig(config) assert.NoError(err) store, err := NewStore(awsconfig, gcpconfig) @@ -40,7 +40,7 @@ func TestStore_Lookup(t *testing.T) { assert.NoError(err) yaml.Unmarshal(bys, &config) assert.NoError(err) - _, _, _, _, _, awsconfig, gcpconfig, err := ParseConfig(config) + _, awsconfig, gcpconfig, err := ParseConfig(config) assert.NoError(err) store, err := NewStore(awsconfig, gcpconfig) @@ -86,7 +86,7 @@ func TestRecord_TTL(t *testing.T) { assert.NoError(err) yaml.Unmarshal(bys, &config) assert.NoError(err) - _, _, _, _, _, awsconfig, gcpconfig, err := ParseConfig(config) + _, awsconfig, gcpconfig, err := ParseConfig(config) assert.NoError(err) store, err := NewStore(awsconfig, gcpconfig) @@ -109,7 +109,7 @@ func BenchmarkStore_Lookup(b *testing.B) { config := make(map[interface{}]interface{}) bys, _ := ioutil.ReadFile(yamlPath) yaml.Unmarshal(bys, &config) - _, _, _, _, _, awsconfig, gcpconfig, _ := ParseConfig(config) + _, awsconfig, gcpconfig, _ := ParseConfig(config) store, _ := NewStore(awsconfig, gcpconfig) for i := 0; i < b.N; i++ { store.Lookup(os.Getenv("TEST_AWS_1")) diff --git a/template.yaml b/template.yaml index 71f3cd0..f4c68f8 100644 --- a/template.yaml +++ b/template.yaml @@ -1,4 +1,4 @@ -domain: your-name-server-domain (eg) test.example.com .. and so on) +domain: your-name-server-domain (eg) localhost, test.example.com .. and so on) nameserver: your-machine hostname or public domain(default -> localhost or auto detect value. it should be a value that dns will resolve) port: tcp, udp open port number(default -> 53) email: your-email(responsible name)