Skip to content

Commit

Permalink
Refactoring dns config.
Browse files Browse the repository at this point in the history
  • Loading branch information
gjbae1212 committed Jun 19, 2019
1 parent b03f93d commit 2214a0c
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 87 deletions.
47 changes: 31 additions & 16 deletions server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ const (
defaultNameServer = "localhost."
)

type CommonConfig struct {
domain string
port string
rname string
nameserver string
private bool
}

type AwsConfig struct {
clients map[string]*ec2.EC2 // map[region]client
}
Expand All @@ -32,68 +40,75 @@ type GcpConfig struct {
client *compute.Service
}

func ParseConfig(config map[interface{}]interface{}) (domain, nameserver, port, rname string, private bool, awsConfig *AwsConfig, gcpConfig *GcpConfig, err error) {
//domain, nameserver, port, rname string, private bool,
func ParseConfig(config map[interface{}]interface{}) (commonConfig *CommonConfig, awsConfig *AwsConfig, gcpConfig *GcpConfig, err error) {
if config == nil {
err = fmt.Errorf("[err] ParseConfig empty params")
return
}

commonConfig = &CommonConfig{}

// get domain
if v, ok := config["domain"]; !ok {
commonConfig = nil
err = fmt.Errorf("[err] ParseConfig empty domain")
return
} else {
rawDomain := strings.TrimSpace(v.(string))
if !strings.HasSuffix(rawDomain, ".") {
domain = rawDomain + "."
commonConfig.domain = rawDomain + "."
} else {
domain = rawDomain
commonConfig.domain = rawDomain
}
}

// get nameserver
if v, ok := config["nameserver"]; !ok {
nameserver = defaultNameServer
commonConfig.nameserver = defaultNameServer
} else {
rawNameserver := strings.TrimSpace(v.(string))
if !strings.HasSuffix(rawNameserver, ".") {
nameserver = rawNameserver + "."
commonConfig.nameserver = rawNameserver + "."
} else {
nameserver = rawNameserver
commonConfig.nameserver = rawNameserver
}
}

// get port
if v, ok := config["port"]; !ok {
port = defaultPort
commonConfig.port = defaultPort
} else {
switch v.(type) {
case int, int64:
port = fmt.Sprintf("%d", v)
commonConfig.port = fmt.Sprintf("%d", v)
case string:
port = strings.TrimSpace(v.(string))
commonConfig.port = strings.TrimSpace(v.(string))
}
}

// email
// get email
if v, ok := config["email"]; !ok {
rname = defaultRName
commonConfig.rname = defaultRName
} else {
rname = strings.Replace(v.(string), "@", ".", -1) + "."
commonConfig.rname = strings.Replace(v.(string), "@", ".", -1) + "."
}

// private
if v, ok := config["private"]; !ok {
private = false
commonConfig.private = false
} else {
switch v.(type) {
case bool:
private = v.(bool)
commonConfig.private = v.(bool)
case string:
e, suberr := strconv.ParseBool(strings.TrimSpace(v.(string)))
if suberr != nil {
commonConfig = nil
err = fmt.Errorf("[err] private field is invalid.")
return
}
private = e
commonConfig.private = e
}
}

Expand Down Expand Up @@ -207,4 +222,4 @@ func ParseConfig(config map[interface{}]interface{}) (domain, nameserver, port,
}
}
return
}
}
41 changes: 21 additions & 20 deletions server/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,23 @@ func TestParseConfig(t *testing.T) {
assert := assert.New(t)

tests := map[string]struct {
input map[interface{}]interface{}
domain string
port string
awsConfig *AwsConfig
gcpConfig *GcpConfig
err bool
input map[interface{}]interface{}
commonConfig *CommonConfig
awsConfig *AwsConfig
gcpConfig *GcpConfig
err bool
}{
"empty": {input: nil, err: true},
"emptyDomain": {input: make(map[interface{}]interface{}), err: true},
"success": {input: map[interface{}]interface{}{
"domain": "localhost"}, err: false, domain: "localhost."},
"domain": "localhost"}, err: false, commonConfig: &CommonConfig{domain: "localhost."}},
}

