Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 85 additions & 45 deletions pkg/clients/npm/npm.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import (
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"

"github.com/deepspace2/plugnpin/pkg/clients"
"github.com/deepspace2/plugnpin/pkg/logging"
Expand All @@ -15,25 +18,36 @@ var log = logging.GetLogger()

type Client struct {
http.Client
baseURL string
identity string
secret string
token string
}

var headers map[string]string = map[string]string{
"content-type": "application/json",
baseURL string
headers map[string]string
identity string
secret string
token string
tokenExpireTime time.Time
mu sync.Mutex
}

func NewClient(baseURL, identity, secret string) *Client {
return &Client{
Client: http.Client{},
baseURL: fmt.Sprintf("%v/api", baseURL),
Client: http.Client{},
baseURL: fmt.Sprintf("%v/api", baseURL),
headers: map[string]string{
"content-type": "application/json",
},
identity: identity,
secret: secret,
}
}

func parseTokenExpireTime(timeStr string) (time.Time, error) {
return time.Parse(time.RFC3339Nano, timeStr)
}

func (n *Client) hasTokenExpired() bool {
now := time.Now().UTC()
return now.Compare(n.tokenExpireTime) >= 0
}

func (n *Client) Login() error {
loginPayload := LoginRequest{
Identity: n.identity,
Expand All @@ -44,20 +58,27 @@ func (n *Client) Login() error {
return err
}
payloadString := string(payloadBytes)
loginResponseString, statusCode, err := clients.Post(&n.Client, n.baseURL+"/tokens", headers, &payloadString)
loginResponseString, statusCode, err := clients.Post(&n.Client, n.baseURL+"/tokens", n.headers, &payloadString)
if err != nil {
return err
}

var resp Token
var resp LoginResponse
err = json.Unmarshal([]byte(loginResponseString), &resp)
if statusCode >= 400 || err != nil || resp.Token == "" {
var loginError ErrorResponse
json.Unmarshal([]byte(loginResponseString), &loginError)
return errors.New(loginError.Error.Message)
}

tokenExpireTime, err := parseTokenExpireTime(resp.Expires)
if err != nil {
return fmt.Errorf("failed to parse token expiry time '%v': %v", resp.Expires, err)
}
n.tokenExpireTime = tokenExpireTime

n.token = resp.Token
headers["authorization"] = "Bearer " + n.token
n.headers["authorization"] = "Bearer " + n.token
return nil
}

Expand All @@ -67,13 +88,11 @@ func (n *Client) GetIP() string {
}

func (n *Client) getProxyHosts() (map[string]int, error) {
proxyHostsString, statusCode, err := clients.Get(&n.Client, n.baseURL+"/nginx/proxy-hosts", headers)
if statusCode == 401 {
n.refreshToken()
return n.getProxyHosts()
} else if statusCode >= 400 {
proxyHostsString, statusCode, err := n.makeRequest(http.MethodGet, n.baseURL+"/nginx/proxy-hosts", nil)
if err != nil || statusCode >= 400 {
return nil, err
}

var proxyHosts []ProxyHostReply
existingProxyHostsMap := map[string]int{}
json.Unmarshal([]byte(proxyHostsString), &proxyHosts)
Expand All @@ -92,12 +111,49 @@ func (n *Client) refreshToken() error {
return n.Login()
}

func (n *Client) makeRequest(method, url string, payload *string) (string, int, error) {
n.mu.Lock()
defer n.mu.Unlock()

if n.hasTokenExpired() {
if err := n.refreshToken(); err != nil {
return "", 0, fmt.Errorf("pre-emptive token refresh failed: %v", err)
}
}

doRequest := func() (string, int, error) {
switch method {
case http.MethodGet:
return clients.Get(&n.Client, url, n.headers)
case http.MethodPost:
return clients.Post(&n.Client, url, n.headers, payload)
case http.MethodDelete:
return clients.Delete(&n.Client, url, n.headers)
default:
return "", 0, fmt.Errorf("unsupported http method: %s", method)
}
}

resp, statusCode, err := doRequest()

var errorResponse ErrorResponse
_ = json.Unmarshal([]byte(resp), &errorResponse)
isTokenExpiredError := strings.Contains(errorResponse.Error.Message, "Token has expired")

if statusCode == http.StatusUnauthorized || (statusCode >= 400 && isTokenExpiredError) {
log.Info("Received auth-related error, attempting reactive token refresh and retry.")
if refreshErr := n.refreshToken(); refreshErr != nil {
return resp, statusCode, fmt.Errorf("request failed with auth error, and subsequent token refresh also failed: %v", refreshErr)
}
resp, statusCode, err = doRequest()
}

return resp, statusCode, err
}

func (n *Client) getCertificates() (Certificates, error) {
resp, statusCode, err := clients.Get(&n.Client, n.baseURL+"/nginx/certificates", headers)
if statusCode == 401 {
n.refreshToken()
return n.getCertificates()
} else if statusCode >= 400 {
resp, statusCode, err := n.makeRequest(http.MethodGet, n.baseURL+"/nginx/certificates", nil)
if err != nil || statusCode >= 400 {
return nil, err
}

Expand All @@ -109,6 +165,7 @@ func (n *Client) getCertificates() (Certificates, error) {
func (n *Client) GetCertificateIDByName(name string) *int {
certificates, err := n.getCertificates()
if err != nil {
log.Error("Failed to get certificates", "error", err)
return nil
}
for _, certificate := range certificates {
Expand Down Expand Up @@ -137,20 +194,12 @@ func (n *Client) AddProxyHost(host ProxyHost) error {
}

payloadString := string(payloadBytes)
resp, statusCode, err := clients.Post(&n.Client, n.baseURL+"/nginx/proxy-hosts", headers, &payloadString)
resp, statusCode, err := n.makeRequest(http.MethodPost, n.baseURL+"/nginx/proxy-hosts", &payloadString)
if err != nil {
return err
}
if statusCode == 401 {
err := n.refreshToken()
if err != nil {
return err
}
_, _, err = clients.Post(&n.Client, n.baseURL+"/nginx/proxy-hosts", headers, &payloadString)
if err != nil {
return err
}
} else if statusCode >= 400 {

if statusCode >= 400 {
var errorResponse ErrorResponse
json.Unmarshal([]byte(resp), &errorResponse)
return errors.New(errorResponse.Error.Message)
Expand All @@ -169,21 +218,12 @@ func (n *Client) DeleteProxyHost(domain string) error {
}

url := fmt.Sprintf("%v/nginx/proxy-hosts/%v", n.baseURL, hostID)
resp, statusCode, err := clients.Delete(&n.Client, url, headers)
resp, statusCode, err := n.makeRequest(http.MethodDelete, url, nil)
if err != nil {
return err
}

if statusCode == 401 {
err := n.refreshToken()
if err != nil {
return err
}
_, _, err = clients.Delete(&n.Client, url, headers)
if err != nil {
return err
}
} else if statusCode >= 400 {
if statusCode >= 400 {
var errorResponse ErrorResponse
json.Unmarshal([]byte(resp), &errorResponse)
return errors.New(errorResponse.Error.Message)
Expand Down
11 changes: 9 additions & 2 deletions pkg/clients/npm/npm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/assert"
)
Expand All @@ -30,8 +31,7 @@ func TestLogin(t *testing.T) {
assert.Equal(t, "test-password", req.Secret)

w.WriteHeader(http.StatusOK)
// The type is called Token, so we send back a Token object
json.NewEncoder(w).Encode(Token{Token: "test-jwt-token"})
json.NewEncoder(w).Encode(LoginResponse{Token: "test-jwt-token", Expires: time.Now().Add(24 * time.Hour).Format(time.RFC3339Nano)})
})

client, server := setupTestServer(handler)
Expand Down Expand Up @@ -74,6 +74,8 @@ func TestAddProxyHost(t *testing.T) {

client, server := setupTestServer(handler)
client.token = testToken // Pre-authorize client
client.tokenExpireTime = time.Now().Add(24 * time.Hour)
client.headers["authorization"] = "Bearer " + testToken
defer server.Close()

hostToAdd := ProxyHost{
Expand All @@ -98,6 +100,8 @@ func TestAddProxyHost(t *testing.T) {

client, server := setupTestServer(handler)
client.token = testToken // Pre-authorize client
client.tokenExpireTime = time.Now().Add(24 * time.Hour)
client.headers["authorization"] = "Bearer " + testToken
defer server.Close()

// Try to add the same host that already exists.
Expand Down Expand Up @@ -139,6 +143,8 @@ func TestDeleteProxyHost(t *testing.T) {

client, server := setupTestServer(handler)
client.token = testToken // Pre-authorize client
client.tokenExpireTime = time.Now().Add(24 * time.Hour)
client.headers["authorization"] = "Bearer " + testToken
defer server.Close()

err := client.DeleteProxyHost("existing-host.com")
Expand All @@ -162,6 +168,7 @@ func TestDeleteProxyHost(t *testing.T) {

client, server := setupTestServer(handler)
client.token = testToken // Pre-authorize client
client.tokenExpireTime = time.Now().Add(24 * time.Hour)
defer server.Close()

err := client.DeleteProxyHost("non-existing-host.com")
Expand Down
9 changes: 5 additions & 4 deletions pkg/clients/npm/types.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package npm

type LoginResponse struct {
Expires string `json:"expires"`
Token string `json:"token"`
}

type LoginRequest struct {
Identity string `json:"identity"`
Secret string `json:"secret"`
Expand All @@ -12,10 +17,6 @@ type ErrorResponse struct {
} `json:"error"`
}

type Token struct {
Token string `json:"token"`
}

type ProxyHostReply struct {
AccessListID int `json:"access_list_id"`
AdvancedConfig string `json:"advanced_config"`
Expand Down
33 changes: 28 additions & 5 deletions pkg/clients/pihole/pihole.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ var log = logging.GetLogger()

type Client struct {
http.Client
baseURL string
sid string
baseURL string
password string
sid string
}

var headers map[string]string = map[string]string{
Expand All @@ -27,9 +28,9 @@ var headers map[string]string = map[string]string{

func NewClient(baseURL string) *Client {
return &Client{
http.Client{},
fmt.Sprintf("%v/api", baseURL),
"",
Client: http.Client{},
baseURL: fmt.Sprintf("%v/api", baseURL),
sid: "",
}
}

Expand All @@ -46,6 +47,7 @@ func (p *Client) Login(password string) error {
return errors.New(resp.Session.Message)
}

p.password = password
p.sid = resp.Session.Sid
return nil
}
Expand Down Expand Up @@ -125,6 +127,10 @@ func (p *Client) AddDnsRecord(domain, ip string) error {
if err != nil {
return err
}
if statusCode == 401 {
p.refreshAuth()
return p.AddDnsRecord(domain, ip)
}
if statusCode >= 400 {
var errorResponse ErrorResponse
json.Unmarshal([]byte(resp), &errorResponse)
Expand Down Expand Up @@ -170,6 +176,10 @@ func (p *Client) DeleteDnsRecord(domain string) error {
if err != nil {
return err
}
if statusCode == 401 {
p.refreshAuth()
return p.DeleteDnsRecord(domain)
}
if statusCode >= 400 {
var errorResponse ErrorResponse
json.Unmarshal([]byte(resp), &errorResponse)
Expand Down Expand Up @@ -254,6 +264,10 @@ func (p *Client) AddCNameRecord(domain, target string) error {
if err != nil {
return err
}
if statusCode == 401 {
p.refreshAuth()
return p.AddCNameRecord(domain, target)
}
if statusCode >= 400 {
var errorResponse ErrorResponse
json.Unmarshal([]byte(resp), &errorResponse)
Expand Down Expand Up @@ -299,6 +313,10 @@ func (p *Client) DeleteCNameRecord(domain, target string) error {
if err != nil {
return err
}
if statusCode == 401 {
p.refreshAuth()
return p.DeleteCNameRecord(domain, target)
}
if statusCode >= 400 {
var errorResponse ErrorResponse
json.Unmarshal([]byte(resp), &errorResponse)
Expand All @@ -307,3 +325,8 @@ func (p *Client) DeleteCNameRecord(domain, target string) error {

return nil
}

func (p *Client) refreshAuth() {
log.Info("Refreshing Pi-Hole authentication")
p.Login(p.password)
}