From 9cb58fc05213e57daf7415a410bea9929565acd6 Mon Sep 17 00:00:00 2001 From: Eoin H Date: Mon, 13 Jan 2025 16:05:18 +0000 Subject: [PATCH] feat: Add basic dashboard for tests --- main.go | 104 ++++++++++++++++++++++++++++++++--- templates/dashboard.templ | 4 +- templates/dashboard_templ.go | 4 +- 3 files changed, 101 insertions(+), 11 deletions(-) diff --git a/main.go b/main.go index d54ce1e..f8f594b 100644 --- a/main.go +++ b/main.go @@ -4,14 +4,14 @@ import ( "database/sql" "encoding/json" "errors" + "gobandit/models" + "gobandit/templates" "log" "math" "math/rand" "net/http" "time" - . "gobandit/models" - "github.com/gorilla/mux" _ "github.com/lib/pq" ) @@ -37,11 +37,30 @@ func (s *Server) routes() { s.router.HandleFunc("/tests", s.handleCreateTest).Methods("POST") s.router.HandleFunc("/tests/{testID}/arm", s.handleGetArm).Methods("GET") s.router.HandleFunc("/tests/{testID}/arms/{armID}/result", s.handleRecordResult).Methods("POST") + + // Dashboard routes + s.router.HandleFunc("/", s.handleDashboard).Methods("GET") + s.router.HandleFunc("/tests/{testID}/arms", s.handleGetArmStats).Methods("GET") +} + +// handleDashboard renders the main dashboard +func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) { + tests, err := s.getAllTests() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + err = templates.Dashboard(tests).Render(r.Context(), w) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } } // handleCreateTest creates a new test with the specified arms func (s *Server) handleCreateTest(w http.ResponseWriter, r *http.Request) { - var test Test + var test models.Test if err := json.NewDecoder(r.Body).Decode(&test); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -88,6 +107,77 @@ func (s *Server) handleCreateTest(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(test) } +// handleGetArmStats returns stats for all arms in a test +func (s *Server) handleGetArmStats(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + testID := vars["testID"] + + arms, err := s.getTestArms(testID) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + err = templates.ArmStats(arms).Render(r.Context(), w) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + +// getAllTests retrieves all tests from the database +func (s *Server) getAllTests() ([]models.Test, error) { + rows, err := s.db.Query(` + SELECT id, name, description, created_at, updated_at + FROM tests + ORDER BY created_at DESC + `) + if err != nil { + return nil, err + } + defer rows.Close() + + var tests []models.Test + for rows.Next() { + var test models.Test + err := rows.Scan(&test.ID, &test.Name, &test.Description, &test.CreatedAt, &test.UpdatedAt) + if err != nil { + return nil, err + } + tests = append(tests, test) + } + return tests, nil +} + +// getTestArms retrieves all arms for a specific test +func (s *Server) getTestArms(testID string) ([]models.Arm, error) { + rows, err := s.db.Query(` + SELECT id, name, description, successes, failures, created_at, updated_at + FROM arms + WHERE test_id = $1 + ORDER BY name + `, testID) + if err != nil { + return nil, err + } + defer rows.Close() + + var arms []models.Arm + for rows.Next() { + var arm models.Arm + err := rows.Scan( + &arm.ID, &arm.Name, &arm.Description, + &arm.Successes, &arm.Failures, + &arm.CreatedAt, &arm.UpdatedAt, + ) + if err != nil { + return nil, err + } + arms = append(arms, arm) + } + return arms, nil +} + // handleGetArm returns the next arm to test using Thompson Sampling func (s *Server) handleGetArm(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) @@ -105,9 +195,9 @@ func (s *Server) handleGetArm(w http.ResponseWriter, r *http.Request) { } defer rows.Close() - var arms []Arm + var arms []models.Arm for rows.Next() { - var arm Arm + var arm models.Arm err := rows.Scan(&arm.ID, &arm.Name, &arm.Successes, &arm.Failures) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -167,12 +257,12 @@ func (s *Server) handleRecordResult(w http.ResponseWriter, r *http.Request) { } // thompsonSampling implements the Thompson Sampling algorithm -func thompsonSampling(arms []Arm) Arm { +func thompsonSampling(arms []models.Arm) models.Arm { rand.Seed(time.Now().UnixNano()) var ( maxSample float64 - selected Arm + selected models.Arm ) for _, arm := range arms { diff --git a/templates/dashboard.templ b/templates/dashboard.templ index d65bde1..45e4fa1 100644 --- a/templates/dashboard.templ +++ b/templates/dashboard.templ @@ -24,7 +24,7 @@ templ layout() { } -templ dashboard(tests []Test) { +templ Dashboard(tests []Test) { @layout() {

A/B Test Dashboard

@@ -52,7 +52,7 @@ templ testCard(test Test) {
} -templ armStats(arms []Arm) { +templ ArmStats(arms []Arm) { for _, arm := range arms { @armCard(arm) } diff --git a/templates/dashboard_templ.go b/templates/dashboard_templ.go index 4c1f4a6..a163748 100644 --- a/templates/dashboard_templ.go +++ b/templates/dashboard_templ.go @@ -52,7 +52,7 @@ func layout() templ.Component { }) } -func dashboard(tests []Test) templ.Component { +func Dashboard(tests []Test) templ.Component { return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) { templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil { @@ -190,7 +190,7 @@ func testCard(test Test) templ.Component { }) } -func armStats(arms []Arm) templ.Component { +func ArmStats(arms []Arm) templ.Component { return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) { templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {