Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add basic dashboard for bandit tests #1

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 97 additions & 7 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ import (
"database/sql"
"encoding/json"
"errors"
"gobandit/models"
"gobandit/templates"
"log"
"math"
"math/rand"
"net/http"
"time"

. "gobandit/models"

"github.com/gorilla/mux"
_ "github.com/lib/pq"
)
Expand All @@ -37,11 +37,30 @@ func (s *Server) routes() {
s.router.HandleFunc("/tests", s.handleCreateTest).Methods("POST")
s.router.HandleFunc("/tests/{testID}/arm", s.handleGetArm).Methods("GET")
s.router.HandleFunc("/tests/{testID}/arms/{armID}/result", s.handleRecordResult).Methods("POST")

// Dashboard routes
s.router.HandleFunc("/", s.handleDashboard).Methods("GET")
s.router.HandleFunc("/tests/{testID}/arms", s.handleGetArmStats).Methods("GET")
}

// handleDashboard renders the main dashboard
func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
tests, err := s.getAllTests()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

err = templates.Dashboard(tests).Render(r.Context(), w)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}

// handleCreateTest creates a new test with the specified arms
func (s *Server) handleCreateTest(w http.ResponseWriter, r *http.Request) {
var test Test
var test models.Test
if err := json.NewDecoder(r.Body).Decode(&test); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
Expand Down Expand Up @@ -88,6 +107,77 @@ func (s *Server) handleCreateTest(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(test)
}

// handleGetArmStats returns stats for all arms in a test
func (s *Server) handleGetArmStats(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
testID := vars["testID"]

arms, err := s.getTestArms(testID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

err = templates.ArmStats(arms).Render(r.Context(), w)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}

// getAllTests retrieves all tests from the database
func (s *Server) getAllTests() ([]models.Test, error) {
rows, err := s.db.Query(`
SELECT id, name, description, created_at, updated_at
FROM tests
ORDER BY created_at DESC
`)
if err != nil {
return nil, err
}
defer rows.Close()

var tests []models.Test
for rows.Next() {
var test models.Test
err := rows.Scan(&test.ID, &test.Name, &test.Description, &test.CreatedAt, &test.UpdatedAt)
if err != nil {
return nil, err
}
tests = append(tests, test)
}
return tests, nil
}

// getTestArms retrieves all arms for a specific test
func (s *Server) getTestArms(testID string) ([]models.Arm, error) {
rows, err := s.db.Query(`
SELECT id, name, description, successes, failures, created_at, updated_at
FROM arms
WHERE test_id = $1
ORDER BY name
`, testID)
if err != nil {
return nil, err
}
defer rows.Close()

var arms []models.Arm
for rows.Next() {
var arm models.Arm
err := rows.Scan(
&arm.ID, &arm.Name, &arm.Description,
&arm.Successes, &arm.Failures,
&arm.CreatedAt, &arm.UpdatedAt,
)
if err != nil {
return nil, err
}
arms = append(arms, arm)
}
return arms, nil
}

// handleGetArm returns the next arm to test using Thompson Sampling
func (s *Server) handleGetArm(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
Expand All @@ -105,9 +195,9 @@ func (s *Server) handleGetArm(w http.ResponseWriter, r *http.Request) {
}
defer rows.Close()

var arms []Arm
var arms []models.Arm
for rows.Next() {
var arm Arm
var arm models.Arm
err := rows.Scan(&arm.ID, &arm.Name, &arm.Successes, &arm.Failures)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down Expand Up @@ -167,12 +257,12 @@ func (s *Server) handleRecordResult(w http.ResponseWriter, r *http.Request) {
}

// thompsonSampling implements the Thompson Sampling algorithm
func thompsonSampling(arms []Arm) Arm {
func thompsonSampling(arms []models.Arm) models.Arm {
rand.Seed(time.Now().UnixNano())

var (
maxSample float64
selected Arm
selected models.Arm
)

for _, arm := range arms {
Expand Down
4 changes: 2 additions & 2 deletions templates/dashboard.templ
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ templ layout() {
</html>
}

templ dashboard(tests []Test) {
templ Dashboard(tests []Test) {
@layout() {
<h1 class="text-3xl font-bold mb-8">A/B Test Dashboard</h1>
<div class="grid grid-cols-1 gap-6">
Expand Down Expand Up @@ -52,7 +52,7 @@ templ testCard(test Test) {
</div>
}

templ armStats(arms []Arm) {
templ ArmStats(arms []Arm) {
for _, arm := range arms {
@armCard(arm)
}
Expand Down
4 changes: 2 additions & 2 deletions templates/dashboard_templ.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.