From c26ab190e73ca76d346e6fe4087e62b0cdf3efd8 Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Tue, 4 Oct 2022 14:35:10 +0300 Subject: [PATCH 1/9] Pull request: imp-json-resp Merge in DNS/adguard-home from imp-json-resp to master Squashed commit of the following: commit 44532b6fa551815e5ea876e7320ce0a73c32b6fb Author: Ainar Garipov Date: Fri Sep 30 15:59:58 2022 +0300 all: imp json resp --- internal/aghhttp/aghhttp.go | 11 ++++++-- internal/dhcpd/http_unix.go | 13 +--------- internal/dhcpd/http_windows.go | 11 ++------ internal/filtering/blocked.go | 16 ++---------- internal/filtering/http.go | 27 +++----------------- internal/filtering/rewrites.go | 8 +----- internal/filtering/safebrowsing.go | 19 +++++--------- internal/filtering/safesearch.go | 15 +++-------- internal/querylog/http.go | 40 +++++------------------------- internal/stats/http.go | 15 ++--------- 10 files changed, 35 insertions(+), 140 deletions(-) diff --git a/internal/aghhttp/aghhttp.go b/internal/aghhttp/aghhttp.go index f03ebf7d967..bde0112aec9 100644 --- a/internal/aghhttp/aghhttp.go +++ b/internal/aghhttp/aghhttp.go @@ -62,9 +62,16 @@ func WriteTextPlainDeprecated(w http.ResponseWriter, r *http.Request) (isPlainTe } // WriteJSONResponse sets the content-type header in w.Header() to -// "application/json", encodes resp to w, calls Error on any returned error, and -// returns it as well. +// "application/json", writes a header with a "200 OK" status, encodes resp to +// w, calls [Error] on any returned error, and returns it as well. func WriteJSONResponse(w http.ResponseWriter, r *http.Request, resp any) (err error) { + return WriteJSONResponseCode(w, r, http.StatusOK, resp) +} + +// WriteJSONResponseCode is like [WriteJSONResponse] but adds the ability to +// redefine the status code. +func WriteJSONResponseCode(w http.ResponseWriter, r *http.Request, code int, resp any) (err error) { + w.WriteHeader(code) w.Header().Set(HdrNameContentType, HdrValApplicationJSON) err = json.NewEncoder(w).Encode(resp) if err != nil { diff --git a/internal/dhcpd/http_unix.go b/internal/dhcpd/http_unix.go index de06431f666..ab3ce318076 100644 --- a/internal/dhcpd/http_unix.go +++ b/internal/dhcpd/http_unix.go @@ -78,18 +78,7 @@ func (s *server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) { status.Leases = s.Leases(LeasesDynamic) status.StaticLeases = s.Leases(LeasesStatic) - w.Header().Set("Content-Type", "application/json") - - err := json.NewEncoder(w).Encode(status) - if err != nil { - aghhttp.Error( - r, - w, - http.StatusInternalServerError, - "Unable to marshal DHCP status json: %s", - err, - ) - } + _ = aghhttp.WriteJSONResponse(w, r, status) } func (s *server) enableDHCP(ifaceName string) (code int, err error) { diff --git a/internal/dhcpd/http_windows.go b/internal/dhcpd/http_windows.go index 5f7f73c1d0c..fda72d4876d 100644 --- a/internal/dhcpd/http_windows.go +++ b/internal/dhcpd/http_windows.go @@ -3,11 +3,10 @@ package dhcpd import ( - "encoding/json" "net/http" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghos" - "github.com/AdguardTeam/golibs/log" ) // jsonError is a generic JSON error response. @@ -25,15 +24,9 @@ type jsonError struct { // TODO(a.garipov): Either take the logger from the server after we've // refactored logging or make this not a method of *Server. func (s *server) notImplemented(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusNotImplemented) - - err := json.NewEncoder(w).Encode(&jsonError{ + _ = aghhttp.WriteJSONResponseCode(w, r, http.StatusNotImplemented, &jsonError{ Message: aghos.Unsupported("dhcp").Error(), }) - if err != nil { - log.Debug("writing 501 json response: %s", err) - } } // registerHandlers sets the handlers for DHCP HTTP API that always respond with diff --git a/internal/filtering/blocked.go b/internal/filtering/blocked.go index 489def36703..b32cb01e7c0 100644 --- a/internal/filtering/blocked.go +++ b/internal/filtering/blocked.go @@ -453,13 +453,7 @@ func (d *DNSFilter) ApplyBlockedServices(setts *Settings, list []string) { } func (d *DNSFilter) handleBlockedServicesAvailableServices(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - err := json.NewEncoder(w).Encode(serviceIDs) - if err != nil { - aghhttp.Error(r, w, http.StatusInternalServerError, "encoding available services: %s", err) - - return - } + _ = aghhttp.WriteJSONResponse(w, r, serviceIDs) } func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Request) { @@ -467,13 +461,7 @@ func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Req list := d.Config.BlockedServices d.confLock.RUnlock() - w.Header().Set("Content-Type", "application/json") - err := json.NewEncoder(w).Encode(list) - if err != nil { - aghhttp.Error(r, w, http.StatusInternalServerError, "encoding services: %s", err) - - return - } + _ = aghhttp.WriteJSONResponse(w, r, list) } func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) { diff --git a/internal/filtering/http.go b/internal/filtering/http.go index 50890f9376e..5c311c4318e 100644 --- a/internal/filtering/http.go +++ b/internal/filtering/http.go @@ -301,14 +301,7 @@ func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques return } - w.Header().Set("Content-Type", "application/json") - - err = json.NewEncoder(w).Encode(resp) - if err != nil { - aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err) - - return - } + _ = aghhttp.WriteJSONResponse(w, r, resp) } type filterJSON struct { @@ -361,17 +354,7 @@ func (d *DNSFilter) handleFilteringStatus(w http.ResponseWriter, r *http.Request resp.UserRules = d.UserRules d.filtersMu.RUnlock() - jsonVal, err := json.Marshal(resp) - if err != nil { - aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err) - - return - } - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(jsonVal) - if err != nil { - aghhttp.Error(r, w, http.StatusInternalServerError, "http write: %s", err) - } + _ = aghhttp.WriteJSONResponse(w, r, resp) } // Set filtering configuration @@ -473,11 +456,7 @@ func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) { } } - w.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w).Encode(resp) - if err != nil { - aghhttp.Error(r, w, http.StatusInternalServerError, "encoding response: %s", err) - } + _ = aghhttp.WriteJSONResponse(w, r, resp) } // RegisterFilteringHandlers - register handlers diff --git a/internal/filtering/rewrites.go b/internal/filtering/rewrites.go index 8f0d5ebfcc7..2c09728f943 100644 --- a/internal/filtering/rewrites.go +++ b/internal/filtering/rewrites.go @@ -240,13 +240,7 @@ func (d *DNSFilter) handleRewriteList(w http.ResponseWriter, r *http.Request) { } d.confLock.Unlock() - w.Header().Set("Content-Type", "application/json") - err := json.NewEncoder(w).Encode(arr) - if err != nil { - aghhttp.Error(r, w, http.StatusInternalServerError, "json.Encode: %s", err) - - return - } + _ = aghhttp.WriteJSONResponse(w, r, arr) } func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) { diff --git a/internal/filtering/safebrowsing.go b/internal/filtering/safebrowsing.go index fe844977277..c6b7c34cc32 100644 --- a/internal/filtering/safebrowsing.go +++ b/internal/filtering/safebrowsing.go @@ -5,7 +5,6 @@ import ( "crypto/sha256" "encoding/binary" "encoding/hex" - "encoding/json" "fmt" "net" "net/http" @@ -381,17 +380,13 @@ func (d *DNSFilter) handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Req } func (d *DNSFilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - err := json.NewEncoder(w).Encode(&struct { + resp := &struct { Enabled bool `json:"enabled"` }{ Enabled: d.Config.SafeBrowsingEnabled, - }) - if err != nil { - aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) - - return } + + _ = aghhttp.WriteJSONResponse(w, r, resp) } func (d *DNSFilter) handleParentalEnable(w http.ResponseWriter, r *http.Request) { @@ -405,13 +400,11 @@ func (d *DNSFilter) handleParentalDisable(w http.ResponseWriter, r *http.Request } func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - err := json.NewEncoder(w).Encode(&struct { + resp := &struct { Enabled bool `json:"enabled"` }{ Enabled: d.Config.ParentalEnabled, - }) - if err != nil { - aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) } + + _ = aghhttp.WriteJSONResponse(w, r, resp) } diff --git a/internal/filtering/safesearch.go b/internal/filtering/safesearch.go index df2d2108fd2..8b3dcb9b430 100644 --- a/internal/filtering/safesearch.go +++ b/internal/filtering/safesearch.go @@ -5,7 +5,6 @@ import ( "context" "encoding/binary" "encoding/gob" - "encoding/json" "fmt" "net" "net/http" @@ -146,21 +145,13 @@ func (d *DNSFilter) handleSafeSearchDisable(w http.ResponseWriter, r *http.Reque } func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - err := json.NewEncoder(w).Encode(&struct { + resp := &struct { Enabled bool `json:"enabled"` }{ Enabled: d.Config.SafeSearchEnabled, - }) - if err != nil { - aghhttp.Error( - r, - w, - http.StatusInternalServerError, - "Unable to write response json: %s", - err, - ) } + + _ = aghhttp.WriteJSONResponse(w, r, resp) } var safeSearchDomains = map[string]string{ diff --git a/internal/querylog/http.go b/internal/querylog/http.go index 11f62d0d369..1fab138eb12 100644 --- a/internal/querylog/http.go +++ b/internal/querylog/http.go @@ -1,7 +1,6 @@ package querylog import ( - "encoding/json" "fmt" "net" "net/http" @@ -48,24 +47,7 @@ func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) { // convert log entries to JSON data := l.entriesToJSON(entries, oldest) - jsonVal, err := json.Marshal(data) - if err != nil { - aghhttp.Error( - r, - w, - http.StatusInternalServerError, - "Couldn't marshal data into json: %s", - err, - ) - - return - } - - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(jsonVal) - if err != nil { - aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) - } + _ = aghhttp.WriteJSONResponse(w, r, data) } func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, _ *http.Request) { @@ -74,23 +56,13 @@ func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, _ *http.Request) { // Get configuration func (l *queryLog) handleQueryLogInfo(w http.ResponseWriter, r *http.Request) { - resp := qlogConfig{} - resp.Enabled = l.conf.Enabled - resp.Interval = l.conf.RotationIvl.Hours() / 24 - resp.AnonymizeClientIP = l.conf.AnonymizeClientIP - - jsonVal, err := json.Marshal(resp) - if err != nil { - aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err) - - return + resp := qlogConfig{ + Enabled: l.conf.Enabled, + Interval: l.conf.RotationIvl.Hours() / 24, + AnonymizeClientIP: l.conf.AnonymizeClientIP, } - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(jsonVal) - if err != nil { - aghhttp.Error(r, w, http.StatusInternalServerError, "http write: %s", err) - } + _ = aghhttp.WriteJSONResponse(w, r, resp) } // AnonymizeIP masks ip to anonymize the client if the ip is a valid one. diff --git a/internal/stats/http.go b/internal/stats/http.go index ae980bf3e4c..b06a7cdc9a7 100644 --- a/internal/stats/http.go +++ b/internal/stats/http.go @@ -55,12 +55,7 @@ func (s *StatsCtx) handleStats(w http.ResponseWriter, r *http.Request) { return } - w.Header().Set("Content-Type", "application/json") - - err := json.NewEncoder(w).Encode(resp) - if err != nil { - aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err) - } + _ = aghhttp.WriteJSONResponse(w, r, resp) } // configResp is the response to the GET /control/stats_info. @@ -71,13 +66,7 @@ type configResp struct { // handleStatsInfo handles requests to the GET /control/stats_info endpoint. func (s *StatsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) { resp := configResp{IntervalDays: atomic.LoadUint32(&s.limitHours) / 24} - - w.Header().Set("Content-Type", "application/json") - - err := json.NewEncoder(w).Encode(resp) - if err != nil { - aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err) - } + _ = aghhttp.WriteJSONResponse(w, r, resp) } // handleStatsConfig handles requests to the POST /control/stats_config From fe8be3701f90ac5b83f0308c5e8c3d2c80fbe0b1 Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Tue, 4 Oct 2022 16:02:55 +0300 Subject: [PATCH 2/9] Pull request: websvc-config-manager Merge in DNS/adguard-home from websvc-config-manager to master Squashed commit of the following: commit 2143b47c6528030dfe059172888fddf9061e42da Author: Ainar Garipov Date: Tue Oct 4 14:50:47 2022 +0300 next: add config manager --- Makefile | 4 +- internal/aghchan/aghchan.go | 33 ++++ internal/aghnet/hostscontainer_test.go | 14 +- internal/aghtest/interface.go | 85 +++++--- internal/aghtest/interface_test.go | 4 +- internal/{v1 => next}/agh/agh.go | 0 internal/{v1 => next}/cmd/cmd.go | 15 +- internal/{v1 => next}/cmd/signal.go | 2 +- internal/{v1 => next}/dnssvc/dnssvc.go | 46 ++++- internal/{v1 => next}/dnssvc/dnssvc_test.go | 2 +- internal/next/websvc/dns.go | 84 ++++++++ internal/next/websvc/dns_test.go | 68 +++++++ internal/next/websvc/http.go | 109 ++++++++++ internal/next/websvc/http_test.go | 62 ++++++ internal/next/websvc/json.go | 143 ++++++++++++++ internal/next/websvc/json_test.go | 114 +++++++++++ internal/{v1 => next}/websvc/middleware.go | 0 internal/next/websvc/path.go | 11 ++ internal/next/websvc/settings.go | 42 ++++ internal/next/websvc/settings_test.go | 74 +++++++ internal/{v1 => next}/websvc/system.go | 6 +- internal/{v1 => next}/websvc/system_test.go | 7 +- internal/next/websvc/waitlistener.go | 31 +++ .../next/websvc/waitlistener_internal_test.go | 46 +++++ internal/{v1 => next}/websvc/websvc.go | 160 +++++++++++---- internal/next/websvc/websvc_internal_test.go | 6 + internal/next/websvc/websvc_test.go | 187 ++++++++++++++++++ internal/v1/websvc/json.go | 61 ------ internal/v1/websvc/path.go | 8 - internal/v1/websvc/websvc_test.go | 93 --------- internal/version/version.go | 16 +- main.go | 4 +- main_v1.go => main_next.go | 6 +- openapi/v1.yaml | 18 +- scripts/make/go-lint.sh | 4 +- 35 files changed, 1286 insertions(+), 279 deletions(-) create mode 100644 internal/aghchan/aghchan.go rename internal/{v1 => next}/agh/agh.go (100%) rename internal/{v1 => next}/cmd/cmd.go (80%) rename internal/{v1 => next}/cmd/signal.go (96%) rename internal/{v1 => next}/dnssvc/dnssvc.go (77%) rename internal/{v1 => next}/dnssvc/dnssvc_test.go (97%) create mode 100644 internal/next/websvc/dns.go create mode 100644 internal/next/websvc/dns_test.go create mode 100644 internal/next/websvc/http.go create mode 100644 internal/next/websvc/http_test.go create mode 100644 internal/next/websvc/json.go create mode 100644 internal/next/websvc/json_test.go rename internal/{v1 => next}/websvc/middleware.go (100%) create mode 100644 internal/next/websvc/path.go create mode 100644 internal/next/websvc/settings.go create mode 100644 internal/next/websvc/settings_test.go rename internal/{v1 => next}/websvc/system.go (87%) rename internal/{v1 => next}/websvc/system_test.go (82%) create mode 100644 internal/next/websvc/waitlistener.go create mode 100644 internal/next/websvc/waitlistener_internal_test.go rename internal/{v1 => next}/websvc/websvc.go (52%) create mode 100644 internal/next/websvc/websvc_internal_test.go create mode 100644 internal/next/websvc/websvc_test.go delete mode 100644 internal/v1/websvc/json.go delete mode 100644 internal/v1/websvc/path.go delete mode 100644 internal/v1/websvc/websvc_test.go rename main_v1.go => main_next.go (79%) diff --git a/Makefile b/Makefile index b4823bb73d1..cca890174ad 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,7 @@ YARN_INSTALL_FLAGS = $(YARN_FLAGS) --network-timeout 120000 --silent\ --ignore-engines --ignore-optional --ignore-platform\ --ignore-scripts -V1API = 0 +NEXTAPI = 0 # Macros for the build-release target. If FRONTEND_PREBUILT is 0, the # default, the macro $(BUILD_RELEASE_DEPS_$(FRONTEND_PREBUILT)) expands @@ -63,7 +63,7 @@ ENV = env\ PATH="$${PWD}/bin:$$( "$(GO.MACRO)" env GOPATH )/bin:$${PATH}"\ RACE='$(RACE)'\ SIGN='$(SIGN)'\ - V1API='$(V1API)'\ + NEXTAPI='$(NEXTAPI)'\ VERBOSE='$(VERBOSE)'\ VERSION='$(VERSION)'\ diff --git a/internal/aghchan/aghchan.go b/internal/aghchan/aghchan.go new file mode 100644 index 00000000000..1da1790a3f7 --- /dev/null +++ b/internal/aghchan/aghchan.go @@ -0,0 +1,33 @@ +// Package aghchan contains channel utilities. +package aghchan + +import ( + "fmt" + "time" +) + +// Receive returns an error if it cannot receive a value form c before timeout +// runs out. +func Receive[T any](c <-chan T, timeout time.Duration) (v T, ok bool, err error) { + var zero T + timeoutCh := time.After(timeout) + select { + case <-timeoutCh: + // TODO(a.garipov): Consider implementing [errors.Aser] for + // os.ErrTimeout. + return zero, false, fmt.Errorf("did not receive after %s", timeout) + case v, ok = <-c: + return v, ok, nil + } +} + +// MustReceive panics if it cannot receive a value form c before timeout runs +// out. +func MustReceive[T any](c <-chan T, timeout time.Duration) (v T, ok bool) { + v, ok, err := Receive(c, timeout) + if err != nil { + panic(err) + } + + return v, ok +} diff --git a/internal/aghnet/hostscontainer_test.go b/internal/aghnet/hostscontainer_test.go index 1f75a3c9e20..d2637d8581b 100644 --- a/internal/aghnet/hostscontainer_test.go +++ b/internal/aghnet/hostscontainer_test.go @@ -10,9 +10,9 @@ import ( "testing/fstest" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghchan" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/urlfilter" @@ -163,15 +163,9 @@ func TestHostsContainer_refresh(t *testing.T) { checkRefresh := func(t *testing.T, want *HostsRecord) { t.Helper() - var ok bool - var upd *netutil.IPMap - select { - case upd, ok = <-hc.Upd(): - require.True(t, ok) - require.NotNil(t, upd) - case <-time.After(1 * time.Second): - t.Fatal("did not receive after 1s") - } + upd, ok := aghchan.MustReceive(hc.Upd(), 1*time.Second) + require.True(t, ok) + require.NotNil(t, upd) assert.Equal(t, 1, upd.Len()) diff --git a/internal/aghtest/interface.go b/internal/aghtest/interface.go index 2de9d372b08..7aae35ee3ee 100644 --- a/internal/aghtest/interface.go +++ b/internal/aghtest/interface.go @@ -1,6 +1,7 @@ package aghtest import ( + "context" "io/fs" "net" @@ -15,6 +16,8 @@ import ( // Standard Library +// Package fs + // type check var _ fs.FS = &FS{} @@ -58,6 +61,8 @@ func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) { return fsys.OnStat(name) } +// Package net + // type check var _ net.Listener = (*Listener)(nil) @@ -83,32 +88,10 @@ func (l *Listener) Close() (err error) { return l.OnClose() } -// Module dnsproxy - -// type check -var _ upstream.Upstream = (*UpstreamMock)(nil) - -// UpstreamMock is a mock [upstream.Upstream] implementation for tests. -// -// TODO(a.garipov): Replace with all uses of Upstream with UpstreamMock and -// rename it to just Upstream. -type UpstreamMock struct { - OnAddress func() (addr string) - OnExchange func(req *dns.Msg) (resp *dns.Msg, err error) -} - -// Address implements the [upstream.Upstream] interface for *UpstreamMock. -func (u *UpstreamMock) Address() (addr string) { - return u.OnAddress() -} - -// Exchange implements the [upstream.Upstream] interface for *UpstreamMock. -func (u *UpstreamMock) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { - return u.OnExchange(req) -} - // Module AdGuardHome +// Package aghos + // type check var _ aghos.FSWatcher = (*FSWatcher)(nil) @@ -133,3 +116,57 @@ func (w *FSWatcher) Add(name string) (err error) { func (w *FSWatcher) Close() (err error) { return w.OnClose() } + +// Package websvc + +// ServiceWithConfig is a mock [websvc.ServiceWithConfig] implementation for +// tests. +type ServiceWithConfig[ConfigType any] struct { + OnStart func() (err error) + OnShutdown func(ctx context.Context) (err error) + OnConfig func() (c ConfigType) +} + +// Start implements the [websvc.ServiceWithConfig] interface for +// *ServiceWithConfig. +func (s *ServiceWithConfig[_]) Start() (err error) { + return s.OnStart() +} + +// Shutdown implements the [websvc.ServiceWithConfig] interface for +// *ServiceWithConfig. +func (s *ServiceWithConfig[_]) Shutdown(ctx context.Context) (err error) { + return s.OnShutdown(ctx) +} + +// Config implements the [websvc.ServiceWithConfig] interface for +// *ServiceWithConfig. +func (s *ServiceWithConfig[ConfigType]) Config() (c ConfigType) { + return s.OnConfig() +} + +// Module dnsproxy + +// Package upstream + +// type check +var _ upstream.Upstream = (*UpstreamMock)(nil) + +// UpstreamMock is a mock [upstream.Upstream] implementation for tests. +// +// TODO(a.garipov): Replace with all uses of Upstream with UpstreamMock and +// rename it to just Upstream. +type UpstreamMock struct { + OnAddress func() (addr string) + OnExchange func(req *dns.Msg) (resp *dns.Msg, err error) +} + +// Address implements the [upstream.Upstream] interface for *UpstreamMock. +func (u *UpstreamMock) Address() (addr string) { + return u.OnAddress() +} + +// Exchange implements the [upstream.Upstream] interface for *UpstreamMock. +func (u *UpstreamMock) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { + return u.OnExchange(req) +} diff --git a/internal/aghtest/interface_test.go b/internal/aghtest/interface_test.go index 5a465c2c873..bd2c0823e84 100644 --- a/internal/aghtest/interface_test.go +++ b/internal/aghtest/interface_test.go @@ -1,9 +1,9 @@ package aghtest_test import ( - "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" ) // type check -var _ aghos.FSWatcher = (*aghtest.FSWatcher)(nil) +var _ websvc.ServiceWithConfig[struct{}] = (*aghtest.ServiceWithConfig[struct{}])(nil) diff --git a/internal/v1/agh/agh.go b/internal/next/agh/agh.go similarity index 100% rename from internal/v1/agh/agh.go rename to internal/next/agh/agh.go diff --git a/internal/v1/cmd/cmd.go b/internal/next/cmd/cmd.go similarity index 80% rename from internal/v1/cmd/cmd.go rename to internal/next/cmd/cmd.go index 2f61509ba6d..5b329abf4a6 100644 --- a/internal/v1/cmd/cmd.go +++ b/internal/next/cmd/cmd.go @@ -11,29 +11,32 @@ import ( "net/netip" "time" - "github.com/AdguardTeam/AdGuardHome/internal/v1/websvc" + "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" "github.com/AdguardTeam/golibs/log" ) // Main is the entry point of application. func Main(clientBuildFS fs.FS) { - // # Initial Configuration + // Initial Configuration start := time.Now() rand.Seed(start.UnixNano()) // TODO(a.garipov): Set up logging. - // # Web Service + // Web Service // TODO(a.garipov): Use in the Web service. _ = clientBuildFS // TODO(a.garipov): Make configurable. web := websvc.New(&websvc.Config{ - Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:3001")}, - Start: start, - Timeout: 60 * time.Second, + // TODO(a.garipov): Use an actual implementation. + ConfigManager: nil, + Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:3001")}, + Start: start, + Timeout: 60 * time.Second, + ForceHTTPS: false, }) err := web.Start() diff --git a/internal/v1/cmd/signal.go b/internal/next/cmd/signal.go similarity index 96% rename from internal/v1/cmd/signal.go rename to internal/next/cmd/signal.go index b66075f6a64..122f3f2c7dd 100644 --- a/internal/v1/cmd/signal.go +++ b/internal/next/cmd/signal.go @@ -4,7 +4,7 @@ import ( "os" "github.com/AdguardTeam/AdGuardHome/internal/aghos" - "github.com/AdguardTeam/AdGuardHome/internal/v1/agh" + "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/golibs/log" ) diff --git a/internal/v1/dnssvc/dnssvc.go b/internal/next/dnssvc/dnssvc.go similarity index 77% rename from internal/v1/dnssvc/dnssvc.go rename to internal/next/dnssvc/dnssvc.go index ffe5b080449..f25fa294fe3 100644 --- a/internal/v1/dnssvc/dnssvc.go +++ b/internal/next/dnssvc/dnssvc.go @@ -9,9 +9,10 @@ import ( "fmt" "net" "net/netip" + "sync/atomic" "time" - "github.com/AdguardTeam/AdGuardHome/internal/v1/agh" + "github.com/AdguardTeam/AdGuardHome/internal/next/agh" // TODO(a.garipov): Add a “dnsproxy proxy” package to shield us from changes // and replacement of module dnsproxy. "github.com/AdguardTeam/dnsproxy/proxy" @@ -47,6 +48,14 @@ type Config struct { // Service is the AdGuard Home DNS service. A nil *Service is a valid // [agh.Service] that does nothing. type Service struct { + // running is an atomic boolean value. Keep it the first value in the + // struct to ensure atomic alignment. 0 means that the service is not + // running, 1 means that it is running. + // + // TODO(a.garipov): Use [atomic.Bool] in Go 1.19 or get rid of it + // completely. + running uint64 + proxy *proxy.Proxy bootstraps []string upstreams []string @@ -160,6 +169,17 @@ func (svc *Service) Start() (err error) { return nil } + defer func() { + // TODO(a.garipov): [proxy.Proxy.Start] doesn't actually have any way to + // tell when all servers are actually up, so at best this is merely an + // assumption. + if err != nil { + atomic.StoreUint64(&svc.running, 0) + } else { + atomic.StoreUint64(&svc.running, 1) + } + }() + return svc.proxy.Start() } @@ -173,13 +193,27 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) { return svc.proxy.Stop() } -// Config returns the current configuration of the web service. +// Config returns the current configuration of the web service. Config must not +// be called simultaneously with Start. If svc was initialized with ":0" +// addresses, addrs will not return the actual bound ports until Start is +// finished. func (svc *Service) Config() (c *Config) { // TODO(a.garipov): Do we need to get the TCP addresses separately? - udpAddrs := svc.proxy.Addrs(proxy.ProtoUDP) - addrs := make([]netip.AddrPort, len(udpAddrs)) - for i, a := range udpAddrs { - addrs[i] = a.(*net.UDPAddr).AddrPort() + + var addrs []netip.AddrPort + if atomic.LoadUint64(&svc.running) == 1 { + udpAddrs := svc.proxy.Addrs(proxy.ProtoUDP) + addrs = make([]netip.AddrPort, len(udpAddrs)) + for i, a := range udpAddrs { + addrs[i] = a.(*net.UDPAddr).AddrPort() + } + } else { + conf := svc.proxy.Config + udpAddrs := conf.UDPListenAddr + addrs = make([]netip.AddrPort, len(udpAddrs)) + for i, a := range udpAddrs { + addrs[i] = a.AddrPort() + } } c = &Config{ diff --git a/internal/v1/dnssvc/dnssvc_test.go b/internal/next/dnssvc/dnssvc_test.go similarity index 97% rename from internal/v1/dnssvc/dnssvc_test.go rename to internal/next/dnssvc/dnssvc_test.go index 5bc3b5621a3..8205897c73a 100644 --- a/internal/v1/dnssvc/dnssvc_test.go +++ b/internal/next/dnssvc/dnssvc_test.go @@ -7,7 +7,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" - "github.com/AdguardTeam/AdGuardHome/internal/v1/dnssvc" + "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/miekg/dns" "github.com/stretchr/testify/assert" diff --git a/internal/next/websvc/dns.go b/internal/next/websvc/dns.go new file mode 100644 index 00000000000..8846813d424 --- /dev/null +++ b/internal/next/websvc/dns.go @@ -0,0 +1,84 @@ +package websvc + +import ( + "encoding/json" + "fmt" + "net/http" + "net/netip" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" +) + +// DNS Settings Handlers + +// ReqPatchSettingsDNS describes the request to the PATCH /api/v1/settings/dns +// HTTP API. +type ReqPatchSettingsDNS struct { + // TODO(a.garipov): Add more as we go. + + Addresses []netip.AddrPort `json:"addresses"` + BootstrapServers []string `json:"bootstrap_servers"` + UpstreamServers []string `json:"upstream_servers"` + UpstreamTimeout JSONDuration `json:"upstream_timeout"` +} + +// HTTPAPIDNSSettings are the DNS settings as used by the HTTP API. See the +// DnsSettings object in the OpenAPI specification. +type HTTPAPIDNSSettings struct { + // TODO(a.garipov): Add more as we go. + + Addresses []netip.AddrPort `json:"addresses"` + BootstrapServers []string `json:"bootstrap_servers"` + UpstreamServers []string `json:"upstream_servers"` + UpstreamTimeout JSONDuration `json:"upstream_timeout"` +} + +// handlePatchSettingsDNS is the handler for the PATCH /api/v1/settings/dns HTTP +// API. +func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Request) { + req := &ReqPatchSettingsDNS{ + Addresses: []netip.AddrPort{}, + BootstrapServers: []string{}, + UpstreamServers: []string{}, + } + + // TODO(a.garipov): Validate nulls and proper JSON patch. + + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + writeJSONErrorResponse(w, r, fmt.Errorf("decoding: %w", err)) + + return + } + + newConf := &dnssvc.Config{ + Addresses: req.Addresses, + BootstrapServers: req.BootstrapServers, + UpstreamServers: req.UpstreamServers, + UpstreamTimeout: time.Duration(req.UpstreamTimeout), + } + + ctx := r.Context() + err = svc.confMgr.UpdateDNS(ctx, newConf) + if err != nil { + writeJSONErrorResponse(w, r, fmt.Errorf("updating: %w", err)) + + return + } + + newSvc := svc.confMgr.DNS() + err = newSvc.Start() + if err != nil { + writeJSONErrorResponse(w, r, fmt.Errorf("starting new service: %w", err)) + + return + } + + writeJSONOKResponse(w, r, &HTTPAPIDNSSettings{ + Addresses: newConf.Addresses, + BootstrapServers: newConf.BootstrapServers, + UpstreamServers: newConf.UpstreamServers, + UpstreamTimeout: JSONDuration(newConf.UpstreamTimeout), + }) +} diff --git a/internal/next/websvc/dns_test.go b/internal/next/websvc/dns_test.go new file mode 100644 index 00000000000..f774c3d87dd --- /dev/null +++ b/internal/next/websvc/dns_test.go @@ -0,0 +1,68 @@ +package websvc_test + +import ( + "context" + "encoding/json" + "net/http" + "net/netip" + "net/url" + "sync/atomic" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" + "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestService_HandlePatchSettingsDNS(t *testing.T) { + wantDNS := &websvc.HTTPAPIDNSSettings{ + Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.1.1:53")}, + BootstrapServers: []string{"1.0.0.1"}, + UpstreamServers: []string{"1.1.1.1"}, + UpstreamTimeout: websvc.JSONDuration(2 * time.Second), + } + + // TODO(a.garipov): Use [atomic.Bool] in Go 1.19. + var numStarted uint64 + confMgr := newConfigManager() + confMgr.onDNS = func() (s websvc.ServiceWithConfig[*dnssvc.Config]) { + return &aghtest.ServiceWithConfig[*dnssvc.Config]{ + OnStart: func() (err error) { + atomic.AddUint64(&numStarted, 1) + + return nil + }, + OnShutdown: func(_ context.Context) (err error) { panic("not implemented") }, + OnConfig: func() (c *dnssvc.Config) { panic("not implemented") }, + } + } + confMgr.onUpdateDNS = func(ctx context.Context, c *dnssvc.Config) (err error) { + return nil + } + + _, addr := newTestServer(t, confMgr) + u := &url.URL{ + Scheme: "http", + Host: addr.String(), + Path: websvc.PathV1SettingsDNS, + } + + req := jobj{ + "addresses": wantDNS.Addresses, + "bootstrap_servers": wantDNS.BootstrapServers, + "upstream_servers": wantDNS.UpstreamServers, + "upstream_timeout": wantDNS.UpstreamTimeout, + } + + respBody := httpPatch(t, u, req, http.StatusOK) + resp := &websvc.HTTPAPIDNSSettings{} + err := json.Unmarshal(respBody, resp) + require.NoError(t, err) + + assert.Equal(t, uint64(1), numStarted) + assert.Equal(t, wantDNS, resp) + assert.Equal(t, wantDNS, resp) +} diff --git a/internal/next/websvc/http.go b/internal/next/websvc/http.go new file mode 100644 index 00000000000..b58eecb9499 --- /dev/null +++ b/internal/next/websvc/http.go @@ -0,0 +1,109 @@ +package websvc + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/netip" + "time" + + "github.com/AdguardTeam/golibs/log" +) + +// HTTP Settings Handlers + +// ReqPatchSettingsHTTP describes the request to the PATCH /api/v1/settings/http +// HTTP API. +type ReqPatchSettingsHTTP struct { + // TODO(a.garipov): Add more as we go. + // + // TODO(a.garipov): Add wait time. + + Addresses []netip.AddrPort `json:"addresses"` + SecureAddresses []netip.AddrPort `json:"secure_addresses"` + Timeout JSONDuration `json:"timeout"` +} + +// HTTPAPIHTTPSettings are the HTTP settings as used by the HTTP API. See the +// HttpSettings object in the OpenAPI specification. +type HTTPAPIHTTPSettings struct { + // TODO(a.garipov): Add more as we go. + + Addresses []netip.AddrPort `json:"addresses"` + SecureAddresses []netip.AddrPort `json:"secure_addresses"` + Timeout JSONDuration `json:"timeout"` + ForceHTTPS bool `json:"force_https"` +} + +// handlePatchSettingsHTTP is the handler for the PATCH /api/v1/settings/http +// HTTP API. +func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Request) { + req := &ReqPatchSettingsHTTP{} + + // TODO(a.garipov): Validate nulls and proper JSON patch. + + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + writeJSONErrorResponse(w, r, fmt.Errorf("decoding: %w", err)) + + return + } + + newConf := &Config{ + ConfigManager: svc.confMgr, + TLS: svc.tls, + Addresses: req.Addresses, + SecureAddresses: req.SecureAddresses, + Timeout: time.Duration(req.Timeout), + ForceHTTPS: svc.forceHTTPS, + } + + writeJSONOKResponse(w, r, &HTTPAPIHTTPSettings{ + Addresses: newConf.Addresses, + SecureAddresses: newConf.SecureAddresses, + Timeout: JSONDuration(newConf.Timeout), + ForceHTTPS: newConf.ForceHTTPS, + }) + + cancelUpd := func() {} + updCtx := context.Background() + + ctx := r.Context() + if deadline, ok := ctx.Deadline(); ok { + updCtx, cancelUpd = context.WithDeadline(updCtx, deadline) + } + + // Launch the new HTTP service in a separate goroutine to let this handler + // finish and thus, this server to shutdown. + go func() { + defer cancelUpd() + + updErr := svc.confMgr.UpdateWeb(updCtx, newConf) + if updErr != nil { + writeJSONErrorResponse(w, r, fmt.Errorf("updating: %w", updErr)) + + return + } + + // TODO(a.garipov): Consider better ways to do this. + const maxUpdDur = 10 * time.Second + updStart := time.Now() + var newSvc ServiceWithConfig[*Config] + for newSvc = svc.confMgr.Web(); newSvc == svc; { + if time.Since(updStart) >= maxUpdDur { + log.Error("websvc: failed to update svc after %s", maxUpdDur) + + return + } + + log.Debug("websvc: waiting for new websvc to be configured") + time.Sleep(1 * time.Second) + } + + updErr = newSvc.Start() + if updErr != nil { + log.Error("websvc: new svc failed to start with error: %s", updErr) + } + }() +} diff --git a/internal/next/websvc/http_test.go b/internal/next/websvc/http_test.go new file mode 100644 index 00000000000..baf384da296 --- /dev/null +++ b/internal/next/websvc/http_test.go @@ -0,0 +1,62 @@ +package websvc_test + +import ( + "context" + "crypto/tls" + "encoding/json" + "net/http" + "net/netip" + "net/url" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestService_HandlePatchSettingsHTTP(t *testing.T) { + wantWeb := &websvc.HTTPAPIHTTPSettings{ + Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.1.1:80")}, + SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.1.1:443")}, + Timeout: websvc.JSONDuration(10 * time.Second), + ForceHTTPS: false, + } + + confMgr := newConfigManager() + confMgr.onWeb = func() (s websvc.ServiceWithConfig[*websvc.Config]) { + return websvc.New(&websvc.Config{ + TLS: &tls.Config{ + Certificates: []tls.Certificate{{}}, + }, + Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")}, + SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")}, + Timeout: 5 * time.Second, + ForceHTTPS: true, + }) + } + confMgr.onUpdateWeb = func(ctx context.Context, c *websvc.Config) (err error) { + return nil + } + + _, addr := newTestServer(t, confMgr) + u := &url.URL{ + Scheme: "http", + Host: addr.String(), + Path: websvc.PathV1SettingsHTTP, + } + + req := jobj{ + "addresses": wantWeb.Addresses, + "secure_addresses": wantWeb.SecureAddresses, + "timeout": wantWeb.Timeout, + "force_https": wantWeb.ForceHTTPS, + } + + respBody := httpPatch(t, u, req, http.StatusOK) + resp := &websvc.HTTPAPIHTTPSettings{} + err := json.Unmarshal(respBody, resp) + require.NoError(t, err) + + assert.Equal(t, wantWeb, resp) +} diff --git a/internal/next/websvc/json.go b/internal/next/websvc/json.go new file mode 100644 index 00000000000..fa2010a8ada --- /dev/null +++ b/internal/next/websvc/json.go @@ -0,0 +1,143 @@ +package websvc + +import ( + "encoding/json" + "fmt" + "net/http" + "strconv" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" + "github.com/AdguardTeam/golibs/log" +) + +// JSON Utilities + +// nsecPerMsec is the number of nanoseconds in a millisecond. +const nsecPerMsec = float64(time.Millisecond / time.Nanosecond) + +// JSONDuration is a time.Duration that can be decoded from JSON and encoded +// into JSON according to our API conventions. +type JSONDuration time.Duration + +// type check +var _ json.Marshaler = JSONDuration(0) + +// MarshalJSON implements the json.Marshaler interface for JSONDuration. err is +// always nil. +func (d JSONDuration) MarshalJSON() (b []byte, err error) { + msec := float64(time.Duration(d)) / nsecPerMsec + b = strconv.AppendFloat(nil, msec, 'f', -1, 64) + + return b, nil +} + +// type check +var _ json.Unmarshaler = (*JSONDuration)(nil) + +// UnmarshalJSON implements the json.Marshaler interface for *JSONDuration. +func (d *JSONDuration) UnmarshalJSON(b []byte) (err error) { + if d == nil { + return fmt.Errorf("json duration is nil") + } + + msec, err := strconv.ParseFloat(string(b), 64) + if err != nil { + return fmt.Errorf("parsing json time: %w", err) + } + + *d = JSONDuration(int64(msec * nsecPerMsec)) + + return nil +} + +// JSONTime is a time.Time that can be decoded from JSON and encoded into JSON +// according to our API conventions. +type JSONTime time.Time + +// type check +var _ json.Marshaler = JSONTime{} + +// MarshalJSON implements the json.Marshaler interface for JSONTime. err is +// always nil. +func (t JSONTime) MarshalJSON() (b []byte, err error) { + msec := float64(time.Time(t).UnixNano()) / nsecPerMsec + b = strconv.AppendFloat(nil, msec, 'f', -1, 64) + + return b, nil +} + +// type check +var _ json.Unmarshaler = (*JSONTime)(nil) + +// UnmarshalJSON implements the json.Marshaler interface for *JSONTime. +func (t *JSONTime) UnmarshalJSON(b []byte) (err error) { + if t == nil { + return fmt.Errorf("json time is nil") + } + + msec, err := strconv.ParseFloat(string(b), 64) + if err != nil { + return fmt.Errorf("parsing json time: %w", err) + } + + *t = JSONTime(time.Unix(0, int64(msec*nsecPerMsec)).UTC()) + + return nil +} + +// writeJSONOKResponse writes headers with the code 200 OK, encodes v into w, +// and logs any errors it encounters. r is used to get additional information +// from the request. +func writeJSONOKResponse(w http.ResponseWriter, r *http.Request, v any) { + writeJSONResponse(w, r, v, http.StatusOK) +} + +// writeJSONResponse writes headers with code, encodes v into w, and logs any +// errors it encounters. r is used to get additional information from the +// request. +func writeJSONResponse(w http.ResponseWriter, r *http.Request, v any, code int) { + // TODO(a.garipov): Put some of these to a middleware. + h := w.Header() + h.Set(aghhttp.HdrNameContentType, aghhttp.HdrValApplicationJSON) + h.Set(aghhttp.HdrNameServer, aghhttp.UserAgent()) + + w.WriteHeader(code) + + err := json.NewEncoder(w).Encode(v) + if err != nil { + log.Error("websvc: writing resp to %s %s: %s", r.Method, r.URL.Path, err) + } +} + +// ErrorCode is the error code as used by the HTTP API. See the ErrorCode +// definition in the OpenAPI specification. +type ErrorCode string + +// ErrorCode constants. +// +// TODO(a.garipov): Expand and document codes. +const ( + // ErrorCodeTMP000 is the temporary error code used for all errors. + ErrorCodeTMP000 = "" +) + +// HTTPAPIErrorResp is the error response as used by the HTTP API. See the +// BadRequestResp, InternalServerErrorResp, and similar objects in the OpenAPI +// specification. +type HTTPAPIErrorResp struct { + Code ErrorCode `json:"code"` + Msg string `json:"msg"` +} + +// writeJSONErrorResponse encodes err as a JSON error into w, and logs any +// errors it encounters. r is used to get additional information from the +// request. +func writeJSONErrorResponse(w http.ResponseWriter, r *http.Request, err error) { + log.Error("websvc: %s %s: %s", r.Method, r.URL.Path, err) + + writeJSONResponse(w, r, &HTTPAPIErrorResp{ + Code: ErrorCodeTMP000, + Msg: err.Error(), + }, http.StatusUnprocessableEntity) +} diff --git a/internal/next/websvc/json_test.go b/internal/next/websvc/json_test.go new file mode 100644 index 00000000000..90874958a60 --- /dev/null +++ b/internal/next/websvc/json_test.go @@ -0,0 +1,114 @@ +package websvc_test + +import ( + "encoding/json" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" + "github.com/AdguardTeam/golibs/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testJSONTime is the JSON time for tests. +var testJSONTime = websvc.JSONTime(time.Unix(1_234_567_890, 123_456_000).UTC()) + +// testJSONTimeStr is the string with the JSON encoding of testJSONTime. +const testJSONTimeStr = "1234567890123.456" + +func TestJSONTime_MarshalJSON(t *testing.T) { + testCases := []struct { + name string + wantErrMsg string + in websvc.JSONTime + want []byte + }{{ + name: "unix_zero", + wantErrMsg: "", + in: websvc.JSONTime(time.Unix(0, 0)), + want: []byte("0"), + }, { + name: "empty", + wantErrMsg: "", + in: websvc.JSONTime{}, + want: []byte("-6795364578871.345"), + }, { + name: "time", + wantErrMsg: "", + in: testJSONTime, + want: []byte(testJSONTimeStr), + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := tc.in.MarshalJSON() + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + + assert.Equal(t, tc.want, got) + }) + } + + t.Run("json", func(t *testing.T) { + in := &struct { + A websvc.JSONTime + }{ + A: testJSONTime, + } + + got, err := json.Marshal(in) + require.NoError(t, err) + + assert.Equal(t, []byte(`{"A":`+testJSONTimeStr+`}`), got) + }) +} + +func TestJSONTime_UnmarshalJSON(t *testing.T) { + testCases := []struct { + name string + wantErrMsg string + want websvc.JSONTime + data []byte + }{{ + name: "time", + wantErrMsg: "", + want: testJSONTime, + data: []byte(testJSONTimeStr), + }, { + name: "bad", + wantErrMsg: `parsing json time: strconv.ParseFloat: parsing "{}": ` + + `invalid syntax`, + want: websvc.JSONTime{}, + data: []byte(`{}`), + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var got websvc.JSONTime + err := got.UnmarshalJSON(tc.data) + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + + assert.Equal(t, tc.want, got) + }) + } + + t.Run("nil", func(t *testing.T) { + err := (*websvc.JSONTime)(nil).UnmarshalJSON([]byte("0")) + require.Error(t, err) + + msg := err.Error() + assert.Equal(t, "json time is nil", msg) + }) + + t.Run("json", func(t *testing.T) { + want := testJSONTime + var got struct { + A websvc.JSONTime + } + + err := json.Unmarshal([]byte(`{"A":`+testJSONTimeStr+`}`), &got) + require.NoError(t, err) + + assert.Equal(t, want, got.A) + }) +} diff --git a/internal/v1/websvc/middleware.go b/internal/next/websvc/middleware.go similarity index 100% rename from internal/v1/websvc/middleware.go rename to internal/next/websvc/middleware.go diff --git a/internal/next/websvc/path.go b/internal/next/websvc/path.go new file mode 100644 index 00000000000..e38a1d605ed --- /dev/null +++ b/internal/next/websvc/path.go @@ -0,0 +1,11 @@ +package websvc + +// Path constants +const ( + PathHealthCheck = "/health-check" + + PathV1SettingsAll = "/api/v1/settings/all" + PathV1SettingsDNS = "/api/v1/settings/dns" + PathV1SettingsHTTP = "/api/v1/settings/http" + PathV1SystemInfo = "/api/v1/system/info" +) diff --git a/internal/next/websvc/settings.go b/internal/next/websvc/settings.go new file mode 100644 index 00000000000..b6c5a80ad68 --- /dev/null +++ b/internal/next/websvc/settings.go @@ -0,0 +1,42 @@ +package websvc + +import ( + "net/http" +) + +// All Settings Handlers + +// RespGetV1SettingsAll describes the response of the GET /api/v1/settings/all +// HTTP API. +type RespGetV1SettingsAll struct { + // TODO(a.garipov): Add more as we go. + + DNS *HTTPAPIDNSSettings `json:"dns"` + HTTP *HTTPAPIHTTPSettings `json:"http"` +} + +// handleGetSettingsAll is the handler for the GET /api/v1/settings/all HTTP +// API. +func (svc *Service) handleGetSettingsAll(w http.ResponseWriter, r *http.Request) { + dnsSvc := svc.confMgr.DNS() + dnsConf := dnsSvc.Config() + + webSvc := svc.confMgr.Web() + httpConf := webSvc.Config() + + // TODO(a.garipov): Add all currently supported parameters. + writeJSONOKResponse(w, r, &RespGetV1SettingsAll{ + DNS: &HTTPAPIDNSSettings{ + Addresses: dnsConf.Addresses, + BootstrapServers: dnsConf.BootstrapServers, + UpstreamServers: dnsConf.UpstreamServers, + UpstreamTimeout: JSONDuration(dnsConf.UpstreamTimeout), + }, + HTTP: &HTTPAPIHTTPSettings{ + Addresses: httpConf.Addresses, + SecureAddresses: httpConf.SecureAddresses, + Timeout: JSONDuration(httpConf.Timeout), + ForceHTTPS: httpConf.ForceHTTPS, + }, + }) +} diff --git a/internal/next/websvc/settings_test.go b/internal/next/websvc/settings_test.go new file mode 100644 index 00000000000..dadb4b55ea3 --- /dev/null +++ b/internal/next/websvc/settings_test.go @@ -0,0 +1,74 @@ +package websvc_test + +import ( + "crypto/tls" + "encoding/json" + "net/http" + "net/netip" + "net/url" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" + "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestService_HandleGetSettingsAll(t *testing.T) { + // TODO(a.garipov): Add all currently supported parameters. + + wantDNS := &websvc.HTTPAPIDNSSettings{ + Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:53")}, + BootstrapServers: []string{"94.140.14.140", "94.140.14.141"}, + UpstreamServers: []string{"94.140.14.14", "1.1.1.1"}, + UpstreamTimeout: websvc.JSONDuration(1 * time.Second), + } + + wantWeb := &websvc.HTTPAPIHTTPSettings{ + Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")}, + SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")}, + Timeout: websvc.JSONDuration(5 * time.Second), + ForceHTTPS: true, + } + + confMgr := newConfigManager() + confMgr.onDNS = func() (s websvc.ServiceWithConfig[*dnssvc.Config]) { + c, err := dnssvc.New(&dnssvc.Config{ + Addresses: wantDNS.Addresses, + UpstreamServers: wantDNS.UpstreamServers, + BootstrapServers: wantDNS.BootstrapServers, + UpstreamTimeout: time.Duration(wantDNS.UpstreamTimeout), + }) + require.NoError(t, err) + + return c + } + + confMgr.onWeb = func() (s websvc.ServiceWithConfig[*websvc.Config]) { + return websvc.New(&websvc.Config{ + TLS: &tls.Config{ + Certificates: []tls.Certificate{{}}, + }, + Addresses: wantWeb.Addresses, + SecureAddresses: wantWeb.SecureAddresses, + Timeout: time.Duration(wantWeb.Timeout), + ForceHTTPS: true, + }) + } + + _, addr := newTestServer(t, confMgr) + u := &url.URL{ + Scheme: "http", + Host: addr.String(), + Path: websvc.PathV1SettingsAll, + } + + body := httpGet(t, u, http.StatusOK) + resp := &websvc.RespGetV1SettingsAll{} + err := json.Unmarshal(body, resp) + require.NoError(t, err) + + assert.Equal(t, wantDNS, resp.DNS) + assert.Equal(t, wantWeb, resp.HTTP) +} diff --git a/internal/v1/websvc/system.go b/internal/next/websvc/system.go similarity index 87% rename from internal/v1/websvc/system.go rename to internal/next/websvc/system.go index 47d0c63cb0c..fbf60fe4d55 100644 --- a/internal/v1/websvc/system.go +++ b/internal/next/websvc/system.go @@ -16,20 +16,20 @@ type RespGetV1SystemInfo struct { Channel string `json:"channel"` OS string `json:"os"` NewVersion string `json:"new_version,omitempty"` - Start jsonTime `json:"start"` + Start JSONTime `json:"start"` Version string `json:"version"` } // handleGetV1SystemInfo is the handler for the GET /api/v1/system/info HTTP // API. func (svc *Service) handleGetV1SystemInfo(w http.ResponseWriter, r *http.Request) { - writeJSONResponse(w, r, &RespGetV1SystemInfo{ + writeJSONOKResponse(w, r, &RespGetV1SystemInfo{ Arch: runtime.GOARCH, Channel: version.Channel(), OS: runtime.GOOS, // TODO(a.garipov): Fill this when we have an updater. NewVersion: "", - Start: jsonTime(svc.start), + Start: JSONTime(svc.start), Version: version.Version(), }) } diff --git a/internal/v1/websvc/system_test.go b/internal/next/websvc/system_test.go similarity index 82% rename from internal/v1/websvc/system_test.go rename to internal/next/websvc/system_test.go index 49579ca5833..acbdcba2a4c 100644 --- a/internal/v1/websvc/system_test.go +++ b/internal/next/websvc/system_test.go @@ -8,16 +8,17 @@ import ( "testing" "time" - "github.com/AdguardTeam/AdGuardHome/internal/v1/websvc" + "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestService_handleGetV1SystemInfo(t *testing.T) { - _, addr := newTestServer(t) + confMgr := newConfigManager() + _, addr := newTestServer(t, confMgr) u := &url.URL{ Scheme: "http", - Host: addr, + Host: addr.String(), Path: websvc.PathV1SystemInfo, } diff --git a/internal/next/websvc/waitlistener.go b/internal/next/websvc/waitlistener.go new file mode 100644 index 00000000000..8ab562693af --- /dev/null +++ b/internal/next/websvc/waitlistener.go @@ -0,0 +1,31 @@ +package websvc + +import ( + "net" + "sync" +) + +// Wait Listener + +// waitListener is a wrapper around a listener that also calls wg.Done() on the +// first call to Accept. It is useful in situations where it is important to +// catch the precise moment of the first call to Accept, for example when +// starting an HTTP server. +// +// TODO(a.garipov): Move to aghnet? +type waitListener struct { + net.Listener + + firstAcceptWG *sync.WaitGroup + firstAcceptOnce sync.Once +} + +// type check +var _ net.Listener = (*waitListener)(nil) + +// Accept implements the [net.Listener] interface for *waitListener. +func (l *waitListener) Accept() (conn net.Conn, err error) { + l.firstAcceptOnce.Do(l.firstAcceptWG.Done) + + return l.Listener.Accept() +} diff --git a/internal/next/websvc/waitlistener_internal_test.go b/internal/next/websvc/waitlistener_internal_test.go new file mode 100644 index 00000000000..e151341bc6a --- /dev/null +++ b/internal/next/websvc/waitlistener_internal_test.go @@ -0,0 +1,46 @@ +package websvc + +import ( + "net" + "sync" + "sync/atomic" + "testing" + + "github.com/AdguardTeam/AdGuardHome/internal/aghchan" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/stretchr/testify/assert" +) + +func TestWaitListener_Accept(t *testing.T) { + // TODO(a.garipov): use atomic.Bool in Go 1.19. + var numAcceptCalls uint32 + var l net.Listener = &aghtest.Listener{ + OnAccept: func() (conn net.Conn, err error) { + atomic.AddUint32(&numAcceptCalls, 1) + + return nil, nil + }, + OnAddr: func() (addr net.Addr) { panic("not implemented") }, + OnClose: func() (err error) { panic("not implemented") }, + } + + wg := &sync.WaitGroup{} + wg.Add(1) + + done := make(chan struct{}) + go aghchan.MustReceive(done, testTimeout) + + go func() { + var wrapper net.Listener = &waitListener{ + Listener: l, + firstAcceptWG: wg, + } + + _, _ = wrapper.Accept() + }() + + wg.Wait() + close(done) + + assert.Equal(t, uint32(1), atomic.LoadUint32(&numAcceptCalls)) +} diff --git a/internal/v1/websvc/websvc.go b/internal/next/websvc/websvc.go similarity index 52% rename from internal/v1/websvc/websvc.go rename to internal/next/websvc/websvc.go index bbaac005f97..75f7d001f69 100644 --- a/internal/v1/websvc/websvc.go +++ b/internal/next/websvc/websvc.go @@ -1,4 +1,7 @@ -// Package websvc contains the AdGuard Home web service. +// Package websvc contains the AdGuard Home HTTP API service. +// +// NOTE: Packages other than cmd must not import this package, as it imports +// most other packages. // // TODO(a.garipov): Add tests. package websvc @@ -14,18 +17,46 @@ import ( "sync" "time" - "github.com/AdguardTeam/AdGuardHome/internal/v1/agh" + "github.com/AdguardTeam/AdGuardHome/internal/next/agh" + "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" httptreemux "github.com/dimfeld/httptreemux/v5" ) +// ServiceWithConfig is an extension of the [agh.Service] interface for services +// that can return their configuration. +// +// TODO(a.garipov): Consider removing this generic interface if we figure out +// how to make it testable in a better way. +type ServiceWithConfig[ConfigType any] interface { + agh.Service + + Config() (c ConfigType) +} + +// ConfigManager is the configuration manager interface. +type ConfigManager interface { + DNS() (svc ServiceWithConfig[*dnssvc.Config]) + Web() (svc ServiceWithConfig[*Config]) + + UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error) + UpdateWeb(ctx context.Context, c *Config) (err error) +} + // Config is the AdGuard Home web service configuration structure. type Config struct { + // ConfigManager is used to show information about services as well as + // dynamically reconfigure them. + ConfigManager ConfigManager + // TLS is the optional TLS configuration. If TLS is not nil, // SecureAddresses must not be empty. TLS *tls.Config + // Start is the time of start of AdGuard Home. + Start time.Time + // Addresses are the addresses on which to serve the plain HTTP API. Addresses []netip.AddrPort @@ -33,40 +64,48 @@ type Config struct { // SecureAddresses is not empty, TLS must not be nil. SecureAddresses []netip.AddrPort - // Start is the time of start of AdGuard Home. - Start time.Time - // Timeout is the timeout for all server operations. Timeout time.Duration + + // ForceHTTPS tells if all requests to Addresses should be redirected to a + // secure address instead. + // + // TODO(a.garipov): Use; define rules, which address to redirect to. + ForceHTTPS bool } // Service is the AdGuard Home web service. A nil *Service is a valid // [agh.Service] that does nothing. type Service struct { - tls *tls.Config - servers []*http.Server - start time.Time - timeout time.Duration + confMgr ConfigManager + tls *tls.Config + start time.Time + servers []*http.Server + timeout time.Duration + forceHTTPS bool } // New returns a new properly initialized *Service. If c is nil, svc is a nil -// *Service that does nothing. +// *Service that does nothing. The fields of c must not be modified after +// calling New. func New(c *Config) (svc *Service) { if c == nil { return nil } svc = &Service{ - tls: c.TLS, - start: c.Start, - timeout: c.Timeout, + confMgr: c.ConfigManager, + tls: c.TLS, + start: c.Start, + timeout: c.Timeout, + forceHTTPS: c.ForceHTTPS, } mux := newMux(svc) for _, a := range c.Addresses { addr := a.String() - errLog := log.StdLog("websvc: http: "+addr, log.ERROR) + errLog := log.StdLog("websvc: plain http: "+addr, log.ERROR) svc.servers = append(svc.servers, &http.Server{ Addr: addr, Handler: mux, @@ -111,6 +150,21 @@ func newMux(svc *Service) (mux *httptreemux.ContextMux) { method: http.MethodGet, path: PathHealthCheck, isJSON: false, + }, { + handler: svc.handleGetSettingsAll, + method: http.MethodGet, + path: PathV1SettingsAll, + isJSON: true, + }, { + handler: svc.handlePatchSettingsDNS, + method: http.MethodPatch, + path: PathV1SettingsDNS, + isJSON: true, + }, { + handler: svc.handlePatchSettingsHTTP, + method: http.MethodPatch, + path: PathV1SettingsHTTP, + isJSON: true, }, { handler: svc.handleGetV1SystemInfo, method: http.MethodGet, @@ -119,29 +173,41 @@ func newMux(svc *Service) (mux *httptreemux.ContextMux) { }} for _, r := range routes { - var h http.HandlerFunc if r.isJSON { - // TODO(a.garipov): Consider using httptreemux's MiddlewareFunc. - h = jsonMw(r.handler) + mux.Handle(r.method, r.path, jsonMw(r.handler)) } else { - h = r.handler + mux.Handle(r.method, r.path, r.handler) } - - mux.Handle(r.method, r.path, h) } return mux } -// Addrs returns all addresses on which this server serves the HTTP API. Addrs -// must not be called until Start returns. -func (svc *Service) Addrs() (addrs []string) { - addrs = make([]string, 0, len(svc.servers)) +// addrs returns all addresses on which this server serves the HTTP API. addrs +// must not be called simultaneously with Start. If svc was initialized with +// ":0" addresses, addrs will not return the actual bound ports until Start is +// finished. +func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) { for _, srv := range svc.servers { - addrs = append(addrs, srv.Addr) + addrPort, err := netip.ParseAddrPort(srv.Addr) + if err != nil { + // Technically shouldn't happen, since all servers must have a valid + // address. + panic(fmt.Errorf("websvc: server %q: bad address: %w", srv.Addr, err)) + } + + // srv.Serve will set TLSConfig to an almost empty value, so, instead of + // relying only on the nilness of TLSConfig, check the length of the + // certificates field as well. + if srv.TLSConfig == nil || len(srv.TLSConfig.Certificates) == 0 { + addrs = append(addrs, addrPort) + } else { + secureAddrs = append(secureAddrs, addrPort) + } + } - return addrs + return addrs, secureAddrs } // handleGetHealthCheck is the handler for the GET /health-check HTTP API. @@ -149,9 +215,6 @@ func (svc *Service) handleGetHealthCheck(w http.ResponseWriter, _ *http.Request) _, _ = io.WriteString(w, "OK") } -// unit is a convenient alias for struct{}. -type unit = struct{} - // type check var _ agh.Service = (*Service)(nil) @@ -163,11 +226,9 @@ func (svc *Service) Start() (err error) { return nil } - srvs := svc.servers - wg := &sync.WaitGroup{} - wg.Add(len(srvs)) - for _, srv := range srvs { + wg.Add(len(svc.servers)) + for _, srv := range svc.servers { go serve(srv, wg) } @@ -181,11 +242,14 @@ func serve(srv *http.Server, wg *sync.WaitGroup) { addr := srv.Addr defer log.OnPanic(addr) + var proto string var l net.Listener var err error if srv.TLSConfig == nil { + proto = "http" l, err = net.Listen("tcp", addr) } else { + proto = "https" l, err = tls.Listen("tcp", addr, srv.TLSConfig) } if err != nil { @@ -196,8 +260,12 @@ func serve(srv *http.Server, wg *sync.WaitGroup) { // would mean that a random available port was automatically chosen. srv.Addr = l.Addr().String() - log.Info("websvc: starting srv http://%s", srv.Addr) - wg.Done() + log.Info("websvc: starting srv %s://%s", proto, srv.Addr) + + l = &waitListener{ + Listener: l, + firstAcceptWG: wg, + } err = srv.Serve(l) if err != nil && !errors.Is(err, http.ErrServerClosed) { @@ -221,8 +289,28 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) { } if len(errs) > 0 { - return errors.List("shutting down") + return errors.List("shutting down", errs...) } return nil } + +// Config returns the current configuration of the web service. Config must not +// be called simultaneously with Start. If svc was initialized with ":0" +// addresses, addrs will not return the actual bound ports until Start is +// finished. +func (svc *Service) Config() (c *Config) { + c = &Config{ + ConfigManager: svc.confMgr, + TLS: svc.tls, + // Leave Addresses and SecureAddresses empty and get the actual + // addresses that include the :0 ones later. + Start: svc.start, + Timeout: svc.timeout, + ForceHTTPS: svc.forceHTTPS, + } + + c.Addresses, c.SecureAddresses = svc.addrs() + + return c +} diff --git a/internal/next/websvc/websvc_internal_test.go b/internal/next/websvc/websvc_internal_test.go new file mode 100644 index 00000000000..3509b193410 --- /dev/null +++ b/internal/next/websvc/websvc_internal_test.go @@ -0,0 +1,6 @@ +package websvc + +import "time" + +// testTimeout is the common timeout for tests. +const testTimeout = 1 * time.Second diff --git a/internal/next/websvc/websvc_test.go b/internal/next/websvc/websvc_test.go new file mode 100644 index 00000000000..dbce77d58a2 --- /dev/null +++ b/internal/next/websvc/websvc_test.go @@ -0,0 +1,187 @@ +package websvc_test + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/netip" + "net/url" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" + "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" + "github.com/AdguardTeam/golibs/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMain(m *testing.M) { + aghtest.DiscardLogOutput(m) +} + +// testTimeout is the common timeout for tests. +const testTimeout = 1 * time.Second + +// testStart is the server start value for tests. +var testStart = time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC) + +// type check +var _ websvc.ConfigManager = (*configManager)(nil) + +// configManager is a [websvc.ConfigManager] for tests. +type configManager struct { + onDNS func() (svc websvc.ServiceWithConfig[*dnssvc.Config]) + onWeb func() (svc websvc.ServiceWithConfig[*websvc.Config]) + + onUpdateDNS func(ctx context.Context, c *dnssvc.Config) (err error) + onUpdateWeb func(ctx context.Context, c *websvc.Config) (err error) +} + +// DNS implements the [websvc.ConfigManager] interface for *configManager. +func (m *configManager) DNS() (svc websvc.ServiceWithConfig[*dnssvc.Config]) { + return m.onDNS() +} + +// Web implements the [websvc.ConfigManager] interface for *configManager. +func (m *configManager) Web() (svc websvc.ServiceWithConfig[*websvc.Config]) { + return m.onWeb() +} + +// UpdateDNS implements the [websvc.ConfigManager] interface for *configManager. +func (m *configManager) UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error) { + return m.onUpdateDNS(ctx, c) +} + +// UpdateWeb implements the [websvc.ConfigManager] interface for *configManager. +func (m *configManager) UpdateWeb(ctx context.Context, c *websvc.Config) (err error) { + return m.onUpdateWeb(ctx, c) +} + +// newConfigManager returns a *configManager all methods of which panic. +func newConfigManager() (m *configManager) { + return &configManager{ + onDNS: func() (svc websvc.ServiceWithConfig[*dnssvc.Config]) { panic("not implemented") }, + onWeb: func() (svc websvc.ServiceWithConfig[*websvc.Config]) { panic("not implemented") }, + onUpdateDNS: func(_ context.Context, _ *dnssvc.Config) (err error) { + panic("not implemented") + }, + onUpdateWeb: func(_ context.Context, _ *websvc.Config) (err error) { + panic("not implemented") + }, + } +} + +// newTestServer creates and starts a new web service instance as well as its +// sole address. It also registers a cleanup procedure, which shuts the +// instance down. +// +// TODO(a.garipov): Use svc or remove it. +func newTestServer( + t testing.TB, + confMgr websvc.ConfigManager, +) (svc *websvc.Service, addr netip.AddrPort) { + t.Helper() + + c := &websvc.Config{ + ConfigManager: confMgr, + TLS: nil, + Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")}, + SecureAddresses: nil, + Timeout: testTimeout, + Start: testStart, + ForceHTTPS: false, + } + + svc = websvc.New(c) + + err := svc.Start() + require.NoError(t, err) + t.Cleanup(func() { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + err = svc.Shutdown(ctx) + require.NoError(t, err) + }) + + c = svc.Config() + require.NotNil(t, c) + require.Len(t, c.Addresses, 1) + + return svc, c.Addresses[0] +} + +// jobj is a utility alias for JSON objects. +type jobj map[string]any + +// httpGet is a helper that performs an HTTP GET request and returns the body of +// the response as well as checks that the status code is correct. +// +// TODO(a.garipov): Add helpers for other methods. +func httpGet(t testing.TB, u *url.URL, wantCode int) (body []byte) { + t.Helper() + + req, err := http.NewRequest(http.MethodGet, u.String(), nil) + require.NoErrorf(t, err, "creating req") + + httpCli := &http.Client{ + Timeout: testTimeout, + } + resp, err := httpCli.Do(req) + require.NoErrorf(t, err, "performing req") + require.Equal(t, wantCode, resp.StatusCode) + + testutil.CleanupAndRequireSuccess(t, resp.Body.Close) + + body, err = io.ReadAll(resp.Body) + require.NoErrorf(t, err, "reading body") + + return body +} + +// httpPatch is a helper that performs an HTTP PATCH request with JSON-encoded +// reqBody as the request body and returns the body of the response as well as +// checks that the status code is correct. +// +// TODO(a.garipov): Add helpers for other methods. +func httpPatch(t testing.TB, u *url.URL, reqBody any, wantCode int) (body []byte) { + t.Helper() + + b, err := json.Marshal(reqBody) + require.NoErrorf(t, err, "marshaling reqBody") + + req, err := http.NewRequest(http.MethodPatch, u.String(), bytes.NewReader(b)) + require.NoErrorf(t, err, "creating req") + + httpCli := &http.Client{ + Timeout: testTimeout, + } + resp, err := httpCli.Do(req) + require.NoErrorf(t, err, "performing req") + require.Equal(t, wantCode, resp.StatusCode) + + testutil.CleanupAndRequireSuccess(t, resp.Body.Close) + + body, err = io.ReadAll(resp.Body) + require.NoErrorf(t, err, "reading body") + + return body +} + +func TestService_Start_getHealthCheck(t *testing.T) { + confMgr := newConfigManager() + _, addr := newTestServer(t, confMgr) + u := &url.URL{ + Scheme: "http", + Host: addr.String(), + Path: websvc.PathHealthCheck, + } + + body := httpGet(t, u, http.StatusOK) + + assert.Equal(t, []byte("OK"), body) +} diff --git a/internal/v1/websvc/json.go b/internal/v1/websvc/json.go deleted file mode 100644 index ef84211b409..00000000000 --- a/internal/v1/websvc/json.go +++ /dev/null @@ -1,61 +0,0 @@ -package websvc - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - "strconv" - "time" - - "github.com/AdguardTeam/golibs/log" -) - -// JSON Utilities - -// jsonTime is a time.Time that can be decoded from JSON and encoded into JSON -// according to our API conventions. -type jsonTime time.Time - -// type check -var _ json.Marshaler = jsonTime{} - -// nsecPerMsec is the number of nanoseconds in a millisecond. -const nsecPerMsec = float64(time.Millisecond / time.Nanosecond) - -// MarshalJSON implements the json.Marshaler interface for jsonTime. err is -// always nil. -func (t jsonTime) MarshalJSON() (b []byte, err error) { - msec := float64(time.Time(t).UnixNano()) / nsecPerMsec - b = strconv.AppendFloat(nil, msec, 'f', 3, 64) - - return b, nil -} - -// type check -var _ json.Unmarshaler = (*jsonTime)(nil) - -// UnmarshalJSON implements the json.Marshaler interface for *jsonTime. -func (t *jsonTime) UnmarshalJSON(b []byte) (err error) { - if t == nil { - return fmt.Errorf("json time is nil") - } - - msec, err := strconv.ParseFloat(string(b), 64) - if err != nil { - return fmt.Errorf("parsing json time: %w", err) - } - - *t = jsonTime(time.Unix(0, int64(msec*nsecPerMsec)).UTC()) - - return nil -} - -// writeJSONResponse encodes v into w and logs any errors it encounters. r is -// used to get additional information from the request. -func writeJSONResponse(w io.Writer, r *http.Request, v any) { - err := json.NewEncoder(w).Encode(v) - if err != nil { - log.Error("websvc: writing resp to %s %s: %s", r.Method, r.URL.Path, err) - } -} diff --git a/internal/v1/websvc/path.go b/internal/v1/websvc/path.go deleted file mode 100644 index cfd67fd9b06..00000000000 --- a/internal/v1/websvc/path.go +++ /dev/null @@ -1,8 +0,0 @@ -package websvc - -// Path constants -const ( - PathHealthCheck = "/health-check" - - PathV1SystemInfo = "/api/v1/system/info" -) diff --git a/internal/v1/websvc/websvc_test.go b/internal/v1/websvc/websvc_test.go deleted file mode 100644 index de4a9f5db15..00000000000 --- a/internal/v1/websvc/websvc_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package websvc_test - -import ( - "context" - "io" - "net/http" - "net/netip" - "net/url" - "testing" - "time" - - "github.com/AdguardTeam/AdGuardHome/internal/v1/websvc" - "github.com/AdguardTeam/golibs/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -const testTimeout = 1 * time.Second - -// testStart is the server start value for tests. -var testStart = time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC) - -// newTestServer creates and starts a new web service instance as well as its -// sole address. It also registers a cleanup procedure, which shuts the -// instance down. -// -// TODO(a.garipov): Use svc or remove it. -func newTestServer(t testing.TB) (svc *websvc.Service, addr string) { - t.Helper() - - c := &websvc.Config{ - TLS: nil, - Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")}, - SecureAddresses: nil, - Timeout: testTimeout, - Start: testStart, - } - - svc = websvc.New(c) - - err := svc.Start() - require.NoError(t, err) - t.Cleanup(func() { - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - t.Cleanup(cancel) - - err = svc.Shutdown(ctx) - require.NoError(t, err) - }) - - addrs := svc.Addrs() - require.Len(t, addrs, 1) - - return svc, addrs[0] -} - -// httpGet is a helper that performs an HTTP GET request and returns the body of -// the response as well as checks that the status code is correct. -// -// TODO(a.garipov): Add helpers for other methods. -func httpGet(t testing.TB, u *url.URL, wantCode int) (body []byte) { - t.Helper() - - req, err := http.NewRequest(http.MethodGet, u.String(), nil) - require.NoErrorf(t, err, "creating req") - - httpCli := &http.Client{ - Timeout: testTimeout, - } - resp, err := httpCli.Do(req) - require.NoErrorf(t, err, "performing req") - require.Equal(t, wantCode, resp.StatusCode) - - testutil.CleanupAndRequireSuccess(t, resp.Body.Close) - - body, err = io.ReadAll(resp.Body) - require.NoErrorf(t, err, "reading body") - - return body -} - -func TestService_Start_getHealthCheck(t *testing.T) { - _, addr := newTestServer(t) - u := &url.URL{ - Scheme: "http", - Host: addr, - Path: websvc.PathHealthCheck, - } - - body := httpGet(t, u, http.StatusOK) - - assert.Equal(t, []byte("OK"), body) -} diff --git a/internal/version/version.go b/internal/version/version.go index 2091d859e12..ca78efffc51 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -63,14 +63,6 @@ func Version() (v string) { return version } -// Constants defining the format of module information string. -const ( - modInfoAtSep = "@" - modInfoDevSep = " " - modInfoSumLeft = " (sum: " - modInfoSumRight = ")" -) - // fmtModule returns formatted information about module. The result looks like: // // github.com/Username/module@v1.2.3 (sum: someHASHSUM=) @@ -87,14 +79,16 @@ func fmtModule(m *debug.Module) (formatted string) { stringutil.WriteToBuilder(b, m.Path) if ver := m.Version; ver != "" { - sep := modInfoAtSep + sep := "@" if ver == "(devel)" { - sep = modInfoDevSep + sep = " " } + stringutil.WriteToBuilder(b, sep, ver) } + if sum := m.Sum; sum != "" { - stringutil.WriteToBuilder(b, modInfoSumLeft, sum, modInfoSumRight) + stringutil.WriteToBuilder(b, "(sum: ", sum, ")") } return b.String() diff --git a/main.go b/main.go index 03ad2f03352..615a8a8636f 100644 --- a/main.go +++ b/main.go @@ -1,5 +1,5 @@ -//go:build !v1 -// +build !v1 +//go:build !next +// +build !next package main diff --git a/main_v1.go b/main_next.go similarity index 79% rename from main_v1.go rename to main_next.go index 6b5f3deaf23..0006e87bf1c 100644 --- a/main_v1.go +++ b/main_next.go @@ -1,12 +1,12 @@ -//go:build v1 -// +build v1 +//go:build next +// +build next package main import ( "embed" - "github.com/AdguardTeam/AdGuardHome/internal/v1/cmd" + "github.com/AdguardTeam/AdGuardHome/internal/next/cmd" ) // Embed the prebuilt client here since we strive to keep .go files inside the diff --git a/openapi/v1.yaml b/openapi/v1.yaml index 77eb1a0997d..adab6d4d2fd 100644 --- a/openapi/v1.yaml +++ b/openapi/v1.yaml @@ -2289,7 +2289,7 @@ 'upstream_servers': - '1.1.1.1' - '8.8.8.8' - 'upstream_timeout': '1s' + 'upstream_timeout': 1000 'required': - 'addresses' - 'blocking_mode' @@ -2397,8 +2397,9 @@ 'type': 'array' 'upstream_timeout': 'description': > - Upstream request timeout, as a human readable duration. - 'type': 'string' + Upstream request timeout, in milliseconds. + 'format': 'double' + 'type': 'number' 'type': 'object' 'DnsType': @@ -3505,14 +3506,16 @@ 'addresses': - '127.0.0.1:80' - '192.168.1.1:80' + 'force_https': true 'secure_addresses': - '127.0.0.1:443' - '192.168.1.1:443' - 'force_https': true + 'timeout': 10000 'required': - 'addresses' - - 'secure_addresses' - 'force_https' + - 'secure_addresses' + - 'timeout' 'HttpSettingsPatch': 'description': > @@ -3539,6 +3542,11 @@ 'items': 'type': 'string' 'type': 'array' + 'timeout': + 'description': > + HTTP request timeout, in milliseconds. + 'format': 'double' + 'type': 'number' 'type': 'object' 'InternalServerErrorResp': diff --git a/scripts/make/go-lint.sh b/scripts/make/go-lint.sh index 2cdcc90d486..e04af72570f 100644 --- a/scripts/make/go-lint.sh +++ b/scripts/make/go-lint.sh @@ -136,11 +136,11 @@ underscores() { -e '_freebsd.go'\ -e '_linux.go'\ -e '_little.go'\ + -e '_next.go'\ -e '_openbsd.go'\ -e '_others.go'\ -e '_test.go'\ -e '_unix.go'\ - -e '_v1.go'\ -e '_windows.go' \ -v\ | sed -e 's/./\t\0/' @@ -229,7 +229,7 @@ gocyclo --over 13 ./internal/filtering/ # Apply stricter standards to new or somewhat refactored code. gocyclo --over 10 ./internal/aghio/ ./internal/aghnet/ ./internal/aghos/\ ./internal/aghtest/ ./internal/dnsforward/ ./internal/stats/\ - ./internal/tools/ ./internal/updater/ ./internal/v1/ ./internal/version/\ + ./internal/tools/ ./internal/updater/ ./internal/next/ ./internal/version/\ ./main.go ineffassign ./... From f557339ca04f0005d559b564ae4a44a3701953cf Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Tue, 4 Oct 2022 16:36:38 +0300 Subject: [PATCH 3/9] Pull request: imp-cache-label Merge in DNS/adguard-home from imp-cache-label to master Squashed commit of the following: commit 10f62aa078b5306525578e22476835ee2e7bac66 Merge: 08c2de0e fe8be370 Author: Ainar Garipov Date: Tue Oct 4 16:30:43 2022 +0300 Merge branch 'master' into imp-cache-label commit 08c2de0edbb1138b47d1a02d6630aa99b7ddcec9 Author: Ainar Garipov Date: Tue Oct 4 16:19:36 2022 +0300 client: imp label commit e66fbbe3cc6f929ff26fe3d7b8e14acc95f5c0ff Author: Ainar Garipov Date: Tue Oct 4 16:17:15 2022 +0300 client: imp upstream example commit d073f71cc5df4ba5f7de7ed08ad1215f7a198539 Author: Ainar Garipov Date: Tue Oct 4 15:44:58 2022 +0300 client: imp upstreams commit b78d06db645a9f496bed699f4d4bf8c7396148f3 Author: Ainar Garipov Date: Tue Oct 4 14:59:30 2022 +0300 client: imp cache size label --- CHANGELOG.md | 2 +- client/src/__locales/en.json | 3 ++- .../components/Settings/Dns/Upstream/Examples.js | 16 ++++++++++++++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 71b93339080..be4d47b0536 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,7 +49,7 @@ experimental and may break or change in the future. explicitly enabled by setting the new property `dns.serve_http3` in the configuration file to `true`. - DNS-over-HTTP upstreams can now upgrade to HTTP/3 if the new configuration - file property `use_http3_upstreams` is set to `true`. + file property `dns.use_http3_upstreams` is set to `true`. - Upstreams with forced DNS-over-HTTP/3 and no fallback to prior HTTP versions using the `h3://` scheme. diff --git a/client/src/__locales/en.json b/client/src/__locales/en.json index e059c9f4466..b986dea1c49 100644 --- a/client/src/__locales/en.json +++ b/client/src/__locales/en.json @@ -215,6 +215,7 @@ "example_upstream_udp": "regular DNS (over UDP, hostname);", "example_upstream_dot": "encrypted <0>DNS-over-TLS;", "example_upstream_doh": "encrypted <0>DNS-over-HTTPS;", + "example_upstream_doh3": "encrypted DNS-over-HTTPS with forced <0>HTTP/3 and no fallback to HTTP/2 or below;", "example_upstream_doq": "encrypted <0>DNS-over-QUIC;", "example_upstream_sdns": "<0>DNS Stamps for <1>DNSCrypt or <2>DNS-over-HTTPS resolvers;", "example_upstream_tcp": "regular DNS (over TCP);", @@ -605,7 +606,7 @@ "blocklist": "Blocklist", "milliseconds_abbreviation": "ms", "cache_size": "Cache size", - "cache_size_desc": "DNS cache size (in bytes).", + "cache_size_desc": "DNS cache size (in bytes). To disable caching, leave empty.", "cache_ttl_min_override": "Override minimum TTL", "cache_ttl_max_override": "Override maximum TTL", "enter_cache_size": "Enter cache size (bytes)", diff --git a/client/src/components/Settings/Dns/Upstream/Examples.js b/client/src/components/Settings/Dns/Upstream/Examples.js index c17e9456b6d..a975e4440be 100644 --- a/client/src/components/Settings/Dns/Upstream/Examples.js +++ b/client/src/components/Settings/Dns/Upstream/Examples.js @@ -57,6 +57,22 @@ const Examples = (props) => ( example_upstream_doh +
  • + h3://unfiltered.adguard-dns.com/dns-query: + HTTP/3 + , + ]} + > + example_upstream_doh3 + +
  • quic://unfiltered.adguard-dns.com: Date: Wed, 5 Oct 2022 17:07:08 +0300 Subject: [PATCH 4/9] Pull request: refactor-opts Updates #2893. Squashed commit of the following: commit c7027abd1088e27569367f3450e9225ff605b43d Author: Ainar Garipov Date: Wed Oct 5 16:54:23 2022 +0300 home: imp docs commit 86a5b0aca916a7db608eba8263ecdc6ca79c8043 Author: Ainar Garipov Date: Wed Oct 5 16:50:44 2022 +0300 home: refactor opts more commit 74c5989d1edf8d007dec847f4aaa0d7a0d24dc38 Author: Ainar Garipov Date: Wed Oct 5 15:17:26 2022 +0300 home: refactor option parsing --- internal/dhcpd/http_unix.go | 37 +- internal/home/auth_test.go | 5 - internal/home/home.go | 136 ++++--- internal/home/home_test.go | 12 + internal/home/options.go | 648 +++++++++++++++++++--------------- internal/home/options_test.go | 58 +-- internal/home/service.go | 2 +- scripts/make/go-lint.sh | 3 +- 8 files changed, 501 insertions(+), 400 deletions(-) create mode 100644 internal/home/home_test.go diff --git a/internal/dhcpd/http_unix.go b/internal/dhcpd/http_unix.go index ab3ce318076..e6b1f8fc6d0 100644 --- a/internal/dhcpd/http_unix.go +++ b/internal/dhcpd/http_unix.go @@ -235,22 +235,7 @@ func (s *server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { return } - if conf.Enabled != aghalg.NBNull { - s.conf.Enabled = conf.Enabled == aghalg.NBTrue - } - - if conf.InterfaceName != "" { - s.conf.InterfaceName = conf.InterfaceName - } - - if srv4 != nil { - s.srv4 = srv4 - } - - if srv6 != nil { - s.srv6 = srv6 - } - + s.setConfFromJSON(conf, srv4, srv6) s.conf.ConfigModified() err = s.dbLoad() @@ -269,6 +254,26 @@ func (s *server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { } } +// setConfFromJSON sets configuration parameters in s from the new configuration +// decoded from JSON. +func (s *server) setConfFromJSON(conf *dhcpServerConfigJSON, srv4, srv6 DHCPServer) { + if conf.Enabled != aghalg.NBNull { + s.conf.Enabled = conf.Enabled == aghalg.NBTrue + } + + if conf.InterfaceName != "" { + s.conf.InterfaceName = conf.InterfaceName + } + + if srv4 != nil { + s.srv4 = srv4 + } + + if srv6 != nil { + s.srv6 = srv6 + } +} + type netInterfaceJSON struct { Name string `json:"name"` HardwareAddr string `json:"hardware_address"` diff --git a/internal/home/auth_test.go b/internal/home/auth_test.go index 1bf387530b9..46767f7d9df 100644 --- a/internal/home/auth_test.go +++ b/internal/home/auth_test.go @@ -12,16 +12,11 @@ import ( "testing" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestMain(m *testing.M) { - aghtest.DiscardLogOutput(m) -} - func TestNewSessionToken(t *testing.T) { // Successful case. token, err := newSessionToken() diff --git a/internal/home/home.go b/internal/home/home.go index 42c44249554..289c1c643ec 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -97,9 +97,15 @@ var Context homeContext // Main is the entry point func Main(clientBuildFS fs.FS) { - // config can be specified, which reads options from there, but other command line flags have to override config values - // therefore, we must do it manually instead of using a lib - args := loadOptions() + initCmdLineOpts() + + // The configuration file path can be overridden, but other command-line + // options have to override config values. Therefore, do it manually + // instead of using package flag. + // + // TODO(a.garipov): The comment above is most likely false. Replace with + // package flag. + opts := loadCmdLineOpts() Context.appSignalChannel = make(chan os.Signal) signal.Notify(Context.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) @@ -120,26 +126,18 @@ func Main(clientBuildFS fs.FS) { } }() - if args.serviceControlAction != "" { - handleServiceControlAction(args, clientBuildFS) + if opts.serviceControlAction != "" { + handleServiceControlAction(opts, clientBuildFS) return } // run the protection - run(args, clientBuildFS) + run(opts, clientBuildFS) } -func setupContext(args options) { - Context.runningAsService = args.runningAsService - Context.disableUpdate = args.disableUpdate || - version.Channel() == version.ChannelDevelopment - - Context.firstRun = detectFirstRun() - if Context.firstRun { - log.Info("This is the first time AdGuard Home is launched") - checkPermissions() - } +func setupContext(opts options) { + setupContextFlags(opts) switch version.Channel() { case version.ChannelEdge, version.ChannelDevelopment: @@ -174,13 +172,13 @@ func setupContext(args options) { os.Exit(1) } - if args.checkConfig { + if opts.checkConfig { log.Info("configuration file is ok") os.Exit(0) } - if !args.noEtcHosts && config.Clients.Sources.HostsFile { + if !opts.noEtcHosts && config.Clients.Sources.HostsFile { err = setupHostsContainer() fatalOnError(err) } @@ -189,6 +187,24 @@ func setupContext(args options) { Context.mux = http.NewServeMux() } +// setupContextFlags sets global flags and prints their status to the log. +func setupContextFlags(opts options) { + Context.firstRun = detectFirstRun() + if Context.firstRun { + log.Info("This is the first time AdGuard Home is launched") + checkPermissions() + } + + Context.runningAsService = opts.runningAsService + // Don't print the runningAsService flag, since that has already been done + // in [run]. + + Context.disableUpdate = opts.disableUpdate || version.Channel() == version.ChannelDevelopment + if Context.disableUpdate { + log.Info("AdGuard Home updates are disabled") + } +} + // logIfUnsupported logs a formatted warning if the error is one of the // unsupported errors and returns nil. If err is nil, logIfUnsupported returns // nil. Otherwise, it returns err. @@ -270,7 +286,7 @@ func setupHostsContainer() (err error) { return nil } -func setupConfig(args options) (err error) { +func setupConfig(opts options) (err error) { config.DNS.DnsfilterConf.EtcHosts = Context.etcHosts config.DNS.DnsfilterConf.ConfigModified = onConfigModified config.DNS.DnsfilterConf.HTTPRegister = httpRegister @@ -312,9 +328,9 @@ func setupConfig(args options) (err error) { Context.clients.Init(config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, arpdb) - if args.bindPort != 0 { + if opts.bindPort != 0 { tcpPorts := aghalg.UniqChecker[tcpPort]{} - addPorts(tcpPorts, tcpPort(args.bindPort), tcpPort(config.BetaBindPort)) + addPorts(tcpPorts, tcpPort(opts.bindPort), tcpPort(config.BetaBindPort)) udpPorts := aghalg.UniqChecker[udpPort]{} addPorts(udpPorts, udpPort(config.DNS.Port)) @@ -336,23 +352,23 @@ func setupConfig(args options) (err error) { return fmt.Errorf("validating udp ports: %w", err) } - config.BindPort = args.bindPort + config.BindPort = opts.bindPort } // override bind host/port from the console - if args.bindHost != nil { - config.BindHost = args.bindHost + if opts.bindHost != nil { + config.BindHost = opts.bindHost } - if len(args.pidFile) != 0 && writePIDFile(args.pidFile) { - Context.pidFileName = args.pidFile + if len(opts.pidFile) != 0 && writePIDFile(opts.pidFile) { + Context.pidFileName = opts.pidFile } return nil } -func initWeb(args options, clientBuildFS fs.FS) (web *Web, err error) { +func initWeb(opts options, clientBuildFS fs.FS) (web *Web, err error) { var clientFS, clientBetaFS fs.FS - if args.localFrontend { + if opts.localFrontend { log.Info("warning: using local frontend files") clientFS = os.DirFS("build/static") @@ -400,24 +416,24 @@ func fatalOnError(err error) { } // run configures and starts AdGuard Home. -func run(args options, clientBuildFS fs.FS) { +func run(opts options, clientBuildFS fs.FS) { // configure config filename - initConfigFilename(args) + initConfigFilename(opts) // configure working dir and config path - initWorkingDir(args) + initWorkingDir(opts) // configure log level and output - configureLogger(args) + configureLogger(opts) // Print the first message after logger is configured. log.Info(version.Full()) log.Debug("current working directory is %s", Context.workDir) - if args.runningAsService { + if opts.runningAsService { log.Info("AdGuard Home is running as a service") } - setupContext(args) + setupContext(opts) err := configureOS(config) fatalOnError(err) @@ -427,7 +443,7 @@ func run(args options, clientBuildFS fs.FS) { // but also avoid relying on automatic Go init() function filtering.InitModule() - err = setupConfig(args) + err = setupConfig(opts) fatalOnError(err) if !Context.firstRun { @@ -456,7 +472,7 @@ func run(args options, clientBuildFS fs.FS) { } sessFilename := filepath.Join(Context.getDataDir(), "sessions.db") - GLMode = args.glinetMode + GLMode = opts.glinetMode var rateLimiter *authRateLimiter if config.AuthAttempts > 0 && config.AuthBlockMin > 0 { rateLimiter = newAuthRateLimiter( @@ -483,7 +499,7 @@ func run(args options, clientBuildFS fs.FS) { log.Fatalf("Can't initialize TLS module") } - Context.web, err = initWeb(args, clientBuildFS) + Context.web, err = initWeb(opts, clientBuildFS) fatalOnError(err) if !Context.firstRun { @@ -575,10 +591,10 @@ func writePIDFile(fn string) bool { return true } -func initConfigFilename(args options) { +func initConfigFilename(opts options) { // config file path can be overridden by command-line arguments: - if args.configFilename != "" { - Context.configFilename = args.configFilename + if opts.confFilename != "" { + Context.configFilename = opts.confFilename } else { // Default config file name Context.configFilename = "AdGuardHome.yaml" @@ -587,15 +603,15 @@ func initConfigFilename(args options) { // initWorkingDir initializes the workDir // if no command-line arguments specified, we use the directory where our binary file is located -func initWorkingDir(args options) { +func initWorkingDir(opts options) { execPath, err := os.Executable() if err != nil { panic(err) } - if args.workDir != "" { + if opts.workDir != "" { // If there is a custom config file, use it's directory as our working dir - Context.workDir = args.workDir + Context.workDir = opts.workDir } else { Context.workDir = filepath.Dir(execPath) } @@ -609,15 +625,15 @@ func initWorkingDir(args options) { } // configureLogger configures logger level and output -func configureLogger(args options) { +func configureLogger(opts options) { ls := getLogSettings() // command-line arguments can override config settings - if args.verbose || config.Verbose { + if opts.verbose || config.Verbose { ls.Verbose = true } - if args.logFile != "" { - ls.File = args.logFile + if opts.logFile != "" { + ls.File = opts.logFile } else if config.File != "" { ls.File = config.File } @@ -638,7 +654,7 @@ func configureLogger(args options) { // happen pretty quickly. log.SetFlags(log.LstdFlags | log.Lmicroseconds) - if args.runningAsService && ls.File == "" && runtime.GOOS == "windows" { + if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" { // When running as a Windows service, use eventlog by default if nothing // else is configured. Otherwise, we'll simply lose the log output. ls.File = configSyslog @@ -728,25 +744,29 @@ func exitWithError() { os.Exit(64) } -// loadOptions reads command line arguments and initializes configuration -func loadOptions() options { - o, f, err := parse(os.Args[0], os.Args[1:]) - +// loadCmdLineOpts reads command line arguments and initializes configuration +// from them. If there is an error or an effect, loadCmdLineOpts processes them +// and exits. +func loadCmdLineOpts() (opts options) { + opts, eff, err := parseCmdOpts(os.Args[0], os.Args[1:]) if err != nil { log.Error(err.Error()) - _ = printHelp(os.Args[0]) + printHelp(os.Args[0]) + exitWithError() - } else if f != nil { - err = f() + } + + if eff != nil { + err = eff() if err != nil { log.Error(err.Error()) exitWithError() - } else { - os.Exit(0) } + + os.Exit(0) } - return o + return opts } // printWebAddrs prints addresses built from proto, addr, and an appropriate diff --git a/internal/home/home_test.go b/internal/home/home_test.go new file mode 100644 index 00000000000..1a611588f33 --- /dev/null +++ b/internal/home/home_test.go @@ -0,0 +1,12 @@ +package home + +import ( + "testing" + + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" +) + +func TestMain(m *testing.M) { + aghtest.DiscardLogOutput(m) + initCmdLineOpts() +} diff --git a/internal/home/options.go b/internal/home/options.go index 6f5a4d8d62a..531a0fd4146 100644 --- a/internal/home/options.go +++ b/internal/home/options.go @@ -5,122 +5,149 @@ import ( "net" "os" "strconv" + "strings" "github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/stringutil" ) -// options passed from command-line arguments +// TODO(a.garipov): Replace with package flag. + +// options represents the command-line options. type options struct { - verbose bool // is verbose logging enabled - configFilename string // path to the config file - workDir string // path to the working directory where we will store the filters data and the querylog - bindHost net.IP // host address to bind HTTP server on - bindPort int // port to serve HTTP pages on - logFile string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog - pidFile string // File name to save PID to - checkConfig bool // Check configuration and exit - disableUpdate bool // If set, don't check for updates - - // service control action (see service.ControlAction array + "status" command) - serviceControlAction string + // confFilename is the path to the configuration file. + confFilename string - // runningAsService flag is set to true when options are passed from the service runner - runningAsService bool + // workDir is the path to the working directory where AdGuard Home stores + // filter data, the query log, and other data. + workDir string - glinetMode bool // Activate GL-Inet compatibility mode + // logFile is the path to the log file. If empty, AdGuard Home writes to + // stdout; if "syslog", to syslog. + logFile string - // noEtcHosts flag should be provided when /etc/hosts file shouldn't be - // used. - noEtcHosts bool - - // localFrontend forces AdGuard Home to use the frontend files from disk - // rather than the ones that have been compiled into the binary. - localFrontend bool -} + // pidFile is the file name for the file to which the PID is saved. + pidFile string -// functions used for their side-effects -type effect func() error + // serviceControlAction is the service action to perform. See + // [service.ControlAction] and [handleServiceControlAction]. + serviceControlAction string -type arg struct { - description string // a short, English description of the argument - longName string // the name of the argument used after '--' - shortName string // the name of the argument used after '-' + // bindHost is the address on which to serve the HTTP UI. + bindHost net.IP - // only one of updateWithValue, updateNoValue, and effect should be present + // bindPort is the port on which to serve the HTTP UI. + bindPort int - updateWithValue func(o options, v string) (options, error) // the mutator for arguments with parameters - updateNoValue func(o options) (options, error) // the mutator for arguments without parameters - effect func(o options, exec string) (f effect, err error) // the side-effect closure generator + // checkConfig is true if the current invocation is only required to check + // the configuration file and exit. + checkConfig bool - serialize func(o options) []string // the re-serialization function back to arguments (return nil for omit) -} + // disableUpdate, if set, makes AdGuard Home not check for updates. + disableUpdate bool -// {type}SliceOrNil functions check their parameter of type {type} -// against its zero value and return nil if the parameter value is -// zero otherwise they return a string slice of the parameter + // verbose shows if verbose logging is enabled. + verbose bool -func ipSliceOrNil(ip net.IP) []string { - if ip == nil { - return nil - } + // runningAsService flag is set to true when options are passed from the + // service runner + // + // TODO(a.garipov): Perhaps this could be determined by a non-empty + // serviceControlAction? + runningAsService bool - return []string{ip.String()} -} + // glinetMode shows if the GL-Inet compatibility mode is enabled. + glinetMode bool -func stringSliceOrNil(s string) []string { - if s == "" { - return nil - } + // noEtcHosts flag should be provided when /etc/hosts file shouldn't be + // used. + noEtcHosts bool - return []string{s} + // localFrontend forces AdGuard Home to use the frontend files from disk + // rather than the ones that have been compiled into the binary. + localFrontend bool } -func intSliceOrNil(i int) []string { - if i == 0 { - return nil - } - - return []string{strconv.Itoa(i)} +// initCmdLineOpts completes initialization of the global command-line option +// slice. It must only be called once. +func initCmdLineOpts() { + // The --help option cannot be put directly into cmdLineOpts, because that + // causes initialization cycle due to printHelp referencing cmdLineOpts. + cmdLineOpts = append(cmdLineOpts, cmdLineOpt{ + updateWithValue: nil, + updateNoValue: nil, + effect: func(o options, exec string) (effect, error) { + return func() error { printHelp(exec); exitWithError(); return nil }, nil + }, + serialize: func(o options) (val string, ok bool) { return "", false }, + description: "Print this help.", + longName: "help", + shortName: "", + }) } -func boolSliceOrNil(b bool) []string { - if b { - return []string{} - } - - return nil -} +// effect is the type for functions used for their side-effects. +type effect func() (err error) -var args []arg +// cmdLineOpt contains the data for a single command-line option. Only one of +// updateWithValue, updateNoValue, and effect must be present. +type cmdLineOpt struct { + updateWithValue func(o options, v string) (updated options, err error) + updateNoValue func(o options) (updated options, err error) + effect func(o options, exec string) (eff effect, err error) -var configArg = arg{ - "Path to the config file.", - "config", "c", - func(o options, v string) (options, error) { o.configFilename = v; return o, nil }, - nil, - nil, - func(o options) []string { return stringSliceOrNil(o.configFilename) }, -} + // serialize is a function that encodes the option into a slice of + // command-line arguments, if necessary. If ok is false, this option should + // be skipped. + serialize func(o options) (val string, ok bool) -var workDirArg = arg{ - "Path to the working directory.", - "work-dir", "w", - func(o options, v string) (options, error) { o.workDir = v; return o, nil }, nil, nil, - func(o options) []string { return stringSliceOrNil(o.workDir) }, + description string + longName string + shortName string } -var hostArg = arg{ - "Host address to bind HTTP server on.", - "host", "h", - func(o options, v string) (options, error) { o.bindHost = net.ParseIP(v); return o, nil }, nil, nil, - func(o options) []string { return ipSliceOrNil(o.bindHost) }, -} +// cmdLineOpts are all command-line options of AdGuard Home. +var cmdLineOpts = []cmdLineOpt{{ + updateWithValue: func(o options, v string) (options, error) { + o.confFilename = v + return o, nil + }, + updateNoValue: nil, + effect: nil, + serialize: func(o options) (val string, ok bool) { + return o.confFilename, o.confFilename != "" + }, + description: "Path to the config file.", + longName: "config", + shortName: "c", +}, { + updateWithValue: func(o options, v string) (options, error) { o.workDir = v; return o, nil }, + updateNoValue: nil, + effect: nil, + serialize: func(o options) (val string, ok bool) { return o.workDir, o.workDir != "" }, + description: "Path to the working directory.", + longName: "work-dir", + shortName: "w", +}, { + updateWithValue: func(o options, v string) (options, error) { + o.bindHost = net.ParseIP(v) + return o, nil + }, + updateNoValue: nil, + effect: nil, + serialize: func(o options) (val string, ok bool) { + if o.bindHost == nil { + return "", false + } -var portArg = arg{ - "Port to serve HTTP pages on.", - "port", "p", - func(o options, v string) (options, error) { + return o.bindHost.String(), true + }, + description: "Host address to bind HTTP server on.", + longName: "host", + shortName: "h", +}, { + updateWithValue: func(o options, v string) (options, error) { var err error var p int minPort, maxPort := 0, 1<<16-1 @@ -131,108 +158,81 @@ var portArg = arg{ } else { o.bindPort = p } + return o, err - }, nil, nil, - func(o options) []string { return intSliceOrNil(o.bindPort) }, -} + }, + updateNoValue: nil, + effect: nil, + serialize: func(o options) (val string, ok bool) { + if o.bindPort == 0 { + return "", false + } -var serviceArg = arg{ - "Service control action: status, install, uninstall, start, stop, restart, reload (configuration).", - "service", "s", - func(o options, v string) (options, error) { + return strconv.Itoa(o.bindPort), true + }, + description: "Port to serve HTTP pages on.", + longName: "port", + shortName: "p", +}, { + updateWithValue: func(o options, v string) (options, error) { o.serviceControlAction = v return o, nil - }, nil, nil, - func(o options) []string { return stringSliceOrNil(o.serviceControlAction) }, -} - -var logfileArg = arg{ - "Path to log file. If empty: write to stdout; if 'syslog': write to system log.", - "logfile", "l", - func(o options, v string) (options, error) { o.logFile = v; return o, nil }, nil, nil, - func(o options) []string { return stringSliceOrNil(o.logFile) }, -} - -var pidfileArg = arg{ - "Path to a file where PID is stored.", - "pidfile", "", - func(o options, v string) (options, error) { o.pidFile = v; return o, nil }, nil, nil, - func(o options) []string { return stringSliceOrNil(o.pidFile) }, -} - -var checkConfigArg = arg{ - "Check configuration and exit.", - "check-config", "", - nil, func(o options) (options, error) { o.checkConfig = true; return o, nil }, nil, - func(o options) []string { return boolSliceOrNil(o.checkConfig) }, -} - -var noCheckUpdateArg = arg{ - "Don't check for updates.", - "no-check-update", "", - nil, func(o options) (options, error) { o.disableUpdate = true; return o, nil }, nil, - func(o options) []string { return boolSliceOrNil(o.disableUpdate) }, -} - -var disableMemoryOptimizationArg = arg{ - "Deprecated. Disable memory optimization.", - "no-mem-optimization", "", - nil, nil, func(_ options, _ string) (f effect, err error) { - log.Info("warning: using --no-mem-optimization flag has no effect and is deprecated") - - return nil, nil }, - func(o options) []string { return nil }, -} - -var verboseArg = arg{ - "Enable verbose output.", - "verbose", "v", - nil, func(o options) (options, error) { o.verbose = true; return o, nil }, nil, - func(o options) []string { return boolSliceOrNil(o.verbose) }, -} - -var glinetArg = arg{ - "Run in GL-Inet compatibility mode.", - "glinet", "", - nil, func(o options) (options, error) { o.glinetMode = true; return o, nil }, nil, - func(o options) []string { return boolSliceOrNil(o.glinetMode) }, -} - -var versionArg = arg{ - description: "Show the version and exit. Show more detailed version description with -v.", - longName: "version", + updateNoValue: nil, + effect: nil, + serialize: func(o options) (val string, ok bool) { + return o.serviceControlAction, o.serviceControlAction != "" + }, + description: `Service control action: status, install (as a service), ` + + `uninstall (as a service), start, stop, restart, reload (configuration).`, + longName: "service", + shortName: "s", +}, { + updateWithValue: func(o options, v string) (options, error) { o.logFile = v; return o, nil }, + updateNoValue: nil, + effect: nil, + serialize: func(o options) (val string, ok bool) { return o.logFile, o.logFile != "" }, + description: `Path to log file. If empty, write to stdout; ` + + `if "syslog", write to system log.`, + longName: "logfile", + shortName: "l", +}, { + updateWithValue: func(o options, v string) (options, error) { o.pidFile = v; return o, nil }, + updateNoValue: nil, + effect: nil, + serialize: func(o options) (val string, ok bool) { return o.pidFile, o.pidFile != "" }, + description: "Path to a file where PID is stored.", + longName: "pidfile", + shortName: "", +}, { + updateWithValue: nil, + updateNoValue: func(o options) (options, error) { o.checkConfig = true; return o, nil }, + effect: nil, + serialize: func(o options) (val string, ok bool) { return "", o.checkConfig }, + description: "Check configuration and exit.", + longName: "check-config", shortName: "", +}, { + updateWithValue: nil, + updateNoValue: func(o options) (options, error) { o.disableUpdate = true; return o, nil }, + effect: nil, + serialize: func(o options) (val string, ok bool) { return "", o.disableUpdate }, + description: "Don't check for updates.", + longName: "no-check-update", + shortName: "", +}, { updateWithValue: nil, updateNoValue: nil, - effect: func(o options, exec string) (effect, error) { - return func() error { - if o.verbose { - fmt.Println(version.Verbose()) - } else { - fmt.Println(version.Full()) - } - os.Exit(0) - - return nil - }, nil - }, - serialize: func(o options) []string { return nil }, -} + effect: func(_ options, _ string) (f effect, err error) { + log.Info("warning: using --no-mem-optimization flag has no effect and is deprecated") -var helpArg = arg{ - "Print this help.", - "help", "", - nil, nil, func(o options, exec string) (effect, error) { - return func() error { _ = printHelp(exec); os.Exit(64); return nil }, nil + return nil, nil }, - func(o options) []string { return nil }, -} - -var noEtcHostsArg = arg{ - description: "Deprecated. Do not use the OS-provided hosts.", - longName: "no-etc-hosts", - shortName: "", + serialize: func(o options) (val string, ok bool) { return "", false }, + description: "Deprecated. Disable memory optimization.", + longName: "no-mem-optimization", + shortName: "", +}, { updateWithValue: nil, updateNoValue: func(o options) (options, error) { o.noEtcHosts = true; return o, nil }, effect: func(_ options, _ string) (f effect, err error) { @@ -242,146 +242,216 @@ var noEtcHostsArg = arg{ return nil, nil }, - serialize: func(o options) []string { return boolSliceOrNil(o.noEtcHosts) }, -} - -var localFrontendArg = arg{ + serialize: func(o options) (val string, ok bool) { return "", o.noEtcHosts }, + description: "Deprecated. Do not use the OS-provided hosts.", + longName: "no-etc-hosts", + shortName: "", +}, { + updateWithValue: nil, + updateNoValue: func(o options) (options, error) { o.localFrontend = true; return o, nil }, + effect: nil, + serialize: func(o options) (val string, ok bool) { return "", o.localFrontend }, description: "Use local frontend directories.", longName: "local-frontend", shortName: "", +}, { updateWithValue: nil, - updateNoValue: func(o options) (options, error) { o.localFrontend = true; return o, nil }, + updateNoValue: func(o options) (options, error) { o.verbose = true; return o, nil }, effect: nil, - serialize: func(o options) []string { return boolSliceOrNil(o.localFrontend) }, -} + serialize: func(o options) (val string, ok bool) { return "", o.verbose }, + description: "Enable verbose output.", + longName: "verbose", + shortName: "v", +}, { + updateWithValue: nil, + updateNoValue: func(o options) (options, error) { o.glinetMode = true; return o, nil }, + effect: nil, + serialize: func(o options) (val string, ok bool) { return "", o.glinetMode }, + description: "Run in GL-Inet compatibility mode.", + longName: "glinet", + shortName: "", +}, { + updateWithValue: nil, + updateNoValue: nil, + effect: func(o options, exec string) (effect, error) { + return func() error { + if o.verbose { + fmt.Println(version.Verbose()) + } else { + fmt.Println(version.Full()) + } -func init() { - args = []arg{ - configArg, - workDirArg, - hostArg, - portArg, - serviceArg, - logfileArg, - pidfileArg, - checkConfigArg, - noCheckUpdateArg, - disableMemoryOptimizationArg, - noEtcHostsArg, - localFrontendArg, - verboseArg, - glinetArg, - versionArg, - helpArg, - } -} + os.Exit(0) -func getUsageLines(exec string, args []arg) []string { - usage := []string{ - "Usage:", - "", - fmt.Sprintf("%s [options]", exec), - "", - "Options:", - } - for _, arg := range args { + return nil + }, nil + }, + serialize: func(o options) (val string, ok bool) { return "", false }, + description: "Show the version and exit. Show more detailed version description with -v.", + longName: "version", + shortName: "", +}} + +// printHelp prints the entire help message. It exits with an error code if +// there are any I/O errors. +func printHelp(exec string) { + b := &strings.Builder{} + + stringutil.WriteToBuilder( + b, + "Usage:\n\n", + fmt.Sprintf("%s [options]\n\n", exec), + "Options:\n", + ) + + var err error + for _, opt := range cmdLineOpts { val := "" - if arg.updateWithValue != nil { + if opt.updateWithValue != nil { val = " VALUE" } - if arg.shortName != "" { - usage = append(usage, fmt.Sprintf(" -%s, %-30s %s", - arg.shortName, - "--"+arg.longName+val, - arg.description)) + + longDesc := opt.longName + val + if opt.shortName != "" { + _, err = fmt.Fprintf(b, " -%s, --%-28s %s\n", opt.shortName, longDesc, opt.description) } else { - usage = append(usage, fmt.Sprintf(" %-34s %s", - "--"+arg.longName+val, - arg.description)) + _, err = fmt.Fprintf(b, " --%-32s %s\n", longDesc, opt.description) } - } - return usage -} -func printHelp(exec string) error { - for _, line := range getUsageLines(exec, args) { - _, err := fmt.Println(line) if err != nil { - return err + // The only error here can be from incorrect Fprintf usage, which is + // a programmer error. + panic(err) } } - return nil -} -func argMatches(a arg, v string) bool { - return v == "--"+a.longName || (a.shortName != "" && v == "-"+a.shortName) + _, err = fmt.Print(b) + if err != nil { + // Exit immediately, since not being able to print out a help message + // essentially means that the I/O is very broken at the moment. + exitWithError() + } } -func parse(exec string, ss []string) (o options, f effect, err error) { - for i := 0; i < len(ss); i++ { - v := ss[i] - knownParam := false - for _, arg := range args { - if argMatches(arg, v) { - if arg.updateWithValue != nil { - if i+1 >= len(ss) { - return o, f, fmt.Errorf("got %s without argument", v) - } - i++ - o, err = arg.updateWithValue(o, ss[i]) - if err != nil { - return - } - } else if arg.updateNoValue != nil { - o, err = arg.updateNoValue(o) - if err != nil { - return - } - } else if arg.effect != nil { - var eff effect - eff, err = arg.effect(o, exec) - if err != nil { - return - } - if eff != nil { - prevf := f - f = func() (ferr error) { - if prevf != nil { - ferr = prevf() - } - if ferr == nil { - ferr = eff() - } - return ferr - } - } +// parseCmdOpts parses the command-line arguments into options and effects. +func parseCmdOpts(cmdName string, args []string) (o options, eff effect, err error) { + // Don't use range since the loop changes the loop variable. + argsLen := len(args) + for i := 0; i < len(args); i++ { + arg := args[i] + isKnown := false + for _, opt := range cmdLineOpts { + isKnown = argMatches(opt, arg) + if !isKnown { + continue + } + + if opt.updateWithValue != nil { + i++ + if i >= argsLen { + return o, eff, fmt.Errorf("got %s without argument", arg) } - knownParam = true - break + + o, err = opt.updateWithValue(o, args[i]) + } else { + o, eff, err = updateOptsNoValue(o, eff, opt, cmdName) } + + if err != nil { + return o, eff, fmt.Errorf("applying option %s: %w", arg, err) + } + + break } - if !knownParam { - return o, f, fmt.Errorf("unknown option %v", v) + + if !isKnown { + return o, eff, fmt.Errorf("unknown option %s", arg) + } + } + + return o, eff, err +} + +// argMatches returns true if arg matches command-line option opt. +func argMatches(opt cmdLineOpt, arg string) (ok bool) { + if arg == "" || arg[0] != '-' { + return false + } + + arg = arg[1:] + if arg == "" { + return false + } + + return (opt.shortName != "" && arg == opt.shortName) || + (arg[0] == '-' && arg[1:] == opt.longName) +} + +// updateOptsNoValue sets values or effects from opt into o or prev. +func updateOptsNoValue( + o options, + prev effect, + opt cmdLineOpt, + cmdName string, +) (updated options, chained effect, err error) { + if opt.updateNoValue != nil { + o, err = opt.updateNoValue(o) + if err != nil { + return o, prev, err } + + return o, prev, nil + } + + next, err := opt.effect(o, cmdName) + if err != nil { + return o, prev, err } - return + chained = chainEffect(prev, next) + + return o, chained, nil } -func shortestFlag(a arg) string { - if a.shortName != "" { - return "-" + a.shortName +// chainEffect chans the next effect after the prev one. If prev is nil, eff +// only calls next. If next is nil, eff is prev; if prev is nil, eff is next. +func chainEffect(prev, next effect) (eff effect) { + if prev == nil { + return next + } else if next == nil { + return prev + } + + eff = func() (err error) { + err = prev() + if err != nil { + return err + } + + return next() } - return "--" + a.longName + + return eff } -func serialize(o options) []string { - ss := []string{} - for _, arg := range args { - s := arg.serialize(o) - if s != nil { - ss = append(ss, append([]string{shortestFlag(arg)}, s...)...) +// optsToArgs converts command line options into a list of arguments. +func optsToArgs(o options) (args []string) { + for _, opt := range cmdLineOpts { + val, ok := opt.serialize(o) + if !ok { + continue + } + + if opt.shortName != "" { + args = append(args, "-"+opt.shortName) + } else { + args = append(args, "--"+opt.longName) + } + + if val != "" { + args = append(args, val) } } - return ss + + return args } diff --git a/internal/home/options_test.go b/internal/home/options_test.go index 21972b0a4ea..7954c0e47bc 100644 --- a/internal/home/options_test.go +++ b/internal/home/options_test.go @@ -12,7 +12,7 @@ import ( func testParseOK(t *testing.T, ss ...string) options { t.Helper() - o, _, err := parse("", ss) + o, _, err := parseCmdOpts("", ss) require.NoError(t, err) return o @@ -21,7 +21,7 @@ func testParseOK(t *testing.T, ss ...string) options { func testParseErr(t *testing.T, descr string, ss ...string) { t.Helper() - _, _, err := parse("", ss) + _, _, err := parseCmdOpts("", ss) require.Error(t, err) } @@ -38,11 +38,11 @@ func TestParseVerbose(t *testing.T) { } func TestParseConfigFilename(t *testing.T) { - assert.Equal(t, "", testParseOK(t).configFilename, "empty is no config filename") - assert.Equal(t, "path", testParseOK(t, "-c", "path").configFilename, "-c is config filename") + assert.Equal(t, "", testParseOK(t).confFilename, "empty is no config filename") + assert.Equal(t, "path", testParseOK(t, "-c", "path").confFilename, "-c is config filename") testParseParamMissing(t, "-c") - assert.Equal(t, "path", testParseOK(t, "--config", "path").configFilename, "--config is config filename") + assert.Equal(t, "path", testParseOK(t, "--config", "path").confFilename, "--config is config filename") testParseParamMissing(t, "--config") } @@ -103,7 +103,7 @@ func TestParseDisableUpdate(t *testing.T) { // TODO(e.burkov): Remove after v0.108.0. func TestParseDisableMemoryOptimization(t *testing.T) { - o, eff, err := parse("", []string{"--no-mem-optimization"}) + o, eff, err := parseCmdOpts("", []string{"--no-mem-optimization"}) require.NoError(t, err) assert.Nil(t, eff) @@ -130,73 +130,73 @@ func TestParseUnknown(t *testing.T) { testParseErr(t, "unknown dash", "-") } -func TestSerialize(t *testing.T) { +func TestOptsToArgs(t *testing.T) { testCases := []struct { name string + args []string opts options - ss []string }{{ name: "empty", + args: []string{}, opts: options{}, - ss: []string{}, }, { name: "config_filename", - opts: options{configFilename: "path"}, - ss: []string{"-c", "path"}, + args: []string{"-c", "path"}, + opts: options{confFilename: "path"}, }, { name: "work_dir", + args: []string{"-w", "path"}, opts: options{workDir: "path"}, - ss: []string{"-w", "path"}, }, { name: "bind_host", + args: []string{"-h", "1.2.3.4"}, opts: options{bindHost: net.IP{1, 2, 3, 4}}, - ss: []string{"-h", "1.2.3.4"}, }, { name: "bind_port", + args: []string{"-p", "666"}, opts: options{bindPort: 666}, - ss: []string{"-p", "666"}, }, { name: "log_file", + args: []string{"-l", "path"}, opts: options{logFile: "path"}, - ss: []string{"-l", "path"}, }, { name: "pid_file", + args: []string{"--pidfile", "path"}, opts: options{pidFile: "path"}, - ss: []string{"--pidfile", "path"}, }, { name: "disable_update", + args: []string{"--no-check-update"}, opts: options{disableUpdate: true}, - ss: []string{"--no-check-update"}, }, { name: "control_action", + args: []string{"-s", "run"}, opts: options{serviceControlAction: "run"}, - ss: []string{"-s", "run"}, }, { name: "glinet_mode", + args: []string{"--glinet"}, opts: options{glinetMode: true}, - ss: []string{"--glinet"}, }, { name: "multiple", - opts: options{ - serviceControlAction: "run", - configFilename: "config", - workDir: "work", - pidFile: "pid", - disableUpdate: true, - }, - ss: []string{ + args: []string{ "-c", "config", "-w", "work", "-s", "run", "--pidfile", "pid", "--no-check-update", }, + opts: options{ + serviceControlAction: "run", + confFilename: "config", + workDir: "work", + pidFile: "pid", + disableUpdate: true, + }, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - result := serialize(tc.opts) - assert.ElementsMatch(t, tc.ss, result) + result := optsToArgs(tc.opts) + assert.ElementsMatch(t, tc.args, result) }) } } diff --git a/internal/home/service.go b/internal/home/service.go index e52f9799aae..3aece1f2715 100644 --- a/internal/home/service.go +++ b/internal/home/service.go @@ -197,7 +197,7 @@ func handleServiceControlAction(opts options, clientBuildFS fs.FS) { DisplayName: serviceDisplayName, Description: serviceDescription, WorkingDirectory: pwd, - Arguments: serialize(runOpts), + Arguments: optsToArgs(runOpts), } configureService(svcConfig) diff --git a/scripts/make/go-lint.sh b/scripts/make/go-lint.sh index e04af72570f..8c462d5bad8 100644 --- a/scripts/make/go-lint.sh +++ b/scripts/make/go-lint.sh @@ -223,8 +223,7 @@ govulncheck ./... # Apply more lax standards to the code we haven't properly refactored yet. gocyclo --over 17 ./internal/querylog/ -gocyclo --over 15 ./internal/home/ ./internal/dhcpd -gocyclo --over 13 ./internal/filtering/ +gocyclo --over 13 ./internal/dhcpd ./internal/filtering/ ./internal/home/ # Apply stricter standards to new or somewhat refactored code. gocyclo --over 10 ./internal/aghio/ ./internal/aghnet/ ./internal/aghos/\ From 330ac303242970d79b557e611bec3227b0af5a7b Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Wed, 5 Oct 2022 18:11:09 +0300 Subject: [PATCH 5/9] Pull request: 3418-clientid-doh Closes #3418. Squashed commit of the following: commit 8a1180f8ef03d30ea3ae6a3e3121ddcac513f45b Author: Ainar Garipov Date: Wed Oct 5 17:26:22 2022 +0300 all: imp docs, tests commit 9629c69b39540db119044f2f79c1c4ed39de911f Author: Ainar Garipov Date: Wed Oct 5 15:34:33 2022 +0300 dnsforward: accept clientids from doh client srvname --- CHANGELOG.md | 10 +++++ internal/dnsforward/clientid.go | 67 +++++++++++++++++----------- internal/dnsforward/clientid_test.go | 67 +++++++++++++++++++++++++--- 3 files changed, 112 insertions(+), 32 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index be4d47b0536..551a3db1ff5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,16 @@ and this project adheres to ## [v0.108.0] - TBA (APPROX.) --> +## Added + +- The ability to put [ClientIDs][clientid] into DNS-over-HTTPS hostnames as + opposed to URL paths ([#3418]). Note that AdGuard Home checks the server name + only if the URL does not contain a ClientID. + +[#3418]: https://github.com/AdguardTeam/AdGuardHome/issues/3418 + +[clientid]: https://github.com/AdguardTeam/AdGuardHome/wiki/Clients#clientid + +## Security + +- Go version has been updated to prevent the possibility of exploiting the + CVE-2022-2879, CVE-2022-2880, and CVE-2022-41715 Go vulnerabilities fixed in + [Go 1.18.7][go-1.18.7]. + ## Added - The ability to put [ClientIDs][clientid] into DNS-over-HTTPS hostnames as @@ -23,7 +29,8 @@ and this project adheres to [#3418]: https://github.com/AdguardTeam/AdGuardHome/issues/3418 -[clientid]: https://github.com/AdguardTeam/AdGuardHome/wiki/Clients#clientid +[go-1.18.7]: https://groups.google.com/g/golang-announce/c/xtuG5faxtaU +[clientid]: https://github.com/AdguardTeam/AdGuardHome/wiki/Clients#clientid @@ -173,7 +180,7 @@ See also the [v0.107.12 GitHub milestone][ms-v0.107.12]. ### Security -- Go version was updated to prevent the possibility of exploiting the +- Go version has been updated to prevent the possibility of exploiting the CVE-2022-27664 and CVE-2022-32190 Go vulnerabilities fixed in [Go 1.18.6][go-1.18.6]. @@ -294,7 +301,7 @@ See also the [v0.107.9 GitHub milestone][ms-v0.107.9]. ### Security -- Go version was updated to prevent the possibility of exploiting the +- Go version has been updated to prevent the possibility of exploiting the CVE-2022-32189 Go vulnerability fixed in [Go 1.18.5][go-1.18.5]. Go 1.17 support has also been removed, as it has reached end of life and will not receive security updates. @@ -337,7 +344,7 @@ See also the [v0.107.8 GitHub milestone][ms-v0.107.8]. ### Security -- Go version was updated to prevent the possibility of exploiting the +- Go version has been updated to prevent the possibility of exploiting the CVE-2022-1705, CVE-2022-32148, CVE-2022-30631, and other Go vulnerabilities fixed in [Go 1.17.12][go-1.17.12]. @@ -373,7 +380,7 @@ See also the [v0.107.7 GitHub milestone][ms-v0.107.7]. ### Security -- Go version was updated to prevent the possibility of exploiting the +- Go version has been updated to prevent the possibility of exploiting the [CVE-2022-29526], [CVE-2022-30634], [CVE-2022-30629], [CVE-2022-30580], and [CVE-2022-29804] Go vulnerabilities. - Enforced password strength policy ([#3503]). @@ -530,7 +537,7 @@ See also the [v0.107.6 GitHub milestone][ms-v0.107.6]. ### Security - `User-Agent` HTTP header removed from outgoing DNS-over-HTTPS requests. -- Go version was updated to prevent the possibility of exploiting the +- Go version has been updated to prevent the possibility of exploiting the [CVE-2022-24675], [CVE-2022-27536], and [CVE-2022-28327] Go vulnerabilities. ### Added @@ -585,7 +592,7 @@ were resolved. ### Security -- Go version was updated to prevent the possibility of exploiting the +- Go version has been updated to prevent the possibility of exploiting the [CVE-2022-24921] Go vulnerability. [CVE-2022-24921]: https://www.cvedetails.com/cve/CVE-2022-24921 @@ -598,7 +605,7 @@ See also the [v0.107.4 GitHub milestone][ms-v0.107.4]. ### Security -- Go version was updated to prevent the possibility of exploiting the +- Go version has been updated to prevent the possibility of exploiting the [CVE-2022-23806], [CVE-2022-23772], and [CVE-2022-23773] Go vulnerabilities. ### Fixed diff --git a/bamboo-specs/release.yaml b/bamboo-specs/release.yaml index ddd957348a2..4232b734e80 100644 --- a/bamboo-specs/release.yaml +++ b/bamboo-specs/release.yaml @@ -7,7 +7,7 @@ # Make sure to sync any changes with the branch overrides below. 'variables': 'channel': 'edge' - 'dockerGo': 'adguard/golang-ubuntu:5.1' + 'dockerGo': 'adguard/golang-ubuntu:5.2' 'stages': - 'Build frontend': @@ -322,7 +322,7 @@ # need to build a few of these. 'variables': 'channel': 'beta' - 'dockerGo': 'adguard/golang-ubuntu:5.1' + 'dockerGo': 'adguard/golang-ubuntu:5.2' # release-vX.Y.Z branches are the branches from which the actual final release # is built. - '^release-v[0-9]+\.[0-9]+\.[0-9]+': @@ -337,4 +337,4 @@ # are the ones that actually get released. 'variables': 'channel': 'release' - 'dockerGo': 'adguard/golang-ubuntu:5.1' + 'dockerGo': 'adguard/golang-ubuntu:5.2' diff --git a/bamboo-specs/test.yaml b/bamboo-specs/test.yaml index fe26bd10f67..81796e1f273 100644 --- a/bamboo-specs/test.yaml +++ b/bamboo-specs/test.yaml @@ -5,7 +5,7 @@ 'key': 'AHBRTSPECS' 'name': 'AdGuard Home - Build and run tests' 'variables': - 'dockerGo': 'adguard/golang-ubuntu:5.1' + 'dockerGo': 'adguard/golang-ubuntu:5.2' 'stages': - 'Tests': From f1dd33346a8580bc9493e85dfc0c3e8b15032a1e Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Fri, 7 Oct 2022 17:05:01 +0300 Subject: [PATCH 7/9] Pull request: upd-chlog Merge in DNS/adguard-home from upd-chlog to master Squashed commit of the following: commit 8885f3f2291947d76203873dce0ccfd5c270fa7f Author: Ainar Garipov Date: Fri Oct 7 16:56:38 2022 +0300 all: upd chlog --- CHANGELOG.md | 35 ++++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b826cea4d7c..a8e31314e91 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,12 +15,6 @@ and this project adheres to ## [v0.108.0] - TBA (APPROX.) --> -## Security - -- Go version has been updated to prevent the possibility of exploiting the - CVE-2022-2879, CVE-2022-2880, and CVE-2022-41715 Go vulnerabilities fixed in - [Go 1.18.7][go-1.18.7]. - ## Added - The ability to put [ClientIDs][clientid] into DNS-over-HTTPS hostnames as @@ -29,21 +23,35 @@ and this project adheres to [#3418]: https://github.com/AdguardTeam/AdGuardHome/issues/3418 -[go-1.18.7]: https://groups.google.com/g/golang-announce/c/xtuG5faxtaU [clientid]: https://github.com/AdguardTeam/AdGuardHome/wiki/Clients#clientid +## [v0.107.16] - 2022-10-07 + +This is a security update. There is no GitHub milestone, since no GitHub issues +were resolved. + +## Security + +- Go version has been updated to prevent the possibility of exploiting the + CVE-2022-2879, CVE-2022-2880, and CVE-2022-41715 Go vulnerabilities fixed in + [Go 1.18.7][go-1.18.7]. + +[go-1.18.7]: https://groups.google.com/g/golang-announce/c/xtuG5faxtaU + + + ## [v0.107.15] - 2022-10-03 See also the [v0.107.15 GitHub milestone][ms-v0.107.15]. @@ -1342,11 +1350,12 @@ See also the [v0.104.2 GitHub milestone][ms-v0.104.2]. -[Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.15...HEAD +[Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.16...HEAD +[v0.107.16]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.15...v0.107.16 [v0.107.15]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.14...v0.107.15 [v0.107.14]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.13...v0.107.14 [v0.107.13]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.12...v0.107.13 From f5602d9c46fedf850f70ba08dc3cc73113e75483 Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Mon, 10 Oct 2022 14:05:24 +0300 Subject: [PATCH 8/9] Pull request: hup-reload Merge in DNS/adguard-home from hup-reload to master Squashed commit of the following: commit 5cd4ab85bdc7544a4eded2a61f5a5571175daa44 Author: Ainar Garipov Date: Fri Oct 7 19:58:17 2022 +0300 next: imp signal hdlr commit 8fd18e749fec46982d26fc408e661bd802586c37 Merge: a8780455 f1dd3334 Author: Ainar Garipov Date: Fri Oct 7 19:46:48 2022 +0300 Merge branch 'master' into hup-reload commit a87804550e15d7fe3d9ded2e5a736c395f96febd Merge: 349dbe54 960a7a75 Author: Ainar Garipov Date: Fri Oct 7 15:49:23 2022 +0300 Merge branch 'master' into hup-reload commit 349dbe54fe27eeaf56776c73c3cc5649018d4c60 Author: Ainar Garipov Date: Fri Oct 7 15:43:52 2022 +0300 next: imp docs, names commit 7287a86d283489127453009267911003cea5227e Author: Ainar Garipov Date: Fri Oct 7 13:39:44 2022 +0300 WIP all: impl dynamic reconfiguration --- internal/aghos/filewalker_internal_test.go | 6 +- internal/aghos/os.go | 10 + internal/aghos/os_unix.go | 8 + internal/aghos/os_windows.go | 8 + internal/aghtest/interface.go | 17 +- internal/aghtest/interface_test.go | 8 +- internal/next/agh/agh.go | 40 +++- internal/next/cmd/cmd.go | 33 ++-- internal/next/cmd/signal.go | 68 ++++++- internal/next/configmgr/config.go | 40 ++++ internal/next/configmgr/configmgr.go | 205 +++++++++++++++++++++ internal/next/websvc/dns_test.go | 3 +- internal/next/websvc/http.go | 3 +- internal/next/websvc/http_test.go | 3 +- internal/next/websvc/settings_test.go | 5 +- internal/next/websvc/websvc.go | 15 +- internal/next/websvc/websvc_test.go | 13 +- scripts/make/go-build.sh | 4 +- 18 files changed, 418 insertions(+), 71 deletions(-) create mode 100644 internal/next/configmgr/config.go create mode 100644 internal/next/configmgr/configmgr.go diff --git a/internal/aghos/filewalker_internal_test.go b/internal/aghos/filewalker_internal_test.go index bb162812fdb..732afc9bac2 100644 --- a/internal/aghos/filewalker_internal_test.go +++ b/internal/aghos/filewalker_internal_test.go @@ -15,11 +15,11 @@ import ( // errFSOpen. type errFS struct{} -// errFSOpen is returned from errGlobFS.Open. +// errFSOpen is returned from errFS.Open. const errFSOpen errors.Error = "test open error" -// Open implements the fs.FS interface for *errGlobFS. fsys is always nil and -// err is always errFSOpen. +// Open implements the fs.FS interface for *errFS. fsys is always nil and err +// is always errFSOpen. func (efs *errFS) Open(name string) (fsys fs.File, err error) { return nil, errFSOpen } diff --git a/internal/aghos/os.go b/internal/aghos/os.go index b39ecbbd857..26201df2277 100644 --- a/internal/aghos/os.go +++ b/internal/aghos/os.go @@ -175,11 +175,21 @@ func RootDirFS() (fsys fs.FS) { return os.DirFS("") } +// NotifyReconfigureSignal notifies c on receiving reconfigure signals. +func NotifyReconfigureSignal(c chan<- os.Signal) { + notifyReconfigureSignal(c) +} + // NotifyShutdownSignal notifies c on receiving shutdown signals. func NotifyShutdownSignal(c chan<- os.Signal) { notifyShutdownSignal(c) } +// IsReconfigureSignal returns true if sig is a reconfigure signal. +func IsReconfigureSignal(sig os.Signal) (ok bool) { + return isReconfigureSignal(sig) +} + // IsShutdownSignal returns true if sig is a shutdown signal. func IsShutdownSignal(sig os.Signal) (ok bool) { return isShutdownSignal(sig) diff --git a/internal/aghos/os_unix.go b/internal/aghos/os_unix.go index da8ee912d8c..7e04f0c0c82 100644 --- a/internal/aghos/os_unix.go +++ b/internal/aghos/os_unix.go @@ -9,10 +9,18 @@ import ( "golang.org/x/sys/unix" ) +func notifyReconfigureSignal(c chan<- os.Signal) { + signal.Notify(c, unix.SIGHUP) +} + func notifyShutdownSignal(c chan<- os.Signal) { signal.Notify(c, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM) } +func isReconfigureSignal(sig os.Signal) (ok bool) { + return sig == unix.SIGHUP +} + func isShutdownSignal(sig os.Signal) (ok bool) { switch sig { case diff --git a/internal/aghos/os_windows.go b/internal/aghos/os_windows.go index c79a603fd8c..d22c1fdd4db 100644 --- a/internal/aghos/os_windows.go +++ b/internal/aghos/os_windows.go @@ -39,12 +39,20 @@ func isOpenWrt() (ok bool) { return false } +func notifyReconfigureSignal(c chan<- os.Signal) { + signal.Notify(c, windows.SIGHUP) +} + func notifyShutdownSignal(c chan<- os.Signal) { // syscall.SIGTERM is processed automatically. See go doc os/signal, // section Windows. signal.Notify(c, os.Interrupt) } +func isReconfigureSignal(sig os.Signal) (ok bool) { + return sig == windows.SIGHUP +} + func isShutdownSignal(sig os.Signal) (ok bool) { switch sig { case diff --git a/internal/aghtest/interface.go b/internal/aghtest/interface.go index 7aae35ee3ee..ea91988904b 100644 --- a/internal/aghtest/interface.go +++ b/internal/aghtest/interface.go @@ -6,6 +6,7 @@ import ( "net" "github.com/AdguardTeam/AdGuardHome/internal/aghos" + "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/miekg/dns" ) @@ -88,7 +89,7 @@ func (l *Listener) Close() (err error) { return l.OnClose() } -// Module AdGuardHome +// Module adguard-home // Package aghos @@ -117,29 +118,31 @@ func (w *FSWatcher) Close() (err error) { return w.OnClose() } -// Package websvc +// Package agh -// ServiceWithConfig is a mock [websvc.ServiceWithConfig] implementation for -// tests. +// type check +var _ agh.ServiceWithConfig[struct{}] = (*ServiceWithConfig[struct{}])(nil) + +// ServiceWithConfig is a mock [agh.ServiceWithConfig] implementation for tests. type ServiceWithConfig[ConfigType any] struct { OnStart func() (err error) OnShutdown func(ctx context.Context) (err error) OnConfig func() (c ConfigType) } -// Start implements the [websvc.ServiceWithConfig] interface for +// Start implements the [agh.ServiceWithConfig] interface for // *ServiceWithConfig. func (s *ServiceWithConfig[_]) Start() (err error) { return s.OnStart() } -// Shutdown implements the [websvc.ServiceWithConfig] interface for +// Shutdown implements the [agh.ServiceWithConfig] interface for // *ServiceWithConfig. func (s *ServiceWithConfig[_]) Shutdown(ctx context.Context) (err error) { return s.OnShutdown(ctx) } -// Config implements the [websvc.ServiceWithConfig] interface for +// Config implements the [agh.ServiceWithConfig] interface for // *ServiceWithConfig. func (s *ServiceWithConfig[ConfigType]) Config() (c ConfigType) { return s.OnConfig() diff --git a/internal/aghtest/interface_test.go b/internal/aghtest/interface_test.go index bd2c0823e84..9141d132c54 100644 --- a/internal/aghtest/interface_test.go +++ b/internal/aghtest/interface_test.go @@ -1,9 +1,3 @@ package aghtest_test -import ( - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" - "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" -) - -// type check -var _ websvc.ServiceWithConfig[struct{}] = (*aghtest.ServiceWithConfig[struct{}])(nil) +// Put interface checks that cause import cycles here. diff --git a/internal/next/agh/agh.go b/internal/next/agh/agh.go index 212da4d685b..52855524734 100644 --- a/internal/next/agh/agh.go +++ b/internal/next/agh/agh.go @@ -1,6 +1,4 @@ // Package agh contains common entities and interfaces of AdGuard Home. -// -// TODO(a.garipov): Move to the upper-level internal/. package agh import "context" @@ -23,11 +21,43 @@ type Service interface { // type check var _ Service = EmptyService{} -// EmptyService is a Service that does nothing. +// EmptyService is a [Service] that does nothing. +// +// TODO(a.garipov): Remove if unnecessary. type EmptyService struct{} -// Start implements the Service interface for EmptyService. +// Start implements the [Service] interface for EmptyService. func (EmptyService) Start() (err error) { return nil } -// Shutdown implements the Service interface for EmptyService. +// Shutdown implements the [Service] interface for EmptyService. func (EmptyService) Shutdown(_ context.Context) (err error) { return nil } + +// ServiceWithConfig is an extension of the [Service] interface for services +// that can return their configuration. +// +// TODO(a.garipov): Consider removing this generic interface if we figure out +// how to make it testable in a better way. +type ServiceWithConfig[ConfigType any] interface { + Service + + Config() (c ConfigType) +} + +// type check +var _ ServiceWithConfig[struct{}] = (*EmptyServiceWithConfig[struct{}])(nil) + +// EmptyServiceWithConfig is a ServiceWithConfig that does nothing. Its Config +// method returns Conf. +// +// TODO(a.garipov): Remove if unnecessary. +type EmptyServiceWithConfig[ConfigType any] struct { + EmptyService + + Conf ConfigType +} + +// Config implements the [ServiceWithConfig] interface for +// *EmptyServiceWithConfig. +func (s *EmptyServiceWithConfig[ConfigType]) Config() (conf ConfigType) { + return s.Conf +} diff --git a/internal/next/cmd/cmd.go b/internal/next/cmd/cmd.go index 5b329abf4a6..d2cc9c809f2 100644 --- a/internal/next/cmd/cmd.go +++ b/internal/next/cmd/cmd.go @@ -8,10 +8,11 @@ import ( "context" "io/fs" "math/rand" - "net/netip" + "os" "time" - "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" + "github.com/AdguardTeam/AdGuardHome/internal/next/configmgr" + "github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/golibs/log" ) @@ -24,26 +25,32 @@ func Main(clientBuildFS fs.FS) { // TODO(a.garipov): Set up logging. + log.Info("starting adguard home, version %s, pid %d", version.Version(), os.Getpid()) + // Web Service // TODO(a.garipov): Use in the Web service. _ = clientBuildFS - // TODO(a.garipov): Make configurable. - web := websvc.New(&websvc.Config{ - // TODO(a.garipov): Use an actual implementation. - ConfigManager: nil, - Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:3001")}, - Start: start, - Timeout: 60 * time.Second, - ForceHTTPS: false, - }) - - err := web.Start() + // TODO(a.garipov): Set up configuration file name. + const confFile = "AdGuardHome.1.yaml" + + confMgr, err := configmgr.New(confFile, start) + fatalOnError(err) + + web := confMgr.Web() + err = web.Start() + fatalOnError(err) + + dns := confMgr.DNS() + err = dns.Start() fatalOnError(err) sigHdlr := newSignalHandler( + confFile, + start, web, + dns, ) go sigHdlr.handle() diff --git a/internal/next/cmd/signal.go b/internal/next/cmd/signal.go index 122f3f2c7dd..640d090b4d8 100644 --- a/internal/next/cmd/signal.go +++ b/internal/next/cmd/signal.go @@ -2,18 +2,26 @@ package cmd import ( "os" + "time" "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/next/agh" + "github.com/AdguardTeam/AdGuardHome/internal/next/configmgr" "github.com/AdguardTeam/golibs/log" ) // signalHandler processes incoming signals and shuts services down. type signalHandler struct { + // signal is the channel to which OS signals are sent. signal chan os.Signal - // services are the services that are shut down before application - // exiting. + // confFile is the path to the configuration file. + confFile string + + // start is the time at which AdGuard Home has been started. + start time.Time + + // services are the services that are shut down before application exiting. services []agh.Service } @@ -24,12 +32,51 @@ func (h *signalHandler) handle() { for sig := range h.signal { log.Info("sighdlr: received signal %q", sig) - if aghos.IsShutdownSignal(sig) { - h.shutdown() + if aghos.IsReconfigureSignal(sig) { + h.reconfigure() + } else if aghos.IsShutdownSignal(sig) { + status := h.shutdown() + log.Info("sighdlr: exiting with status %d", status) + + os.Exit(status) } } } +// reconfigure rereads the configuration file and updates and restarts services. +func (h *signalHandler) reconfigure() { + log.Info("sighdlr: reconfiguring adguard home") + + status := h.shutdown() + if status != statusSuccess { + log.Info("sighdlr: reconfiruging: exiting with status %d", status) + + os.Exit(status) + } + + // TODO(a.garipov): This is a very rough way to do it. Some services can be + // reconfigured without the full shutdown, and the error handling is + // currently not the best. + + confMgr, err := configmgr.New(h.confFile, h.start) + fatalOnError(err) + + web := confMgr.Web() + err = web.Start() + fatalOnError(err) + + dns := confMgr.DNS() + err = dns.Start() + fatalOnError(err) + + h.services = []agh.Service{ + dns, + web, + } + + log.Info("sighdlr: successfully reconfigured adguard home") +} + // Exit status constants. const ( statusSuccess = 0 @@ -37,11 +84,11 @@ const ( ) // shutdown gracefully shuts down all services. -func (h *signalHandler) shutdown() { +func (h *signalHandler) shutdown() (status int) { ctx, cancel := ctxWithDefaultTimeout() defer cancel() - status := statusSuccess + status = statusSuccess log.Info("sighdlr: shutting down services") for i, service := range h.services { @@ -52,19 +99,20 @@ func (h *signalHandler) shutdown() { } } - log.Info("sighdlr: shutting down adguard home") - - os.Exit(status) + return status } // newSignalHandler returns a new signalHandler that shuts down svcs. -func newSignalHandler(svcs ...agh.Service) (h *signalHandler) { +func newSignalHandler(confFile string, start time.Time, svcs ...agh.Service) (h *signalHandler) { h = &signalHandler{ signal: make(chan os.Signal, 1), + confFile: confFile, + start: start, services: svcs, } aghos.NotifyShutdownSignal(h.signal) + aghos.NotifyReconfigureSignal(h.signal) return h } diff --git a/internal/next/configmgr/config.go b/internal/next/configmgr/config.go new file mode 100644 index 00000000000..d11d8c1a79b --- /dev/null +++ b/internal/next/configmgr/config.go @@ -0,0 +1,40 @@ +package configmgr + +import ( + "net/netip" + + "github.com/AdguardTeam/golibs/timeutil" +) + +// Configuration Structures + +// config is the top-level on-disk configuration structure. +type config struct { + DNS *dnsConfig `yaml:"dns"` + HTTP *httpConfig `yaml:"http"` + // TODO(a.garipov): Use. + SchemaVersion int `yaml:"schema_version"` + // TODO(a.garipov): Use. + DebugPprof bool `yaml:"debug_pprof"` + Verbose bool `yaml:"verbose"` +} + +// dnsConfig is the on-disk DNS configuration. +// +// TODO(a.garipov): Validate. +type dnsConfig struct { + Addresses []netip.AddrPort `yaml:"addresses"` + BootstrapDNS []string `yaml:"bootstrap_dns"` + UpstreamDNS []string `yaml:"upstream_dns"` + UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"` +} + +// httpConfig is the on-disk web API configuration. +// +// TODO(a.garipov): Validate. +type httpConfig struct { + Addresses []netip.AddrPort `yaml:"addresses"` + SecureAddresses []netip.AddrPort `yaml:"secure_addresses"` + Timeout timeutil.Duration `yaml:"timeout"` + ForceHTTPS bool `yaml:"force_https"` +} diff --git a/internal/next/configmgr/configmgr.go b/internal/next/configmgr/configmgr.go new file mode 100644 index 00000000000..5b0422743e6 --- /dev/null +++ b/internal/next/configmgr/configmgr.go @@ -0,0 +1,205 @@ +// Package configmgr defines the AdGuard Home on-disk configuration entities and +// configuration manager. +package configmgr + +import ( + "context" + "fmt" + "os" + "sync" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/next/agh" + "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" + "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/log" + "gopkg.in/yaml.v3" +) + +// Configuration Manager + +// Manager handles full and partial changes in the configuration, persisting +// them to disk if necessary. +type Manager struct { + // updMu makes sure that at most one reconfiguration is performed at a time. + // updMu protects all fields below. + updMu *sync.RWMutex + + // dns is the DNS service. + dns *dnssvc.Service + + // Web is the Web API service. + web *websvc.Service + + // current is the current configuration. + current *config + + // fileName is the name of the configuration file. + fileName string +} + +// New creates a new *Manager that persists changes to the file pointed to by +// fileName. It reads the configuration file and populates the service fields. +// start is the startup time of AdGuard Home. +func New(fileName string, start time.Time) (m *Manager, err error) { + defer func() { err = errors.Annotate(err, "reading config") }() + + conf := &config{} + f, err := os.Open(fileName) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return nil, err + } + defer func() { err = errors.WithDeferred(err, f.Close()) }() + + err = yaml.NewDecoder(f).Decode(conf) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return nil, err + } + + // TODO(a.garipov): Move into a separate function and add other logging + // settings. + if conf.Verbose { + log.SetLevel(log.DEBUG) + } + + // TODO(a.garipov): Validate the configuration structure. Return an error + // if it's incorrect. + + m = &Manager{ + updMu: &sync.RWMutex{}, + current: conf, + fileName: fileName, + } + + // TODO(a.garipov): Get the context with the timeout from the arguments? + const assemblyTimeout = 5 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), assemblyTimeout) + defer cancel() + + err = m.assemble(ctx, conf, start) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return nil, err + } + + return m, nil +} + +// assemble creates all services and puts them into the corresponding fields. +// The fields of conf must not be modified after calling assemble. +func (m *Manager) assemble(ctx context.Context, conf *config, start time.Time) (err error) { + dnsConf := &dnssvc.Config{ + Addresses: conf.DNS.Addresses, + BootstrapServers: conf.DNS.BootstrapDNS, + UpstreamServers: conf.DNS.UpstreamDNS, + UpstreamTimeout: conf.DNS.UpstreamTimeout.Duration, + } + err = m.updateDNS(ctx, dnsConf) + if err != nil { + return fmt.Errorf("assembling dnssvc: %w", err) + } + + webSvcConf := &websvc.Config{ + ConfigManager: m, + // TODO(a.garipov): Fill from config file. + TLS: nil, + Start: start, + Addresses: conf.HTTP.Addresses, + SecureAddresses: conf.HTTP.SecureAddresses, + Timeout: conf.HTTP.Timeout.Duration, + ForceHTTPS: conf.HTTP.ForceHTTPS, + } + + err = m.updateWeb(ctx, webSvcConf) + if err != nil { + return fmt.Errorf("assembling websvc: %w", err) + } + + return nil +} + +// DNS returns the current DNS service. It is safe for concurrent use. +func (m *Manager) DNS() (dns agh.ServiceWithConfig[*dnssvc.Config]) { + m.updMu.RLock() + defer m.updMu.RUnlock() + + return m.dns +} + +// UpdateDNS implements the [websvc.ConfigManager] interface for *Manager. The +// fields of c must not be modified after calling UpdateDNS. +func (m *Manager) UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error) { + m.updMu.Lock() + defer m.updMu.Unlock() + + // TODO(a.garipov): Update and write the configuration file. Return an + // error if something went wrong. + + err = m.updateDNS(ctx, c) + if err != nil { + return fmt.Errorf("reassembling dnssvc: %w", err) + } + + return nil +} + +// updateDNS recreates the DNS service. m.updMu is expected to be locked. +func (m *Manager) updateDNS(ctx context.Context, c *dnssvc.Config) (err error) { + if prev := m.dns; prev != nil { + err = prev.Shutdown(ctx) + if err != nil { + return fmt.Errorf("shutting down dns svc: %w", err) + } + } + + svc, err := dnssvc.New(c) + if err != nil { + return fmt.Errorf("creating dns svc: %w", err) + } + + m.dns = svc + + return nil +} + +// Web returns the current web service. It is safe for concurrent use. +func (m *Manager) Web() (web agh.ServiceWithConfig[*websvc.Config]) { + m.updMu.RLock() + defer m.updMu.RUnlock() + + return m.web +} + +// UpdateWeb implements the [websvc.ConfigManager] interface for *Manager. The +// fields of c must not be modified after calling UpdateWeb. +func (m *Manager) UpdateWeb(ctx context.Context, c *websvc.Config) (err error) { + m.updMu.Lock() + defer m.updMu.Unlock() + + // TODO(a.garipov): Update and write the configuration file. Return an + // error if something went wrong. + + err = m.updateWeb(ctx, c) + if err != nil { + return fmt.Errorf("reassembling websvc: %w", err) + } + + return nil +} + +// updateWeb recreates the web service. m.upd is expected to be locked. +func (m *Manager) updateWeb(ctx context.Context, c *websvc.Config) (err error) { + if prev := m.web; prev != nil { + err = prev.Shutdown(ctx) + if err != nil { + return fmt.Errorf("shutting down web svc: %w", err) + } + } + + m.web = websvc.New(c) + + return nil +} diff --git a/internal/next/websvc/dns_test.go b/internal/next/websvc/dns_test.go index f774c3d87dd..d0efec8734a 100644 --- a/internal/next/websvc/dns_test.go +++ b/internal/next/websvc/dns_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" "github.com/stretchr/testify/assert" @@ -28,7 +29,7 @@ func TestService_HandlePatchSettingsDNS(t *testing.T) { // TODO(a.garipov): Use [atomic.Bool] in Go 1.19. var numStarted uint64 confMgr := newConfigManager() - confMgr.onDNS = func() (s websvc.ServiceWithConfig[*dnssvc.Config]) { + confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) { return &aghtest.ServiceWithConfig[*dnssvc.Config]{ OnStart: func() (err error) { atomic.AddUint64(&numStarted, 1) diff --git a/internal/next/websvc/http.go b/internal/next/websvc/http.go index b58eecb9499..c6107cd0501 100644 --- a/internal/next/websvc/http.go +++ b/internal/next/websvc/http.go @@ -8,6 +8,7 @@ import ( "net/netip" "time" + "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/golibs/log" ) @@ -89,7 +90,7 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque // TODO(a.garipov): Consider better ways to do this. const maxUpdDur = 10 * time.Second updStart := time.Now() - var newSvc ServiceWithConfig[*Config] + var newSvc agh.ServiceWithConfig[*Config] for newSvc = svc.confMgr.Web(); newSvc == svc; { if time.Since(updStart) >= maxUpdDur { log.Error("websvc: failed to update svc after %s", maxUpdDur) diff --git a/internal/next/websvc/http_test.go b/internal/next/websvc/http_test.go index baf384da296..d79be735d88 100644 --- a/internal/next/websvc/http_test.go +++ b/internal/next/websvc/http_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -24,7 +25,7 @@ func TestService_HandlePatchSettingsHTTP(t *testing.T) { } confMgr := newConfigManager() - confMgr.onWeb = func() (s websvc.ServiceWithConfig[*websvc.Config]) { + confMgr.onWeb = func() (s agh.ServiceWithConfig[*websvc.Config]) { return websvc.New(&websvc.Config{ TLS: &tls.Config{ Certificates: []tls.Certificate{{}}, diff --git a/internal/next/websvc/settings_test.go b/internal/next/websvc/settings_test.go index dadb4b55ea3..3dfc63fc8a2 100644 --- a/internal/next/websvc/settings_test.go +++ b/internal/next/websvc/settings_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" "github.com/stretchr/testify/assert" @@ -33,7 +34,7 @@ func TestService_HandleGetSettingsAll(t *testing.T) { } confMgr := newConfigManager() - confMgr.onDNS = func() (s websvc.ServiceWithConfig[*dnssvc.Config]) { + confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) { c, err := dnssvc.New(&dnssvc.Config{ Addresses: wantDNS.Addresses, UpstreamServers: wantDNS.UpstreamServers, @@ -45,7 +46,7 @@ func TestService_HandleGetSettingsAll(t *testing.T) { return c } - confMgr.onWeb = func() (s websvc.ServiceWithConfig[*websvc.Config]) { + confMgr.onWeb = func() (s agh.ServiceWithConfig[*websvc.Config]) { return websvc.New(&websvc.Config{ TLS: &tls.Config{ Certificates: []tls.Certificate{{}}, diff --git a/internal/next/websvc/websvc.go b/internal/next/websvc/websvc.go index 75f7d001f69..054228897a1 100644 --- a/internal/next/websvc/websvc.go +++ b/internal/next/websvc/websvc.go @@ -24,21 +24,10 @@ import ( httptreemux "github.com/dimfeld/httptreemux/v5" ) -// ServiceWithConfig is an extension of the [agh.Service] interface for services -// that can return their configuration. -// -// TODO(a.garipov): Consider removing this generic interface if we figure out -// how to make it testable in a better way. -type ServiceWithConfig[ConfigType any] interface { - agh.Service - - Config() (c ConfigType) -} - // ConfigManager is the configuration manager interface. type ConfigManager interface { - DNS() (svc ServiceWithConfig[*dnssvc.Config]) - Web() (svc ServiceWithConfig[*Config]) + DNS() (svc agh.ServiceWithConfig[*dnssvc.Config]) + Web() (svc agh.ServiceWithConfig[*Config]) UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error) UpdateWeb(ctx context.Context, c *Config) (err error) diff --git a/internal/next/websvc/websvc_test.go b/internal/next/websvc/websvc_test.go index dbce77d58a2..39ab30389df 100644 --- a/internal/next/websvc/websvc_test.go +++ b/internal/next/websvc/websvc_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" "github.com/AdguardTeam/golibs/testutil" @@ -34,20 +35,20 @@ var _ websvc.ConfigManager = (*configManager)(nil) // configManager is a [websvc.ConfigManager] for tests. type configManager struct { - onDNS func() (svc websvc.ServiceWithConfig[*dnssvc.Config]) - onWeb func() (svc websvc.ServiceWithConfig[*websvc.Config]) + onDNS func() (svc agh.ServiceWithConfig[*dnssvc.Config]) + onWeb func() (svc agh.ServiceWithConfig[*websvc.Config]) onUpdateDNS func(ctx context.Context, c *dnssvc.Config) (err error) onUpdateWeb func(ctx context.Context, c *websvc.Config) (err error) } // DNS implements the [websvc.ConfigManager] interface for *configManager. -func (m *configManager) DNS() (svc websvc.ServiceWithConfig[*dnssvc.Config]) { +func (m *configManager) DNS() (svc agh.ServiceWithConfig[*dnssvc.Config]) { return m.onDNS() } // Web implements the [websvc.ConfigManager] interface for *configManager. -func (m *configManager) Web() (svc websvc.ServiceWithConfig[*websvc.Config]) { +func (m *configManager) Web() (svc agh.ServiceWithConfig[*websvc.Config]) { return m.onWeb() } @@ -64,8 +65,8 @@ func (m *configManager) UpdateWeb(ctx context.Context, c *websvc.Config) (err er // newConfigManager returns a *configManager all methods of which panic. func newConfigManager() (m *configManager) { return &configManager{ - onDNS: func() (svc websvc.ServiceWithConfig[*dnssvc.Config]) { panic("not implemented") }, - onWeb: func() (svc websvc.ServiceWithConfig[*websvc.Config]) { panic("not implemented") }, + onDNS: func() (svc agh.ServiceWithConfig[*dnssvc.Config]) { panic("not implemented") }, + onWeb: func() (svc agh.ServiceWithConfig[*websvc.Config]) { panic("not implemented") }, onUpdateDNS: func(_ context.Context, _ *dnssvc.Config) (err error) { panic("not implemented") }, diff --git a/scripts/make/go-build.sh b/scripts/make/go-build.sh index c998a61168d..8d993d66fb5 100644 --- a/scripts/make/go-build.sh +++ b/scripts/make/go-build.sh @@ -124,11 +124,11 @@ GO111MODULE='on' export CGO_ENABLED GO111MODULE # Build the new binary if requested. -if [ "${V1API:-0}" -eq '0' ] +if [ "${NEXTAPI:-0}" -eq '0' ] then tags_flags='--tags=' else - tags_flags='--tags=v1' + tags_flags='--tags=next' fi readonly tags_flags From 0eba31ca031a2e712f6a6d00c355bc28635820bb Mon Sep 17 00:00:00 2001 From: Ildar Kamalov Date: Mon, 10 Oct 2022 17:49:19 +0300 Subject: [PATCH 9/9] Pull request: 4815 fix table view of the query log modal Updates #4815 Squashed commit of the following: commit a547c546a2b3cdbfb6988c910d8a970e0189ae5a Merge: 3c1e745d f5602d9c Author: Ildar Kamalov Date: Mon Oct 10 17:40:38 2022 +0300 Merge branch 'master' into 4815-tablet-view-fix commit 3c1e745dc2e34a62be8264ad003b5e6c155bb241 Author: Ildar Kamalov Date: Mon Oct 10 16:50:10 2022 +0300 fix mobile view commit a1d0b36473982854eecf1d96cf5a7033059e7720 Author: Ildar Kamalov Date: Sun Oct 9 17:57:14 2022 +0300 client: fix styles commit f34f928e1dbeef5ed37a0de3515be8d12f2241f6 Author: Ildar Kamalov Date: Sun Oct 9 16:59:23 2022 +0300 client: fix table view of query log modal --- .../src/components/Logs/Cells/ClientCell.js | 2 +- .../src/components/Logs/Cells/IconTooltip.css | 27 ++++++++++++++++--- client/src/components/Logs/Cells/index.js | 6 ++--- client/src/components/Logs/Logs.css | 4 +-- client/src/components/Logs/index.js | 19 +++++++------ 5 files changed, 41 insertions(+), 17 deletions(-) diff --git a/client/src/components/Logs/Cells/ClientCell.js b/client/src/components/Logs/Cells/ClientCell.js index 669f1c0acdc..9467f14ebe3 100644 --- a/client/src/components/Logs/Cells/ClientCell.js +++ b/client/src/components/Logs/Cells/ClientCell.js @@ -121,7 +121,7 @@ const ClientCell = ({ {options.map(({ name, onClick, disabled }) => ( ; const blockClientButton =