Skip to content
This repository has been archived by the owner on Oct 30, 2018. It is now read-only.

Commit

Permalink
Cleaned up how bind info is tracked by the API
Browse files Browse the repository at this point in the history
  • Loading branch information
zealws committed Nov 13, 2015
1 parent f4eba81 commit 304fac4
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (

var (
API_URL_MAP = map[string]HandlerFunc{
"/": RedirectHandler(ServerData.ApiUrl + "shape"),
"/": RedirectHandler(ROOT_URL + "/shape"),
"/info": InfoHandler,
"/group": GroupsHandler,
"/group/{id}": GroupHandler,
Expand Down
10 changes: 3 additions & 7 deletions src/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ type testServer struct {
}

func newTestServer(t *testing.T) testServer {
srv, err := ListenAndServe(net.JoinHostPort(ServerAddr, "0"), "", "", "sqlite3", ":memory:")
srv, err := ListenAndServe(net.JoinHostPort(ServerAddr, "0"), "", "", "sqlite3", ":memory:", "0.0.0.0", "::")
if err != nil {
t.Fatal(err)
}
Expand All @@ -249,12 +249,8 @@ func (s testServer) Cleanup() {
}

func (s testServer) client(addr string) testClient {
if addr == "" {
addr = s.GetAddr()
} else {
_, port, _ := net.SplitHostPort(s.GetAddr())
addr = net.JoinHostPort(addr, port)
}
_, port, _ := net.SplitHostPort(s.GetAddr())
addr = net.JoinHostPort(addr, port)
return testClient{addr, s.t}
}

Expand Down
67 changes: 54 additions & 13 deletions src/api/assets.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,67 @@ package api
import (
"fmt"
"html/template"
"net"
"net/http"

"github.com/gorilla/mux"
)

func rootHandler(w http.ResponseWriter, r *http.Request) {
data, err := Asset("static/index.htm")
if err != nil {
fmt.Println(err)
w.WriteHeader(404)
return
type bindInfo struct {
ApiUrl string
IP4 string
IP6 string
Port string
}

type templateData struct {
ApiUrl string
Primary string
Secondary string
}

func (info *bindInfo) templateFor(r *http.Request) *templateData {
data := &templateData{
ApiUrl: info.ApiUrl,
}
tmpl, err := template.New("root").Parse(string(data))
if err != nil {
fmt.Println(err)
w.WriteHeader(500)
return
addr, _, _ := net.SplitHostPort(r.RemoteAddr)
if len(net.ParseIP(addr)) == net.IPv6len {
// client is ipv6
data.Primary = info.IP6
data.Secondary = info.IP4
} else {
// client is ipv4
data.Primary = info.IP4
data.Secondary = info.IP6
}
// If the user didn't provide one of the two addresses, we pass the UI an
// empty string.
if data.Primary != "" {
data.Primary = net.JoinHostPort(data.Primary, info.Port)
}
if data.Secondary != "" {
data.Secondary = net.JoinHostPort(data.Secondary, info.Port)
}
return data
}

func rootHandler(info *bindInfo) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
data, err := Asset("static/index.htm")
if err != nil {
fmt.Println(err)
w.WriteHeader(404)
return
}
tmpl, err := template.New("root").Parse(string(data))
if err != nil {
fmt.Println(err)
w.WriteHeader(500)
return
}
w.WriteHeader(200)
tmpl.Execute(w, info.templateFor(r))
}
w.WriteHeader(200)
tmpl.Execute(w, ServerData)
}

func cachedAssetHandler(w http.ResponseWriter, r *http.Request) {
Expand Down
22 changes: 12 additions & 10 deletions src/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,9 @@ var (
// HTTP Connection timeouts for read/write
TIMEOUT = time.Second * 30

ServerData = serverData{
ApiUrl: "/api/v1/",
}
ROOT_URL = "/api/v1"
)

type serverData struct {
ApiUrl string
}

type Server struct {
Addr string
Timeout time.Duration
Expand All @@ -30,13 +24,15 @@ type Server struct {
db *DbRunner
thrift_proto string
thrift_addr string
bind_info *bindInfo
}

func ListenAndServe(addr, thrift_addr, thrift_proto, dbdriver, dbconn string) (*Server, error) {
func ListenAndServe(addr, thrift_addr, thrift_proto, dbdriver, dbconn, v4, v6 string) (*Server, error) {
db, err := NewDbRunner(dbdriver, dbconn)
if err != nil {
return nil, err
}
_, port, _ := net.SplitHostPort(addr)
srv := &Server{
Addr: addr,
listener: nil,
Expand All @@ -46,6 +42,12 @@ func ListenAndServe(addr, thrift_addr, thrift_proto, dbdriver, dbconn string) (*
thrift_proto: thrift_proto,
Atcd: nil,
db: db,
bind_info: &bindInfo{
ApiUrl: ROOT_URL,
IP4: v4,
IP6: v6,
Port: port,
},
}
srv.setupHandlers()
err = srv.ListenAndServe()
Expand Down Expand Up @@ -97,13 +99,13 @@ func (srv *Server) Serve() {

func (srv *Server) setupHandlers() {
r := mux.NewRouter()
apir := r.PathPrefix(ServerData.ApiUrl).Subrouter()
apir := r.PathPrefix(ROOT_URL).Subrouter()
for url, f := range API_URL_MAP {
h := NewHandler(srv, f)
apir.HandleFunc(url, h)
apir.HandleFunc(url+"/", h)
}
r.HandleFunc("/", rootHandler)
r.HandleFunc("/", rootHandler(srv.bind_info))
r.HandleFunc("/static/{folder}/{name}", cachedAssetHandler)
srv.Handler = r
}
Expand Down
13 changes: 12 additions & 1 deletion src/atc_api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ func TestAtcdConnection(addr, proto string) error {
func main() {
args := ParseArgs()

// Make sure connection to the daemon is working.
err := TestAtcdConnection(args.ThriftAddr, args.ThriftProtocol)
if err != nil {
api.Log.Println("failed to connect to atcd server:", err)
Expand All @@ -32,9 +33,13 @@ func main() {
api.Log.Println("Connected to atcd socket on", args.ThriftAddr)
}

if args.IPv4 == "" || args.IPv6 == "" {
api.Log.Fatalln("You must provide either -4 or -6 arguments to run the API.")
}

api.Log.Println("Listening on", args.BindAddr)

srv, err := api.ListenAndServe(args.BindAddr, args.ThriftAddr, args.ThriftProtocol, args.DbDriver, args.DbConnstr)
srv, err := api.ListenAndServe(args.BindAddr, args.ThriftAddr, args.ThriftProtocol, args.DbDriver, args.DbConnstr, args.IPv4, args.IPv6)
if err != nil {
api.Log.Fatalln("failed to listen and serve:", err)
}
Expand All @@ -52,6 +57,8 @@ type Arguments struct {
DbDriver string
DbConnstr string
WarnOnly bool
IPv4 string
IPv6 string
}

func ParseArgs() Arguments {
Expand All @@ -61,6 +68,8 @@ func ParseArgs() Arguments {
db_driver := flag.String("D", "sqlite3", "database driver")
db_connstr := flag.String("Q", "atc_api.db", "database driver connection parameters")
warn_only := flag.Bool("W", false, "only warn if the thrift server isn't reachable")
ipv4 := flag.String("4", "", "IPv4 address for the API")
ipv6 := flag.String("6", "", "IPv6 address for the API")
flag.Parse()

return Arguments{
Expand All @@ -70,5 +79,7 @@ func ParseArgs() Arguments {
DbDriver: *db_driver,
DbConnstr: *db_connstr,
WarnOnly: *warn_only,
IPv4: *ipv4,
IPv6: *ipv6,
}
}

0 comments on commit 304fac4

Please sign in to comment.