diff --git a/packages/orchestrator/cmd/mock-sandbox/mock.go b/packages/orchestrator/cmd/mock-sandbox/mock.go index be701d9c6..306d8a01f 100644 --- a/packages/orchestrator/cmd/mock-sandbox/mock.go +++ b/packages/orchestrator/cmd/mock-sandbox/mock.go @@ -41,7 +41,7 @@ func main() { cancel() }() - dnsServer := dns.New() + dnsServer := dns.New(func(sandboxID string) error { return nil }) go func() { log.Printf("Starting DNS server") @@ -93,7 +93,7 @@ func mockSandbox( templateId, buildId, sandboxId string, - dns *dns.DNS, + dns *dns.OrchDNS, keepAlive time.Duration, networkPool *network.Pool, templateCache *template.Cache, diff --git a/packages/orchestrator/cmd/mock-snapshot/mock.go b/packages/orchestrator/cmd/mock-snapshot/mock.go index 45aadec24..498b7082e 100644 --- a/packages/orchestrator/cmd/mock-snapshot/mock.go +++ b/packages/orchestrator/cmd/mock-snapshot/mock.go @@ -43,7 +43,7 @@ func main() { cancel() }() - dnsServer := dns.New() + dnsServer := dns.New(func(sandboxID string) error { return nil }) go func() { log.Printf("Starting DNS server") @@ -103,7 +103,7 @@ func mockSnapshot( templateId, buildId, sandboxId string, - dns *dns.DNS, + dns *dns.OrchDNS, keepAlive time.Duration, networkPool *network.Pool, templateCache *template.Cache, diff --git a/packages/orchestrator/internal/dns/server.go b/packages/orchestrator/internal/dns/server.go index 8f065d291..48757b93e 100644 --- a/packages/orchestrator/internal/dns/server.go +++ b/packages/orchestrator/internal/dns/server.go @@ -4,6 +4,7 @@ import ( "fmt" "log" "net" + "net/http" "strings" "sync" @@ -14,36 +15,43 @@ import ( const ttl = 0 -type DNS struct { - mu sync.Mutex - records *smap.Map[string] +const defaultRoutingIP = "127.0.0.1" +const defaultErrorPort = 3003 + +type SandboxErrorChecker func(sandboxID string) error + +type OrchDNS struct { + mu sync.Mutex + records *smap.Map[string] + sandboxErrorChecker SandboxErrorChecker } -func New() *DNS { - return &DNS{ - records: smap.New[string](), +func New(sandboxErrorChecker SandboxErrorChecker) *OrchDNS { + return &OrchDNS{ + records: smap.New[string](), + sandboxErrorChecker: sandboxErrorChecker, } } -func (d *DNS) Add(sandboxID, ip string) { +func (d *OrchDNS) Add(sandboxID, ip string) { d.records.Insert(d.hostname(sandboxID), ip) } -func (d *DNS) Remove(sandboxID, ip string) { +func (d *OrchDNS) Remove(sandboxID, ip string) { d.records.RemoveCb(d.hostname(sandboxID), func(key string, v string, exists bool) bool { return v == ip }) } -func (d *DNS) get(hostname string) (string, bool) { +func (d *OrchDNS) get(hostname string) (string, bool) { return d.records.Get(hostname) } -func (*DNS) hostname(sandboxID string) string { +func (*OrchDNS) hostname(sandboxID string) string { return fmt.Sprintf("%s.", sandboxID) } -func (d *DNS) handleDNSRequest(w resolver.ResponseWriter, r *resolver.Msg) { +func (d *OrchDNS) handleDNSRequest(w resolver.ResponseWriter, r *resolver.Msg) { m := new(resolver.Msg) m.SetReply(r) m.Compress = false @@ -51,20 +59,24 @@ func (d *DNS) handleDNSRequest(w resolver.ResponseWriter, r *resolver.Msg) { for _, q := range m.Question { if q.Qtype == resolver.TypeA { + a := &resolver.A{ + Hdr: resolver.RR_Header{ + Name: q.Name, + Rrtype: resolver.TypeA, + Class: resolver.ClassINET, + Ttl: ttl, + }, + } + sandboxID := strings.Split(q.Name, "-")[0] ip, found := d.get(sandboxID) if found { - a := &resolver.A{ - Hdr: resolver.RR_Header{ - Name: q.Name, - Rrtype: resolver.TypeA, - Class: resolver.ClassINET, - Ttl: ttl, - }, - A: net.ParseIP(ip).To4(), + a.A = net.ParseIP(ip).To4() + } else { + err := d.sandboxErrorChecker(sandboxID) + if err != nil { + a.A = net.ParseIP(defaultRoutingIP).To4() } - - m.Answer = append(m.Answer, a) } } } @@ -75,7 +87,7 @@ func (d *DNS) handleDNSRequest(w resolver.ResponseWriter, r *resolver.Msg) { } } -func (d *DNS) Start(address string, port int) error { +func (d *OrchDNS) Start(address string, port int) error { mux := resolver.NewServeMux() mux.HandleFunc(".", d.handleDNSRequest) @@ -87,5 +99,43 @@ func (d *DNS) Start(address string, port int) error { return fmt.Errorf("failed to start DNS server: %w", err) } + err = d.startErrorServer(defaultRoutingIP, defaultErrorPort) + if err != nil { + return fmt.Errorf("failed to start error HTTP server: %w", err) + } + + return nil +} + +func (d *OrchDNS) startErrorServer(address string, port int) error { + mux := http.NewServeMux() + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + parts := strings.Split(r.Host, "-") + sandboxID := "" + if len(parts) >= 2 { + sandboxID = parts[1] + } + + errMsg := "Sandbox does not exist." + + if err := d.sandboxErrorChecker(sandboxID); err != nil { + errMsg = err.Error() + } + + w.WriteHeader(http.StatusBadGateway) + w.Write([]byte(errMsg)) + }) + + server := &http.Server{ + Addr: fmt.Sprintf("%s:%d", address, port), + Handler: mux, + } + + err := server.ListenAndServe() + if err != nil { + return fmt.Errorf("failed to start error HTTP server: %w", err) + } + return nil } diff --git a/packages/orchestrator/internal/sandbox/checks.go b/packages/orchestrator/internal/sandbox/checks.go index 09538985b..0354c6084 100644 --- a/packages/orchestrator/internal/sandbox/checks.go +++ b/packages/orchestrator/internal/sandbox/checks.go @@ -2,7 +2,6 @@ package sandbox import ( "context" - "encoding/json" "fmt" "io" "net/http" @@ -81,46 +80,6 @@ func (s *Sandbox) Healthcheck(ctx context.Context, alwaysReport bool) { } } -func (s *Sandbox) GetMetrics(ctx context.Context) (SandboxMetrics, error) { - address := fmt.Sprintf("http://%s:%d/metrics", s.Slot.HostIP(), consts.DefaultEnvdServerPort) - - request, err := http.NewRequestWithContext(ctx, "GET", address, nil) - if err != nil { - return SandboxMetrics{}, err - } - - response, err := httpClient.Do(request) - if err != nil { - return SandboxMetrics{}, err - } - defer response.Body.Close() - - if response.StatusCode != http.StatusOK { - err = fmt.Errorf("unexpected status code: %d", response.StatusCode) - return SandboxMetrics{}, err - } - - var metrics SandboxMetrics - err = json.NewDecoder(response.Body).Decode(&metrics) - if err != nil { - return SandboxMetrics{}, err - } - - return metrics, nil -} - -func (s *Sandbox) LogMetrics(ctx context.Context) { - if isGTEVersion(s.Config.EnvdVersion, minEnvdVersionForMetrcis) { - metrics, err := s.GetMetrics(ctx) - if err != nil { - s.Logger.Warnf("failed to get metrics: %s", err) - } else { - s.Logger.Metrics( - metrics.MemTotalMiB, metrics.MemUsedMiB, metrics.CPUCount, metrics.CPUUsedPercent) - } - } -} - func isGTEVersion(curVersion, minVersion string) bool { if len(curVersion) > 0 && curVersion[0] != 'v' { curVersion = "v" + curVersion diff --git a/packages/orchestrator/internal/sandbox/metrics.go b/packages/orchestrator/internal/sandbox/metrics.go index 637610447..03f92c56a 100644 --- a/packages/orchestrator/internal/sandbox/metrics.go +++ b/packages/orchestrator/internal/sandbox/metrics.go @@ -1,5 +1,14 @@ package sandbox +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/e2b-dev/infra/packages/shared/pkg/consts" +) + type SandboxMetrics struct { Timestamp int64 `json:"ts"` // Unix Timestamp in UTC CPUCount uint32 `json:"cpu_count"` // Total CPU cores @@ -7,3 +16,43 @@ type SandboxMetrics struct { MemTotalMiB uint64 `json:"mem_total_mib"` // Total virtual memory in MiB MemUsedMiB uint64 `json:"mem_used_mib"` // Used virtual memory in MiB } + +func (s *Sandbox) GetMetrics(ctx context.Context) (SandboxMetrics, error) { + address := fmt.Sprintf("http://%s:%d/metrics", s.Slot.HostIP(), consts.DefaultEnvdServerPort) + + request, err := http.NewRequestWithContext(ctx, "GET", address, nil) + if err != nil { + return SandboxMetrics{}, err + } + + response, err := httpClient.Do(request) + if err != nil { + return SandboxMetrics{}, err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + err = fmt.Errorf("unexpected status code: %d", response.StatusCode) + return SandboxMetrics{}, err + } + + var metrics SandboxMetrics + err = json.NewDecoder(response.Body).Decode(&metrics) + if err != nil { + return SandboxMetrics{}, err + } + + return metrics, nil +} + +func (s *Sandbox) LogMetrics(ctx context.Context) { + if isGTEVersion(s.Config.EnvdVersion, minEnvdVersionForMetrcis) { + metrics, err := s.GetMetrics(ctx) + if err != nil { + s.Logger.Warnf("failed to get metrics: %s", err) + } else { + s.Logger.Metrics( + metrics.MemTotalMiB, metrics.MemUsedMiB, metrics.CPUCount, metrics.CPUUsedPercent) + } + } +} diff --git a/packages/orchestrator/internal/sandbox/sandbox.go b/packages/orchestrator/internal/sandbox/sandbox.go index b2093d2a3..4c7b0c0f8 100644 --- a/packages/orchestrator/internal/sandbox/sandbox.go +++ b/packages/orchestrator/internal/sandbox/sandbox.go @@ -62,7 +62,7 @@ type Sandbox struct { func NewSandbox( ctx context.Context, tracer trace.Tracer, - dns *dns.DNS, + dns *dns.OrchDNS, networkPool *network.Pool, templateCache *template.Cache, config *orchestrator.SandboxConfig, diff --git a/packages/orchestrator/internal/server/main.go b/packages/orchestrator/internal/server/main.go index 46f44a457..41f881246 100644 --- a/packages/orchestrator/internal/server/main.go +++ b/packages/orchestrator/internal/server/main.go @@ -5,8 +5,10 @@ import ( "fmt" "log" "sync" + "time" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/recovery" + "github.com/jellydator/ttlcache/v3" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/trace" @@ -24,14 +26,16 @@ import ( ) const ServiceName = "orchestrator" +const ExitErrorExpiration = 60 * time.Second type server struct { orchestrator.UnimplementedSandboxServiceServer - sandboxes *smap.Map[*sandbox.Sandbox] - dns *dns.DNS - tracer trace.Tracer - networkPool *network.Pool - templateCache *template.Cache + sandboxes *smap.Map[*sandbox.Sandbox] + sandboxExitErrors *ttlcache.Cache[string, error] + dns *dns.OrchDNS + tracer trace.Tracer + networkPool *network.Pool + templateCache *template.Cache pauseMu sync.Mutex } @@ -39,7 +43,13 @@ type server struct { func New() (*grpc.Server, error) { ctx := context.Background() - dnsServer := dns.New() + sandboxExitErrors := ttlcache.New(ttlcache.WithTTL[string, error](ExitErrorExpiration)) + sandboxErrorChecker := func(sandboxID string) error { + item := sandboxExitErrors.Get(sandboxID) + return item.Value() + } + + dnsServer := dns.New(sandboxErrorChecker) go func() { log.Printf("Starting DNS server") @@ -67,11 +77,12 @@ func New() (*grpc.Server, error) { ) orchestrator.RegisterSandboxServiceServer(s, &server{ - tracer: otel.Tracer(ServiceName), - dns: dnsServer, - sandboxes: smap.New[*sandbox.Sandbox](), - networkPool: networkPool, - templateCache: templateCache, + tracer: otel.Tracer(ServiceName), + dns: dnsServer, + sandboxes: smap.New[*sandbox.Sandbox](), + sandboxExitErrors: sandboxExitErrors, + networkPool: networkPool, + templateCache: templateCache, }) grpc_health_v1.RegisterHealthServer(s, health.NewServer()) diff --git a/packages/orchestrator/internal/server/sandboxes.go b/packages/orchestrator/internal/server/sandboxes.go index e95997b32..292844c73 100644 --- a/packages/orchestrator/internal/server/sandboxes.go +++ b/packages/orchestrator/internal/server/sandboxes.go @@ -77,6 +77,7 @@ func (s *server) Create(ctx context.Context, req *orchestrator.SandboxCreateRequ waitErr := sbx.Wait() if waitErr != nil { fmt.Fprintf(os.Stderr, "failed to wait for Sandbox: %v\n", waitErr) + s.sandboxExitErrors.Set(req.Sandbox.SandboxId, waitErr, 0) } cleanupErr := cleanup.Run()