diff --git a/pkg/clients/npm/npm.go b/pkg/clients/npm/npm.go index de680bf..eea1735 100644 --- a/pkg/clients/npm/npm.go +++ b/pkg/clients/npm/npm.go @@ -6,6 +6,9 @@ import ( "fmt" "net/http" "net/url" + "strings" + "sync" + "time" "github.com/deepspace2/plugnpin/pkg/clients" "github.com/deepspace2/plugnpin/pkg/logging" @@ -15,25 +18,36 @@ var log = logging.GetLogger() type Client struct { http.Client - baseURL string - identity string - secret string - token string -} - -var headers map[string]string = map[string]string{ - "content-type": "application/json", + baseURL string + headers map[string]string + identity string + secret string + token string + tokenExpireTime time.Time + mu sync.Mutex } func NewClient(baseURL, identity, secret string) *Client { return &Client{ - Client: http.Client{}, - baseURL: fmt.Sprintf("%v/api", baseURL), + Client: http.Client{}, + baseURL: fmt.Sprintf("%v/api", baseURL), + headers: map[string]string{ + "content-type": "application/json", + }, identity: identity, secret: secret, } } +func parseTokenExpireTime(timeStr string) (time.Time, error) { + return time.Parse(time.RFC3339Nano, timeStr) +} + +func (n *Client) hasTokenExpired() bool { + now := time.Now().UTC() + return now.Compare(n.tokenExpireTime) >= 0 +} + func (n *Client) Login() error { loginPayload := LoginRequest{ Identity: n.identity, @@ -44,20 +58,27 @@ func (n *Client) Login() error { return err } payloadString := string(payloadBytes) - loginResponseString, statusCode, err := clients.Post(&n.Client, n.baseURL+"/tokens", headers, &payloadString) + loginResponseString, statusCode, err := clients.Post(&n.Client, n.baseURL+"/tokens", n.headers, &payloadString) if err != nil { return err } - var resp Token + var resp LoginResponse err = json.Unmarshal([]byte(loginResponseString), &resp) if statusCode >= 400 || err != nil || resp.Token == "" { var loginError ErrorResponse json.Unmarshal([]byte(loginResponseString), &loginError) return errors.New(loginError.Error.Message) } + + tokenExpireTime, err := parseTokenExpireTime(resp.Expires) + if err != nil { + return fmt.Errorf("failed to parse token expiry time '%v': %v", resp.Expires, err) + } + n.tokenExpireTime = tokenExpireTime + n.token = resp.Token - headers["authorization"] = "Bearer " + n.token + n.headers["authorization"] = "Bearer " + n.token return nil } @@ -67,13 +88,11 @@ func (n *Client) GetIP() string { } func (n *Client) getProxyHosts() (map[string]int, error) { - proxyHostsString, statusCode, err := clients.Get(&n.Client, n.baseURL+"/nginx/proxy-hosts", headers) - if statusCode == 401 { - n.refreshToken() - return n.getProxyHosts() - } else if statusCode >= 400 { + proxyHostsString, statusCode, err := n.makeRequest(http.MethodGet, n.baseURL+"/nginx/proxy-hosts", nil) + if err != nil || statusCode >= 400 { return nil, err } + var proxyHosts []ProxyHostReply existingProxyHostsMap := map[string]int{} json.Unmarshal([]byte(proxyHostsString), &proxyHosts) @@ -92,12 +111,49 @@ func (n *Client) refreshToken() error { return n.Login() } +func (n *Client) makeRequest(method, url string, payload *string) (string, int, error) { + n.mu.Lock() + defer n.mu.Unlock() + + if n.hasTokenExpired() { + if err := n.refreshToken(); err != nil { + return "", 0, fmt.Errorf("pre-emptive token refresh failed: %v", err) + } + } + + doRequest := func() (string, int, error) { + switch method { + case http.MethodGet: + return clients.Get(&n.Client, url, n.headers) + case http.MethodPost: + return clients.Post(&n.Client, url, n.headers, payload) + case http.MethodDelete: + return clients.Delete(&n.Client, url, n.headers) + default: + return "", 0, fmt.Errorf("unsupported http method: %s", method) + } + } + + resp, statusCode, err := doRequest() + + var errorResponse ErrorResponse + _ = json.Unmarshal([]byte(resp), &errorResponse) + isTokenExpiredError := strings.Contains(errorResponse.Error.Message, "Token has expired") + + if statusCode == http.StatusUnauthorized || (statusCode >= 400 && isTokenExpiredError) { + log.Info("Received auth-related error, attempting reactive token refresh and retry.") + if refreshErr := n.refreshToken(); refreshErr != nil { + return resp, statusCode, fmt.Errorf("request failed with auth error, and subsequent token refresh also failed: %v", refreshErr) + } + resp, statusCode, err = doRequest() + } + + return resp, statusCode, err +} + func (n *Client) getCertificates() (Certificates, error) { - resp, statusCode, err := clients.Get(&n.Client, n.baseURL+"/nginx/certificates", headers) - if statusCode == 401 { - n.refreshToken() - return n.getCertificates() - } else if statusCode >= 400 { + resp, statusCode, err := n.makeRequest(http.MethodGet, n.baseURL+"/nginx/certificates", nil) + if err != nil || statusCode >= 400 { return nil, err } @@ -109,6 +165,7 @@ func (n *Client) getCertificates() (Certificates, error) { func (n *Client) GetCertificateIDByName(name string) *int { certificates, err := n.getCertificates() if err != nil { + log.Error("Failed to get certificates", "error", err) return nil } for _, certificate := range certificates { @@ -137,20 +194,12 @@ func (n *Client) AddProxyHost(host ProxyHost) error { } payloadString := string(payloadBytes) - resp, statusCode, err := clients.Post(&n.Client, n.baseURL+"/nginx/proxy-hosts", headers, &payloadString) + resp, statusCode, err := n.makeRequest(http.MethodPost, n.baseURL+"/nginx/proxy-hosts", &payloadString) if err != nil { return err } - if statusCode == 401 { - err := n.refreshToken() - if err != nil { - return err - } - _, _, err = clients.Post(&n.Client, n.baseURL+"/nginx/proxy-hosts", headers, &payloadString) - if err != nil { - return err - } - } else if statusCode >= 400 { + + if statusCode >= 400 { var errorResponse ErrorResponse json.Unmarshal([]byte(resp), &errorResponse) return errors.New(errorResponse.Error.Message) @@ -169,21 +218,12 @@ func (n *Client) DeleteProxyHost(domain string) error { } url := fmt.Sprintf("%v/nginx/proxy-hosts/%v", n.baseURL, hostID) - resp, statusCode, err := clients.Delete(&n.Client, url, headers) + resp, statusCode, err := n.makeRequest(http.MethodDelete, url, nil) if err != nil { return err } - if statusCode == 401 { - err := n.refreshToken() - if err != nil { - return err - } - _, _, err = clients.Delete(&n.Client, url, headers) - if err != nil { - return err - } - } else if statusCode >= 400 { + if statusCode >= 400 { var errorResponse ErrorResponse json.Unmarshal([]byte(resp), &errorResponse) return errors.New(errorResponse.Error.Message) diff --git a/pkg/clients/npm/npm_test.go b/pkg/clients/npm/npm_test.go index 1d73052..11f67fa 100644 --- a/pkg/clients/npm/npm_test.go +++ b/pkg/clients/npm/npm_test.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -30,8 +31,7 @@ func TestLogin(t *testing.T) { assert.Equal(t, "test-password", req.Secret) w.WriteHeader(http.StatusOK) - // The type is called Token, so we send back a Token object - json.NewEncoder(w).Encode(Token{Token: "test-jwt-token"}) + json.NewEncoder(w).Encode(LoginResponse{Token: "test-jwt-token", Expires: time.Now().Add(24 * time.Hour).Format(time.RFC3339Nano)}) }) client, server := setupTestServer(handler) @@ -74,6 +74,8 @@ func TestAddProxyHost(t *testing.T) { client, server := setupTestServer(handler) client.token = testToken // Pre-authorize client + client.tokenExpireTime = time.Now().Add(24 * time.Hour) + client.headers["authorization"] = "Bearer " + testToken defer server.Close() hostToAdd := ProxyHost{ @@ -98,6 +100,8 @@ func TestAddProxyHost(t *testing.T) { client, server := setupTestServer(handler) client.token = testToken // Pre-authorize client + client.tokenExpireTime = time.Now().Add(24 * time.Hour) + client.headers["authorization"] = "Bearer " + testToken defer server.Close() // Try to add the same host that already exists. @@ -139,6 +143,8 @@ func TestDeleteProxyHost(t *testing.T) { client, server := setupTestServer(handler) client.token = testToken // Pre-authorize client + client.tokenExpireTime = time.Now().Add(24 * time.Hour) + client.headers["authorization"] = "Bearer " + testToken defer server.Close() err := client.DeleteProxyHost("existing-host.com") @@ -162,6 +168,7 @@ func TestDeleteProxyHost(t *testing.T) { client, server := setupTestServer(handler) client.token = testToken // Pre-authorize client + client.tokenExpireTime = time.Now().Add(24 * time.Hour) defer server.Close() err := client.DeleteProxyHost("non-existing-host.com") diff --git a/pkg/clients/npm/types.go b/pkg/clients/npm/types.go index 30bbef3..791a919 100644 --- a/pkg/clients/npm/types.go +++ b/pkg/clients/npm/types.go @@ -1,5 +1,10 @@ package npm +type LoginResponse struct { + Expires string `json:"expires"` + Token string `json:"token"` +} + type LoginRequest struct { Identity string `json:"identity"` Secret string `json:"secret"` @@ -12,10 +17,6 @@ type ErrorResponse struct { } `json:"error"` } -type Token struct { - Token string `json:"token"` -} - type ProxyHostReply struct { AccessListID int `json:"access_list_id"` AdvancedConfig string `json:"advanced_config"` diff --git a/pkg/clients/pihole/pihole.go b/pkg/clients/pihole/pihole.go index 46f2d6e..9b4da9d 100644 --- a/pkg/clients/pihole/pihole.go +++ b/pkg/clients/pihole/pihole.go @@ -16,8 +16,9 @@ var log = logging.GetLogger() type Client struct { http.Client - baseURL string - sid string + baseURL string + password string + sid string } var headers map[string]string = map[string]string{ @@ -27,9 +28,9 @@ var headers map[string]string = map[string]string{ func NewClient(baseURL string) *Client { return &Client{ - http.Client{}, - fmt.Sprintf("%v/api", baseURL), - "", + Client: http.Client{}, + baseURL: fmt.Sprintf("%v/api", baseURL), + sid: "", } } @@ -46,6 +47,7 @@ func (p *Client) Login(password string) error { return errors.New(resp.Session.Message) } + p.password = password p.sid = resp.Session.Sid return nil } @@ -125,6 +127,10 @@ func (p *Client) AddDnsRecord(domain, ip string) error { if err != nil { return err } + if statusCode == 401 { + p.refreshAuth() + return p.AddDnsRecord(domain, ip) + } if statusCode >= 400 { var errorResponse ErrorResponse json.Unmarshal([]byte(resp), &errorResponse) @@ -170,6 +176,10 @@ func (p *Client) DeleteDnsRecord(domain string) error { if err != nil { return err } + if statusCode == 401 { + p.refreshAuth() + return p.DeleteDnsRecord(domain) + } if statusCode >= 400 { var errorResponse ErrorResponse json.Unmarshal([]byte(resp), &errorResponse) @@ -254,6 +264,10 @@ func (p *Client) AddCNameRecord(domain, target string) error { if err != nil { return err } + if statusCode == 401 { + p.refreshAuth() + return p.AddCNameRecord(domain, target) + } if statusCode >= 400 { var errorResponse ErrorResponse json.Unmarshal([]byte(resp), &errorResponse) @@ -299,6 +313,10 @@ func (p *Client) DeleteCNameRecord(domain, target string) error { if err != nil { return err } + if statusCode == 401 { + p.refreshAuth() + return p.DeleteCNameRecord(domain, target) + } if statusCode >= 400 { var errorResponse ErrorResponse json.Unmarshal([]byte(resp), &errorResponse) @@ -307,3 +325,8 @@ func (p *Client) DeleteCNameRecord(domain, target string) error { return nil } + +func (p *Client) refreshAuth() { + log.Info("Refreshing Pi-Hole authentication") + p.Login(p.password) +}