for _, t := range tests {
do, _, _, _, _, ac, gc, err := ParseConfig(t.input)
assert.Equal(t.domain, do)
co, ac, gc, err := ParseConfig(t.input)
if co != nil {
assert.Equal(t.commonConfig.domain, co.domain)
}
assert.Equal(t.awsConfig, ac)
assert.Equal(t.gcpConfig, gc)
if t.err {
Expand All @@ -48,14 +49,14 @@ func TestParseConfig(t *testing.T) {
assert.NoError(err)

// both enable
do, na, po, rn, pr, ac, gc, err := ParseConfig(config)
co, ac, gc, err := ParseConfig(config)
assert.NoError(err)
assert.NotEqual("", do)
assert.NotEqual("", na)
assert.NotEqual("", po)
assert.NotEqual("", rn)
assert.True(strings.HasSuffix(do, "."))
assert.Equal(config["private"], pr)
assert.NotEqual("", co.domain)
assert.True(strings.HasSuffix(co.domain, "."))
assert.NotEqual("", co.nameserver)
assert.NotEqual("", co.port)
assert.NotEqual("", co.rname)
assert.Equal(config["private"], co.private)
assert.NotEmpty(ac)
assert.NotEmpty(gc)
assert.NotEqual(0, len(ac.clients))
Expand All @@ -65,18 +66,18 @@ func TestParseConfig(t *testing.T) {

// gcp enable off
config["gcp"].(map[interface{}]interface{})["enable"] = "false"
do, _, _, _, _, ac, gc, err = ParseConfig(config)
co, ac, gc, err = ParseConfig(config)
assert.NoError(err)
assert.NotEqual("", do)
assert.NotEqual("", co.domain)
assert.NotEmpty(ac)
assert.Empty(gc)

// aws enable off
config["gcp"].(map[interface{}]interface{})["enable"] = "true"
config["aws"].(map[interface{}]interface{})["enable"] = "false"
do, _, _, _, _, ac, gc, err = ParseConfig(config)
co, ac, gc, err = ParseConfig(config)
assert.NoError(err)
assert.NotEqual("", do)
assert.NotEqual("", co.domain)
assert.Empty(ac)
assert.NotEmpty(gc)
}
Expand Down
108 changes: 62 additions & 46 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,28 @@ type Server interface {
}

type server struct {
domain string
port string
rname string
nameserver string
publicIP string
private bool
store *Store
publicIP string
config *CommonConfig
store *Store
}

func (s *server) Start() {
udpServer := &dns.Server{Addr: ":" + s.port, Net: "udp"}
udpServer := &dns.Server{Addr: ":" + s.config.port, Net: "udp"}
go func() {
if err := udpServer.ListenAndServe(); err != nil {
log.Panic(err)
}
}()
tcpServer := &dns.Server{Addr: ":" + s.port, Net: "tcp"}
tcpServer := &dns.Server{Addr: ":" + s.config.port, Net: "tcp"}
mode := "PUBLIC-IP"
if s.private {
if s.config.private {
mode = "PRIVATE-IP"
}
log.Printf("%s listen(%s) nameserver(%s) domain(%s) Serving %s\n",
aurora.Green("[start]"),
aurora.Blue(fmt.Sprintf("%s:%s", s.publicIP, s.port)),
aurora.Yellow(s.nameserver),
aurora.Cyan(s.domain),
aurora.Blue(fmt.Sprintf("%s:%s", s.publicIP, s.config.port)),
aurora.Yellow(s.config.nameserver),
aurora.Cyan(s.config.domain),
aurora.Magenta(mode),
)
if err := tcpServer.ListenAndServe(); err != nil {
Expand Down Expand Up @@ -124,23 +120,23 @@ func (s *server) dnsRequest(w dns.ResponseWriter, r *dns.Msg) {
for _, msg := range m.Question {
switch msg.Qtype {
case dns.TypeNS: // dns nameserver
if msg.Name == s.domain {
if msg.Name == s.config.domain {
m.Answer = append(m.Answer, s.ns())
}
case dns.TypeSOA: // dns info
if msg.Name == s.domain {
if msg.Name == s.config.domain {
m.Answer = append(m.Answer, s.soa())
}
case dns.TypeA: // ipv4
if strings.HasSuffix(msg.Name, s.domain) {
prefix := strings.TrimSpace(strings.TrimSuffix(msg.Name, "."+s.domain))
if strings.HasSuffix(msg.Name, s.config.domain) {
prefix := strings.TrimSpace(strings.TrimSuffix(msg.Name, "."+s.config.domain))
records, err := s.Lookup(prefix)
if err != nil {
log.Printf("[err] lookup %+v\n", err)
} else {
for _, record := range records {
ip := record.PublicIP
if s.private {
if s.config.private {
ip = record.PrivateIP
}
m.Answer = append(m.Answer, &dns.A{
Expand Down Expand Up @@ -168,16 +164,16 @@ func (s *server) dnsRequest(w dns.ResponseWriter, r *dns.Msg) {

func (s *server) ns() *dns.NS {
return &dns.NS{
Hdr: dns.RR_Header{Name: s.domain, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: uint32(TTL / time.Second)},
Ns: s.nameserver,
Hdr: dns.RR_Header{Name: s.config.domain, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: uint32(TTL / time.Second)},
Ns: s.config.nameserver,
}
}

func (s *server) soa() *dns.SOA {
return &dns.SOA{
Hdr: dns.RR_Header{Name: s.domain, Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: uint32(TTL / time.Second)},
Ns: s.nameserver,
Mbox: s.rname,
Hdr: dns.RR_Header{Name: s.config.domain, Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: uint32(TTL / time.Second)},
Ns: s.config.nameserver,
Mbox: s.config.rname,
Serial: uint32(s.store.cacheUpdatedAt.Unix()), // cache updatedAt
Refresh: uint32((6 * time.Hour) / time.Second),
Retry: uint32((30 * time.Minute) / time.Second),
Expand All @@ -200,24 +196,55 @@ func NewServer(yamlPath string) (Server, error) {
return nil, err
}

domain, nameserver, port, rname, private, awsconfig, gcpconfig, err := ParseConfig(config)
// parse yaml
commonConfig, awsconfig, gcpconfig, err := ParseConfig(config)
if err != nil {
return nil, err
}

if commonConfig == nil {
return nil, fmt.Errorf("[err] required config missing")
}

if awsconfig == nil && gcpconfig == nil {
return nil, fmt.Errorf("[err] the aws or the gcp must be useful at least one")
}

// get machine public ip
publicIP, err := goip.GetPublicIPV4()
if err != nil {
log.Printf("%s not found machine public ip \n", aurora.Red("[fail]"))
}

// check NS Record
nsrecords, err := net.LookupNS(domain)
// generate dns table
store, err := NewStore(awsconfig, gcpconfig)
if err != nil {
return nil, err
}

// check normal config
checkedConfig, err := checkConfig(commonConfig)
if err != nil {
return nil, err
}

s := &server{config: checkedConfig,
publicIP: publicIP, store: store}

// register handler
dns.HandleFunc(s.config.domain, s.dnsRequest)
return Server(s), nil
}

func checkConfig(config *CommonConfig) (*CommonConfig, error) {
if config == nil {
return nil, fmt.Errorf("[err] empty checkConfig")
}

// check a machine state to be associated nameserver.
nsrecords, err := net.LookupNS(config.domain)
if err != nil {
log.Printf("%s %s not found NS Record %v\n", aurora.Red("[fail]"), aurora.Magenta(domain), err)
log.Printf("%s %s not found NS Record %v\n", aurora.Red("[fail]"), aurora.Magenta(config.domain), err)
} else {
for _, ns := range nsrecords {
ips, err := net.LookupIP(ns.Host)
Expand All @@ -226,33 +253,22 @@ func NewServer(yamlPath string) (Server, error) {
} else {
check := false
for _, ip := range ips {
if ip.String() == nameserver {
if ip.String() == config.nameserver {
check = true
}
}
if check {
if nameserver == defaultNameServer {
log.Printf("%s matched %s \n", aurora.Green("[success-auto-detect]"), aurora.Magenta(ns.Host))
nameserver = ns.Host
if config.nameserver == defaultNameServer {
log.Printf("%s matched %s \n", aurora.Green("[success-match-with-detect]"), aurora.Magenta(ns.Host))
config.nameserver = ns.Host
} else {
log.Printf("%s matched %s \n", aurora.Green("[success-match]"), aurora.Magenta(nameserver))
log.Printf("%s matched %s \n", aurora.Green("[success-match]"), aurora.Magenta(config.nameserver))
}

} else {
log.Printf("%s not matched %s \n", aurora.Red("[fail]"), aurora.Magenta(nameserver))
log.Printf("%s not matched %s \n", aurora.Red("[fail]"), aurora.Magenta(config.nameserver))
}
}
}
}

store, err := NewStore(awsconfig, gcpconfig)
if err != nil {
return nil, err
}
s := &server{domain: domain, port: port, nameserver: nameserver, rname: rname, private: private,
publicIP: publicIP, store: store}

// register handler
dns.HandleFunc(s.domain, s.dnsRequest)
return Server(s), nil
}
return config, nil
}
Loading

0 comments on commit 2214a0c

Please sign in to comment.