Skip to content

Commit

Permalink
Move db network forward functions to ClusterTx
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Hipp <thomashipp@gmail.com>
  • Loading branch information
monstermunchkin committed Dec 12, 2023
1 parent dd2fc19 commit 80f1063
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 136 deletions.
8 changes: 7 additions & 1 deletion cmd/incusd/network_allocations.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,13 @@ func networkAllocationsGet(d *Daemon, r *http.Request) response.Response {
}
}

forwards, err := d.db.Cluster.GetNetworkForwards(r.Context(), n.ID(), false)
var forwards map[int64]*api.NetworkForward

d.db.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error {

Check failure on line 184 in cmd/incusd/network_allocations.go

View workflow job for this annotation

GitHub Actions / Code

Error return value of `d.db.Cluster.Transaction` is not checked (errcheck)
forwards, err = tx.GetNetworkForwards(ctx, n.ID(), false)

return err
})
if err != nil {
return response.SmartError(fmt.Errorf("Failed getting forwards for network %q in project %q: %w", networkName, projectName, err))
}
Expand Down
34 changes: 30 additions & 4 deletions cmd/incusd/network_forwards.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"encoding/json"
"fmt"
"net/http"
Expand All @@ -10,6 +11,7 @@ import (

"github.com/lxc/incus/internal/server/auth"
clusterRequest "github.com/lxc/incus/internal/server/cluster/request"
"github.com/lxc/incus/internal/server/db"
"github.com/lxc/incus/internal/server/lifecycle"
"github.com/lxc/incus/internal/server/network"
"github.com/lxc/incus/internal/server/project"
Expand Down Expand Up @@ -160,7 +162,13 @@ func networkForwardsGet(d *Daemon, r *http.Request) response.Response {
memberSpecific := false // Get forwards for all cluster members.

if localUtil.IsRecursionRequest(r) {
records, err := s.DB.Cluster.GetNetworkForwards(r.Context(), n.ID(), memberSpecific)
var records map[int64]*api.NetworkForward

err = s.DB.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error {
records, err = tx.GetNetworkForwards(ctx, n.ID(), memberSpecific)

return err
})
if err != nil {
return response.SmartError(fmt.Errorf("Failed loading network forwards: %w", err))
}
Expand All @@ -173,7 +181,13 @@ func networkForwardsGet(d *Daemon, r *http.Request) response.Response {
return response.SyncResponse(true, forwards)
}

listenAddresses, err := s.DB.Cluster.GetNetworkForwardListenAddresses(n.ID(), memberSpecific)
var listenAddresses map[int64]string

err = s.DB.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error {
listenAddresses, err = tx.GetNetworkForwardListenAddresses(ctx, n.ID(), memberSpecific)

return err
})
if err != nil {
return response.SmartError(fmt.Errorf("Failed loading network forwards: %w", err))
}
Expand Down Expand Up @@ -425,7 +439,13 @@ func networkForwardGet(d *Daemon, r *http.Request) response.Response {
targetMember := request.QueryParam(r, "target")
memberSpecific := targetMember != ""

_, forward, err := s.DB.Cluster.GetNetworkForward(r.Context(), n.ID(), memberSpecific, listenAddress)
var forward *api.NetworkForward

err = s.DB.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error {
_, forward, err = tx.GetNetworkForward(ctx, n.ID(), memberSpecific, listenAddress)

return err
})
if err != nil {
return response.SmartError(err)
}
Expand Down Expand Up @@ -550,7 +570,13 @@ func networkForwardPut(d *Daemon, r *http.Request) response.Response {
memberSpecific := targetMember != ""

if r.Method == http.MethodPatch {
_, forward, err := s.DB.Cluster.GetNetworkForward(r.Context(), n.ID(), memberSpecific, listenAddress)
var forward *api.NetworkForward

err = s.DB.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error {
_, forward, err = tx.GetNetworkForward(ctx, n.ID(), memberSpecific, listenAddress)

return err
})
if err != nil {
return response.SmartError(err)
}
Expand Down
181 changes: 78 additions & 103 deletions internal/server/db/network_forwards.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
// CreateNetworkForward creates a new Network Forward.
// If memberSpecific is true, then the forward is associated to the current member, rather than being associated to
// all members.
func (c *Cluster) CreateNetworkForward(networkID int64, memberSpecific bool, info *api.NetworkForwardsPost) (int64, error) {
func (c *ClusterTx) CreateNetworkForward(ctx context.Context, networkID int64, memberSpecific bool, info *api.NetworkForwardsPost) (int64, error) {
var err error
var forwardID int64
var nodeID any
Expand All @@ -36,30 +36,23 @@ func (c *Cluster) CreateNetworkForward(networkID int64, memberSpecific bool, inf
}
}

err = c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
// Insert a new Network forward record.
result, err := tx.tx.Exec(`
// Insert a new Network forward record.
result, err := c.tx.ExecContext(ctx, `
INSERT INTO networks_forwards
(network_id, node_id, listen_address, description, ports)
VALUES (?, ?, ?, ?, ?)
`, networkID, nodeID, info.ListenAddress, info.Description, string(portsJSON))
if err != nil {
return err
}

forwardID, err = result.LastInsertId()
if err != nil {
return err
}
if err != nil {
return -1, err
}

// Save config.
err = networkForwardConfigAdd(tx.tx, forwardID, info.Config)
if err != nil {
return err
}
forwardID, err = result.LastInsertId()
if err != nil {
return -1, err
}

return nil
})
// Save config.
err = networkForwardConfigAdd(c.tx, forwardID, info.Config)
if err != nil {
return -1, err
}
Expand Down Expand Up @@ -95,7 +88,7 @@ func networkForwardConfigAdd(tx *sql.Tx, forwardID int64, config map[string]stri
}

// UpdateNetworkForward updates an existing Network Forward.
func (c *Cluster) UpdateNetworkForward(networkID int64, forwardID int64, info *api.NetworkForwardPut) error {
func (c *ClusterTx) UpdateNetworkForward(ctx context.Context, networkID int64, forwardID int64, info *api.NetworkForwardPut) error {
var err error
var portsJSON []byte

Expand All @@ -106,39 +99,32 @@ func (c *Cluster) UpdateNetworkForward(networkID int64, forwardID int64, info *a
}
}

err = c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
// Update existing Network forward record.
res, err := tx.tx.Exec(`
// Update existing Network forward record.
res, err := c.tx.ExecContext(ctx, `
UPDATE networks_forwards
SET description = ?, ports = ?
WHERE network_id = ? and id = ?
`, info.Description, string(portsJSON), networkID, forwardID)
if err != nil {
return err
}

rowsAffected, err := res.RowsAffected()
if err != nil {
return err
}
if err != nil {
return err
}

if rowsAffected <= 0 {
return api.StatusErrorf(http.StatusNotFound, "Network forward not found")
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return err
}

// Save config.
_, err = tx.tx.Exec("DELETE FROM networks_forwards_config WHERE network_forward_id=?", forwardID)
if err != nil {
return err
}
if rowsAffected <= 0 {
return api.StatusErrorf(http.StatusNotFound, "Network forward not found")
}

err = networkForwardConfigAdd(tx.tx, forwardID, info.Config)
if err != nil {
return err
}
// Save config.
_, err = c.tx.ExecContext(ctx, "DELETE FROM networks_forwards_config WHERE network_forward_id=?", forwardID)
if err != nil {
return err
}

return nil
})
err = networkForwardConfigAdd(c.tx, forwardID, info.Config)
if err != nil {
return err
}
Expand All @@ -147,34 +133,32 @@ func (c *Cluster) UpdateNetworkForward(networkID int64, forwardID int64, info *a
}

// DeleteNetworkForward deletes an existing Network Forward.
func (c *Cluster) DeleteNetworkForward(networkID int64, forwardID int64) error {
return c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
// Delete existing Network forward record.
res, err := tx.tx.Exec(`
func (c *ClusterTx) DeleteNetworkForward(ctx context.Context, networkID int64, forwardID int64) error {
// Delete existing Network forward record.
res, err := c.tx.ExecContext(ctx, `
DELETE FROM networks_forwards
WHERE network_id = ? and id = ?
`, networkID, forwardID)
if err != nil {
return err
}
if err != nil {
return err
}

rowsAffected, err := res.RowsAffected()
if err != nil {
return err
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return err
}

if rowsAffected <= 0 {
return api.StatusErrorf(http.StatusNotFound, "Network forward not found")
}
if rowsAffected <= 0 {
return api.StatusErrorf(http.StatusNotFound, "Network forward not found")
}

return nil
})
return nil
}

// GetNetworkForward returns the Network Forward ID and info for the given network ID and listen address.
// If memberSpecific is true, then the search is restricted to forwards that belong to this member or belong to
// all members.
func (c *Cluster) GetNetworkForward(ctx context.Context, networkID int64, memberSpecific bool, listenAddress string) (int64, *api.NetworkForward, error) {
func (c *ClusterTx) GetNetworkForward(ctx context.Context, networkID int64, memberSpecific bool, listenAddress string) (int64, *api.NetworkForward, error) {
forwards, err := c.GetNetworkForwards(ctx, networkID, memberSpecific, listenAddress)
if (err == nil && len(forwards) <= 0) || errors.Is(err, sql.ErrNoRows) {
return -1, nil, api.StatusErrorf(http.StatusNotFound, "Network forward not found")
Expand Down Expand Up @@ -225,7 +209,7 @@ func networkForwardConfig(ctx context.Context, tx *ClusterTx, forwardID int64, f
// on Forward ID.
// If memberSpecific is true, then the search is restricted to forwards that belong to this member or belong to
// all members.
func (c *Cluster) GetNetworkForwardListenAddresses(networkID int64, memberSpecific bool) (map[int64]string, error) {
func (c *ClusterTx) GetNetworkForwardListenAddresses(ctx context.Context, networkID int64, memberSpecific bool) (map[int64]string, error) {
var q *strings.Builder = &strings.Builder{}
args := []any{networkID}

Expand All @@ -244,21 +228,19 @@ func (c *Cluster) GetNetworkForwardListenAddresses(networkID int64, memberSpecif

forwards := make(map[int64]string)

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
return query.Scan(ctx, tx.Tx(), q.String(), func(scan func(dest ...any) error) error {
var forwardID int64 = int64(-1)
var listenAddress string
err := query.Scan(ctx, c.tx, q.String(), func(scan func(dest ...any) error) error {
var forwardID int64 = int64(-1)
var listenAddress string

err := scan(&forwardID, &listenAddress)
if err != nil {
return err
}
err := scan(&forwardID, &listenAddress)
if err != nil {
return err
}

forwards[forwardID] = listenAddress
forwards[forwardID] = listenAddress

return nil
}, args...)
})
return nil
}, args...)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -377,7 +359,7 @@ func (c *ClusterTx) GetProjectNetworkForwardListenAddressesOnMember(ctx context.
// GetNetworkForwards returns map of Network Forwards for the given network ID keyed on Forward ID.
// If memberSpecific is true, then the search is restricted to forwards that belong to this member or belong to
// all members. Can optionally retrieve only specific network forwards by listen address.
func (c *Cluster) GetNetworkForwards(ctx context.Context, networkID int64, memberSpecific bool, listenAddresses ...string) (map[int64]*api.NetworkForward, error) {
func (c *ClusterTx) GetNetworkForwards(ctx context.Context, networkID int64, memberSpecific bool, listenAddresses ...string) (map[int64]*api.NetworkForward, error) {
var q *strings.Builder = &strings.Builder{}
args := []any{networkID}

Expand Down Expand Up @@ -408,46 +390,39 @@ func (c *Cluster) GetNetworkForwards(ctx context.Context, networkID int64, membe
var err error
forwards := make(map[int64]*api.NetworkForward)

err = c.Transaction(ctx, func(ctx context.Context, tx *ClusterTx) error {
err = query.Scan(ctx, tx.Tx(), q.String(), func(scan func(dest ...any) error) error {
var forwardID int64 = int64(-1)
var portsJSON string
var forward api.NetworkForward

err := scan(&forwardID, &forward.ListenAddress, &forward.Description, &forward.Location, &portsJSON)
if err != nil {
return err
}
err = query.Scan(ctx, c.tx, q.String(), func(scan func(dest ...any) error) error {
var forwardID int64 = int64(-1)
var portsJSON string
var forward api.NetworkForward

forward.Ports = []api.NetworkForwardPort{}
if portsJSON != "" {
err = json.Unmarshal([]byte(portsJSON), &forward.Ports)
if err != nil {
return fmt.Errorf("Failed unmarshalling ports: %w", err)
}
}

forwards[forwardID] = &forward

return nil
}, args...)
err := scan(&forwardID, &forward.ListenAddress, &forward.Description, &forward.Location, &portsJSON)
if err != nil {
return err
}

// Populate config.
for forwardID := range forwards {
err = networkForwardConfig(ctx, tx, forwardID, forwards[forwardID])
forward.Ports = []api.NetworkForwardPort{}
if portsJSON != "" {
err = json.Unmarshal([]byte(portsJSON), &forward.Ports)
if err != nil {
return err
return fmt.Errorf("Failed unmarshalling ports: %w", err)
}
}

forwards[forwardID] = &forward

return nil
})
}, args...)
if err != nil {
return nil, err
}

// Populate config.
for forwardID := range forwards {
err = networkForwardConfig(ctx, c, forwardID, forwards[forwardID])
if err != nil {
return nil, err
}
}

return forwards, nil
}
9 changes: 8 additions & 1 deletion internal/server/device/nic_bridged.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package device
import (
"bufio"
"bytes"
"context"
"encoding/binary"
"encoding/hex"
"fmt"
Expand Down Expand Up @@ -611,7 +612,13 @@ func (d *nicBridged) Start() (*deviceConfig.RunConfig, error) {
}

if brNetfilterEnabled {
listenAddresses, err := d.state.DB.Cluster.GetNetworkForwardListenAddresses(d.network.ID(), true)
var listenAddresses map[int64]string

err = d.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error {
listenAddresses, err = tx.GetNetworkForwardListenAddresses(ctx, d.network.ID(), true)

return err
})
if err != nil {
return nil, fmt.Errorf("Failed loading network forwards: %w", err)
}
Expand Down
Loading

0 comments on commit 80f1063

Please sign in to comment.