diff --git a/air_example.toml b/air_example.toml index 8540084f..5f3f2128 100644 --- a/air_example.toml +++ b/air_example.toml @@ -73,3 +73,9 @@ clean_on_exit = true [screen] clear_on_rebuild = true keep_scroll = true + +# Enable live-reloading on the browser. This is useful when developing UI applications. +[proxy] + enabled = true + proxy_port = 8090 + app_port = 8080 diff --git a/runner/config.go b/runner/config.go index 5cffa99b..9ef4498d 100644 --- a/runner/config.go +++ b/runner/config.go @@ -31,6 +31,7 @@ type Config struct { Log cfgLog `toml:"log"` Misc cfgMisc `toml:"misc"` Screen cfgScreen `toml:"screen"` + Proxy cfgProxy `toml:"proxy"` } type cfgBuild struct { @@ -96,6 +97,12 @@ type cfgScreen struct { KeepScroll bool `toml:"keep_scroll"` } +type cfgProxy struct { + Enabled bool `toml:"enabled"` + Port int `toml:"proxy_port"` + AppPort int `toml:"app_port"` +} + type sliceTransformer struct{} func (t sliceTransformer) Transformer(typ reflect.Type) func(dst, src reflect.Value) error { @@ -350,10 +357,9 @@ func (c *Config) killDelay() time.Duration { // interpret as milliseconds if less than the value of 1 millisecond if c.Build.KillDelay < time.Millisecond { return c.Build.KillDelay * time.Millisecond - } else { - // normalize kill delay to milliseconds - return time.Duration(c.Build.KillDelay.Milliseconds()) * time.Millisecond } + // normalize kill delay to milliseconds + return time.Duration(c.Build.KillDelay.Milliseconds()) * time.Millisecond } func (c *Config) binPath() string { diff --git a/runner/engine.go b/runner/engine.go index b30c9c2a..71f43225 100644 --- a/runner/engine.go +++ b/runner/engine.go @@ -18,6 +18,7 @@ import ( // Engine ... type Engine struct { config *Config + proxy *Proxy logger *logger watcher filenotify.FileWatcher debugMode bool @@ -48,6 +49,7 @@ func NewEngineWithConfig(cfg *Config, debugMode bool) (*Engine, error) { } e := Engine{ config: cfg, + proxy: NewProxy(&cfg.Proxy), logger: logger, watcher: watcher, debugMode: debugMode, @@ -310,6 +312,13 @@ func (e *Engine) isModified(filename string) bool { // Endless loop and never return func (e *Engine) start() { + if e.config.Proxy.Enabled { + go func() { + e.mainLog("Proxy server listening on %s", e.proxy.server.Addr) + e.proxy.Run() + }() + } + e.running = true firstRunCh := make(chan bool, 1) firstRunCh <- true @@ -347,6 +356,9 @@ func (e *Engine) start() { } } + if e.config.Proxy.Enabled { + e.proxy.Reload() + } e.mainLog("%s has changed", e.config.rel(filename)) case <-firstRunCh: // go down @@ -535,6 +547,9 @@ func (e *Engine) runBin() error { cmd, stdout, stderr, _ := e.startCmd(command) processExit := make(chan struct{}) e.mainDebug("running process pid %v", cmd.Process.Pid) + if e.proxy.config.Enabled { + e.proxy.Reload() + } wg.Add(1) atomic.AddUint64(&e.round, 1) @@ -579,6 +594,11 @@ func (e *Engine) cleanup() { e.mainLog("cleaning...") defer e.mainLog("see you again~") + if e.config.Proxy.Enabled { + e.mainDebug("powering down the proxy...") + e.proxy.Stop() + } + e.withLock(func() { close(e.binStopCh) e.binStopCh = make(chan bool) diff --git a/runner/proxy.go b/runner/proxy.go new file mode 100644 index 00000000..74478382 --- /dev/null +++ b/runner/proxy.go @@ -0,0 +1,155 @@ +package runner + +import ( + "bytes" + "errors" + "fmt" + "io" + "log" + "net/http" + "strconv" + "strings" + "syscall" + "time" +) + +type Reloader interface { + AddSubscriber() *Subscriber + RemoveSubscriber(id int) + Reload() + Stop() +} + +type Proxy struct { + server *http.Server + config *cfgProxy + stream Reloader +} + +func NewProxy(cfg *cfgProxy) *Proxy { + p := &Proxy{ + config: cfg, + server: &http.Server{ + Addr: fmt.Sprintf("localhost:%d", cfg.Port), + }, + stream: NewProxyStream(), + } + return p +} + +func (p *Proxy) Run() { + http.HandleFunc("/", p.proxyHandler) + http.HandleFunc("/internal/reload", p.reloadHandler) + log.Fatal(p.server.ListenAndServe()) +} + +func (p *Proxy) Stop() { + p.server.Close() + p.stream.Stop() +} + +func (p *Proxy) Reload() { + p.stream.Reload() +} + +func (p *Proxy) injectLiveReload(origURL string, respBody io.ReadCloser) string { + buf := new(bytes.Buffer) + if _, err := buf.ReadFrom(respBody); err != nil { + panic("failed to convert request body to bytes buffer") + } + s := buf.String() + + body := strings.LastIndex(s, "") + if body == -1 { + panic("invalid html") + } + script := ` + + ` + parsedScript := fmt.Sprintf(script, p.config.Port, origURL) + + s = s[:body] + parsedScript + s[body:] + return s +} + +func (p *Proxy) proxyHandler(w http.ResponseWriter, r *http.Request) { + url := fmt.Sprintf("http://localhost:%d", p.config.AppPort) + req, err := http.NewRequest(r.Method, url, r.Body) + if err != nil { + panic(err) + } + req.Header.Set("X-Forwarded-For", r.RemoteAddr) + + client := &http.Client{} + var resp *http.Response + for i := 0; i < 10; i++ { + resp, err = client.Do(req) + if err == nil { + break + } + if !errors.Is(err, syscall.ECONNREFUSED) { + log.Fatalf("failed to call http://localhost:%d, err: %+v", p.config.AppPort, err) + } + time.Sleep(100 * time.Millisecond) + } + defer resp.Body.Close() + + // copy all headers except Content-Length + for k, vv := range resp.Header { + for _, v := range vv { + if k == "Content-Length" { + continue + } + w.Header().Add(k, v) + } + } + w.WriteHeader(resp.StatusCode) + + if strings.Contains(resp.Header.Get("Content-Type"), "text/html") { + s := p.injectLiveReload(r.URL.String(), resp.Body) + w.Header().Set("Content-Length", strconv.Itoa((len([]byte(s))))) + if _, err := io.WriteString(w, s); err != nil { + panic("failed to write injected payload") + } + } else { + w.Header().Set("Content-Length", resp.Header.Get("Content-Length")) + if _, err := io.Copy(w, resp.Body); err != nil { + panic("failed to write normal payload") + } + } +} + +func (p *Proxy) reloadHandler(w http.ResponseWriter, r *http.Request) { + flusher, err := w.(http.Flusher) + if !err { + http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + sub := p.stream.AddSubscriber() + go func() { + <-r.Context().Done() + p.stream.RemoveSubscriber(sub.id) + }() + + w.WriteHeader(http.StatusOK) + flusher.Flush() + + for range sub.reloadCh { + fmt.Fprintf(w, "data: reload\n\n") + flusher.Flush() + } +} diff --git a/runner/proxy_stream.go b/runner/proxy_stream.go new file mode 100644 index 00000000..412bee3f --- /dev/null +++ b/runner/proxy_stream.go @@ -0,0 +1,50 @@ +package runner + +import ( + "sync" +) + +type ProxyStream struct { + sync.Mutex + subscribers map[int]*Subscriber + count int +} + +type Subscriber struct { + id int + reloadCh chan struct{} +} + +func NewProxyStream() *ProxyStream { + return &ProxyStream{subscribers: make(map[int]*Subscriber)} +} + +func (stream *ProxyStream) Stop() { + for id := range stream.subscribers { + stream.RemoveSubscriber(id) + } + stream.count = 0 +} + +func (stream *ProxyStream) AddSubscriber() *Subscriber { + stream.Lock() + defer stream.Unlock() + stream.count++ + + sub := &Subscriber{id: stream.count, reloadCh: make(chan struct{})} + stream.subscribers[stream.count] = sub + return sub +} + +func (stream *ProxyStream) RemoveSubscriber(id int) { + stream.Lock() + defer stream.Unlock() + close(stream.subscribers[id].reloadCh) + delete(stream.subscribers, id) +} + +func (stream *ProxyStream) Reload() { + for _, sub := range stream.subscribers { + sub.reloadCh <- struct{}{} + } +} diff --git a/runner/proxy_stream_test.go b/runner/proxy_stream_test.go new file mode 100644 index 00000000..daf536e3 --- /dev/null +++ b/runner/proxy_stream_test.go @@ -0,0 +1,66 @@ +package runner + +import ( + "sync" + "testing" +) + +func find(s map[int]*Subscriber, id int) bool { + for _, sub := range s { + if sub.id == id { + return true + } + } + return false +} + +func TestProxyStream(t *testing.T) { + stream := NewProxyStream() + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + _ = stream.AddSubscriber() + }(i) + } + wg.Wait() + + if got, exp := len(stream.subscribers), 10; got != exp { + t.Errorf("expected %d but got %d", exp, got) + } + + go func() { + stream.Reload() + }() + + reloadCount := 0 + for _, sub := range stream.subscribers { + wg.Add(1) + go func(sub *Subscriber) { + defer wg.Done() + <-sub.reloadCh + reloadCount++ + }(sub) + } + wg.Wait() + + if got, exp := reloadCount, 10; got != exp { + t.Errorf("expected %d but got %d", exp, got) + } + + stream.RemoveSubscriber(2) + stream.AddSubscriber() + if got, exp := find(stream.subscribers, 2), false; got != exp { + t.Errorf("expected subscriber found to be %t but got %t", exp, got) + } + if got, exp := find(stream.subscribers, 11), true; got != exp { + t.Errorf("expected subscriber found to be %t but got %t", exp, got) + } + + stream.Stop() + if got, exp := len(stream.subscribers), 0; got != exp { + t.Errorf("expected %d but got %d", exp, got) + } +} diff --git a/runner/proxy_test.go b/runner/proxy_test.go new file mode 100644 index 00000000..e4b80264 --- /dev/null +++ b/runner/proxy_test.go @@ -0,0 +1,132 @@ +package runner + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strconv" + "sync" + "testing" +) + +type reloader struct { + subCh chan struct{} + reloadCh chan struct{} +} + +func (r *reloader) AddSubscriber() *Subscriber { + r.subCh <- struct{}{} + return &Subscriber{reloadCh: r.reloadCh} +} + +func (r *reloader) RemoveSubscriber(_ int) { + close(r.subCh) +} + +func (r *reloader) Reload() {} +func (r *reloader) Stop() {} + +func setupAppServer(t *testing.T) (srv *httptest.Server, port int) { + srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "thin air") + })) + mockURL, err := url.Parse(srv.URL) + if err != nil { + t.Fatal(err) + } + port, err = strconv.Atoi(mockURL.Port()) + if err != nil { + t.Fatal(err) + } + return srv, port +} + +func TestNewProxy(t *testing.T) { + _ = os.Unsetenv(airWd) + cfg := &cfgProxy{ + Enabled: true, + Port: 1111, + AppPort: 2222, + } + proxy := NewProxy(cfg) + if proxy.config == nil { + t.Fatal("Config should not be nil") + } + if proxy.server == nil { + t.Fatal("watcher should not be nil") + } +} + +func TestProxy_proxyHandler(t *testing.T) { + srv, appPort := setupAppServer(t) + defer srv.Close() + + cfg := &cfgProxy{ + Enabled: true, + Port: 8090, + AppPort: appPort, + } + proxy := NewProxy(cfg) + + req := httptest.NewRequest("GET", "http://localhost:8090/", nil) + rec := httptest.NewRecorder() + + proxy.proxyHandler(rec, req) + resp := rec.Result() + bodyBytes, _ := io.ReadAll(resp.Body) + if got, exp := string(bodyBytes), "thin air"; got != exp { + t.Errorf("expected %q but got %q", exp, got) + } +} + +func TestProxy_reloadHandler(t *testing.T) { + srv, appPort := setupAppServer(t) + defer srv.Close() + + reloader := &reloader{subCh: make(chan struct{}), reloadCh: make(chan struct{})} + cfg := &cfgProxy{ + Enabled: true, + Port: 8090, + AppPort: appPort, + } + proxy := &Proxy{ + config: cfg, + server: &http.Server{ + Addr: fmt.Sprintf("localhost:%d", cfg.Port), + }, + stream: reloader, + } + + req := httptest.NewRequest("GET", "http://localhost:8090/internal/reload", nil) + rec := httptest.NewRecorder() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + proxy.reloadHandler(rec, req) + }() + + // wait for subscriber to be added + <-reloader.subCh + + // send a reload event and wait for http response + reloader.reloadCh <- struct{}{} + close(reloader.reloadCh) + wg.Wait() + + if !rec.Flushed { + t.Errorf("request should have been flushed") + } + + resp := rec.Result() + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Errorf("reading body: %v", err) + } + if got, exp := string(bodyBytes), "data: reload\n\n"; got != exp { + t.Errorf("expected %q but got %q", exp, got) + } +}