Skip to content

Commit

Permalink
Merge pull request #5 from uhthomas/safely-construct-urls
Browse files Browse the repository at this point in the history
Safely construct urls
  • Loading branch information
crutonjohn authored Jun 12, 2024
2 parents f979ce1 + 720fe63 commit 128e915
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 30 deletions.
41 changes: 22 additions & 19 deletions internal/opnsense-unbound/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"path"
"strings"

log "github.com/sirupsen/logrus"
Expand All @@ -20,17 +22,16 @@ const emptyJSONObject = "{}"
type httpClient struct {
*Config
*http.Client
baseURL *url.URL
}

const (
opnsenseUnboundServicePath = "%s/api/unbound/service/%s"
opnsenseUnboundSettingsPath = "%s/api/unbound/settings/%s"
// Hacky, but nice to have the delete as an explicit constant since it's destructive
opnsenseUnboundSettingsPathDelete = "%s/api/unbound/settings/delHostOverride/%s"
)

// newOpnsenseClient creates a new DNS provider client.
func newOpnsenseClient(config *Config) (*httpClient, error) {
u, err := url.Parse(config.Host)
if err != nil {
return nil, fmt.Errorf("parse url: %w", err)
}
u.Path = path.Join(u.Path, "api/unbound")

// Create the HTTP client
client := &httpClient{
Expand All @@ -40,6 +41,7 @@ func newOpnsenseClient(config *Config) (*httpClient, error) {
TLSClientConfig: &tls.Config{InsecureSkipVerify: config.SkipTLSVerify},
},
},
baseURL: u,
}

if err := client.login(); err != nil {
Expand All @@ -51,11 +53,10 @@ func newOpnsenseClient(config *Config) (*httpClient, error) {

// login performs a basic call to validate credentials
func (c *httpClient) login() error {

// Perform the test call by getting service status
resp, err := c.doRequest(
http.MethodGet,
FormatUrl(opnsenseUnboundServicePath, c.Config.Host, "status"),
"service/status",
nil,
)
if err != nil {
Expand All @@ -76,9 +77,13 @@ func (c *httpClient) login() error {

// doRequest makes an HTTP request to the Opnsense firewall.
func (c *httpClient) doRequest(method, path string, body io.Reader) (*http.Response, error) {
log.Debugf("doRequest: making %s request to %s", method, path)
u := c.baseURL.ResolveReference(&url.URL{
Path: path,
})

log.Debugf("doRequest: making %s request to %s", method, u)

req, err := http.NewRequest(method, path, body)
req, err := http.NewRequest(method, u.String(), body)
if err != nil {
return nil, err
}
Expand All @@ -90,10 +95,10 @@ func (c *httpClient) doRequest(method, path string, body io.Reader) (*http.Respo
return nil, err
}

log.Debugf("doRequest: response code from %s request to %s: %d", method, path, resp.StatusCode)
log.Debugf("doRequest: response code from %s request to %s: %d", method, u, resp.StatusCode)

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("doRequest: %s request to %s was not successful: %d", method, path, resp.StatusCode)
return nil, fmt.Errorf("doRequest: %s request to %s was not successful: %d", method, u, resp.StatusCode)
}

return resp, nil
Expand All @@ -102,10 +107,9 @@ func (c *httpClient) doRequest(method, path string, body io.Reader) (*http.Respo
// GetHostOverrides retrieves the list of HostOverrides from the Opnsense Firewall's Unbound API.
// These are equivalent to A or AAAA records
func (c *httpClient) GetHostOverrides() ([]DNSRecord, error) {

resp, err := c.doRequest(
http.MethodGet,
FormatUrl(opnsenseUnboundSettingsPath, c.Config.Host, "searchHostOverride"),
"settings/searchHostOverride",
nil,
)
if err != nil {
Expand All @@ -125,7 +129,6 @@ func (c *httpClient) GetHostOverrides() ([]DNSRecord, error) {

// CreateHostOverride creates a new DNS A or AAAA record in the Opnsense Firewall's Unbound API.
func (c *httpClient) CreateHostOverride(endpoint *endpoint.Endpoint) (*DNSRecord, error) {

log.Debugf("create: Try pulling pre-existing Unbound %s record: %s", endpoint.RecordType, endpoint.DNSName)
lookup, err := c.lookupHostOverrideIdentifier(endpoint.DNSName, endpoint.RecordType)
if err != nil {
Expand Down Expand Up @@ -155,7 +158,7 @@ func (c *httpClient) CreateHostOverride(endpoint *endpoint.Endpoint) (*DNSRecord
log.Debugf("create: POST: %s", string(jsonBody))
resp, err := c.doRequest(
http.MethodPost,
FormatUrl(opnsenseUnboundSettingsPath, c.Config.Host, "addHostOverride"),
"settings/addHostOverride",
bytes.NewReader(jsonBody),
)
if err != nil {
Expand Down Expand Up @@ -189,7 +192,7 @@ func (c *httpClient) DeleteHostOverride(endpoint *endpoint.Endpoint) error {
log.Debugf("delete: Sending POST %s", lookup.Uuid)
if _, err = c.doRequest(
http.MethodPost,
FormatUrl(opnsenseUnboundSettingsPathDelete, c.Config.Host, lookup.Uuid),
path.Join("settings/delHostOverride", lookup.Uuid),
strings.NewReader(emptyJSONObject),
); err != nil {
return err
Expand Down Expand Up @@ -223,7 +226,7 @@ func (c *httpClient) ReconfigureUnbound() error {
// Perform the reconfigure
resp, err := c.doRequest(
http.MethodPost,
FormatUrl(opnsenseUnboundServicePath, c.Config.Host, "reconfigure"),
"service/reconfigure",
strings.NewReader(emptyJSONObject),
)
if err != nil {
Expand Down
11 changes: 0 additions & 11 deletions internal/opnsense-unbound/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,6 @@ package opnsense

import "strings"

// FormatUrl formats a URL with the given parameters.
func FormatUrl(path string, params ...string) string {
segments := strings.Split(path, "%s")
for i, param := range params {
if param != "" {
segments[i] += param
}
}
return strings.Join(segments, "")
}

// UnboundFQDNSplitter splits a DNSName into two parts,
// [0] Being the top level hostname
// [1] Being the subdomain/domain
Expand Down

0 comments on commit 128e915

Please sign in to comment.