Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send ports forwarded to control server #2392

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/portforward/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ type Service interface {
Start(ctx context.Context) (runError <-chan error, err error)
Stop() (err error)
GetPortsForwarded() (ports []uint16)
SetPortsForwarded(ctx context.Context, ports []uint16) (err error)
}

type Routing interface {
Expand Down
13 changes: 13 additions & 0 deletions internal/portforward/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,19 @@ func (l *Loop) GetPortsForwarded() (ports []uint16) {
return l.service.GetPortsForwarded()
}

func (l *Loop) SetPortsForwarded(ports []uint16) (err error) {
if l.service == nil {
return
}
Comment on lines +161 to +163
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could set the ports somehow, even if the service is not started. The ports could then be injected to the service when we create it. A bit of a futuristic approach about when we could do all kind of modifications live 😄

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that might be beyond me for now. 😅

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem, let's keep this unresolved and I'll jump at implementing it later 😉


err = l.service.SetPortsForwarded(l.runCtx, ports)
if err != nil {
return err
}

return nil
}

func ptrTo[T any](value T) *T {
return &value
}
41 changes: 41 additions & 0 deletions internal/portforward/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package service

import (
"context"
"fmt"
"net/http"
"sync"
)
Expand Down Expand Up @@ -47,3 +48,43 @@ func (s *Service) GetPortsForwarded() (ports []uint16) {
copy(ports, s.ports)
return ports
}

func (s *Service) SetPortsForwarded(ctx context.Context, ports []uint16) (err error) {
for i, port := range s.ports {
err := s.portAllower.RemoveAllowedPort(ctx, port)
if err != nil {
for j := range i {
_ = s.portAllower.SetAllowedPort(ctx, s.ports[j], s.settings.Interface)
}
return fmt.Errorf("removing allowed port: %w", err)
}
}

for i, port := range ports {
err := s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface)
if err != nil {
for j := 0; j < i; j++ {
_ = s.portAllower.RemoveAllowedPort(ctx, s.ports[j])
}
for _, port := range s.ports {
_ = s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface)
}
return fmt.Errorf("setting allowed port: %w", err)
}
}

err = s.writePortForwardedFile(ports)
if err != nil {
_ = s.cleanup()
return err
}

s.portMutex.RLock()
defer s.portMutex.RUnlock()
s.ports = make([]uint16, len(ports))
copy(s.ports, ports)

s.logger.Info("updated: " + portsToString(s.ports))

return nil
}
4 changes: 2 additions & 2 deletions internal/server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
func newHandler(ctx context.Context, logger infoWarner, logging bool,
buildInfo models.BuildInformation,
vpnLooper VPNLooper,
pfGetter PortForwardedGetter,
pf PortForwarding,
unboundLooper DNSLoop,
updaterLooper UpdaterLooper,
publicIPLooper PublicIPLoop,
Expand All @@ -21,7 +21,7 @@ func newHandler(ctx context.Context, logger infoWarner, logging bool,
handler := &handler{}

vpn := newVPNHandler(ctx, vpnLooper, storage, ipv6Supported, logger)
openvpn := newOpenvpnHandler(ctx, vpnLooper, pfGetter, logger)
openvpn := newOpenvpnHandler(ctx, vpnLooper, pf, logger)
dns := newDNSHandler(ctx, unboundLooper, logger)
updater := newUpdaterHandler(ctx, updaterLooper, logger)
publicip := newPublicIPHandler(publicIPLooper, logger)
Expand Down
3 changes: 2 additions & 1 deletion internal/server/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ type DNSLoop interface {
GetStatus() (status models.LoopStatus)
}

type PortForwardedGetter interface {
type PortForwarding interface {
GetPortsForwarded() (ports []uint16)
SetPortsForwarded(ports []uint16) (err error)
}

type PublicIPLoop interface {
Expand Down
33 changes: 30 additions & 3 deletions internal/server/openvpn.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package server
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"

Expand All @@ -11,19 +12,19 @@ import (
)

func newOpenvpnHandler(ctx context.Context, looper VPNLooper,
pfGetter PortForwardedGetter, w warner) http.Handler {
portForwarding PortForwarding, w warner) http.Handler {
return &openvpnHandler{
ctx: ctx,
looper: looper,
pf: pfGetter,
pf: portForwarding,
warner: w,
}
}

type openvpnHandler struct {
ctx context.Context //nolint:containedctx
looper VPNLooper
pf PortForwardedGetter
pf PortForwarding
warner warner
}

Expand All @@ -50,6 +51,8 @@ func (h *openvpnHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
h.getPortForwarded(w)
case http.MethodPut:
h.setPortForwarded(w, r)
default:
errMethodNotSupported(w, r.Method)
}
Expand Down Expand Up @@ -141,3 +144,27 @@ func (h *openvpnHandler) getPortForwarded(w http.ResponseWriter) {
w.WriteHeader(http.StatusInternalServerError)
}
}

func (h *openvpnHandler) setPortForwarded(w http.ResponseWriter, r *http.Request) {
var data portsWrapper

decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(&data); err != nil {
h.warner.Warn(fmt.Sprintf("failed setting forwarded ports: %s", err))
http.Error(w, "failed setting forwarded ports", http.StatusBadRequest)
return
}

if err := h.pf.SetPortsForwarded(data.Ports); err != nil {
h.warner.Warn(fmt.Sprintf("failed setting forwarded ports: %s", err))
http.Error(w, "failed setting forwarded ports", http.StatusInternalServerError)
return
}

encoder := json.NewEncoder(w)
err := encoder.Encode(h.pf.GetPortsForwarded())
if err != nil {
h.warner.Warn(err.Error())
w.WriteHeader(http.StatusInternalServerError)
}
}
4 changes: 2 additions & 2 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ import (

func New(ctx context.Context, address string, logEnabled bool, logger Logger,
buildInfo models.BuildInformation, openvpnLooper VPNLooper,
pfGetter PortForwardedGetter, unboundLooper DNSLoop,
pf PortForwarding, unboundLooper DNSLoop,
updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage,
ipv6Supported bool) (
server *httpserver.Server, err error) {
handler := newHandler(ctx, logger, logEnabled, buildInfo,
openvpnLooper, pfGetter, unboundLooper, updaterLooper, publicIPLooper,
openvpnLooper, pf, unboundLooper, updaterLooper, publicIPLooper,
storage, ipv6Supported)

httpServerSettings := httpserver.Settings{
Expand Down