Skip to content

Commit f8dd0eb

Browse files
authored
feat: check if peer exists before proxying (#14)
1 parent c462cf0 commit f8dd0eb

File tree

13 files changed

+170
-29
lines changed

13 files changed

+170
-29
lines changed

.golangci.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,6 @@ linters:
211211
- asciicheck
212212
- bidichk
213213
- bodyclose
214-
- deadcode
215214
- dogsled
216215
- errcheck
217216
- errname
@@ -255,4 +254,3 @@ linters:
255254
- typecheck
256255
- unconvert
257256
- unused
258-
- varcheck

cmd/tunneld/main.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@ func main() {
123123
EnvVars: []string{"TUNNELD_TRACING_HONEYCOMB_TEAM"},
124124
},
125125
&cli.StringFlag{
126-
Name: "tracing-service-id",
127-
Usage: "The service ID to annotate all traces with that uniquely identifies this deployment.",
128-
EnvVars: []string{"TUNNELD_TRACING_SERVICE_ID"},
126+
Name: "tracing-instance-id",
127+
Usage: "The instance ID to annotate all traces with that uniquely identifies this deployment.",
128+
EnvVars: []string{"TUNNELD_TRACING_INSTANCE_ID"},
129129
},
130130
},
131131
Action: runApp,
@@ -152,7 +152,7 @@ func runApp(ctx *cli.Context) error {
152152
realIPHeader = ctx.String("real-ip-header")
153153
pprofListenAddress = ctx.String("pprof-listen-address")
154154
tracingHoneycombTeam = ctx.String("tracing-honeycomb-team")
155-
tracingServiceID = ctx.String("tracing-service-id")
155+
tracingInstanceID = ctx.String("tracing-instance-id")
156156
)
157157
if baseURL == "" {
158158
return xerrors.New("base-url is required. See --help for more information.")
@@ -185,7 +185,7 @@ func runApp(ctx *cli.Context) error {
185185

186186
// Create a new tracer provider with a batch span processor and the otlp
187187
// exporter.
188-
tp := newTraceProvider(exp, tracingServiceID)
188+
tp := newTraceProvider(exp, tracingInstanceID)
189189
otel.SetTracerProvider(tp)
190190
otel.SetTextMapPropagator(
191191
propagation.NewCompositeTextMapPropagator(

cmd/tunneld/tracing.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ func newHoneycombExporter(ctx context.Context, teamID string) (*otlptrace.Export
2626
return otlptrace.New(ctx, client)
2727
}
2828

29-
func newTraceProvider(exp *otlptrace.Exporter, serviceID string) *sdktrace.TracerProvider {
29+
func newTraceProvider(exp *otlptrace.Exporter, instanceID string) *sdktrace.TracerProvider {
3030
rsc := resource.NewWithAttributes(
3131
semconv.SchemaURL,
3232
semconv.ServiceNameKey.String("WireguardTunnel"),
33-
semconv.ServiceInstanceIDKey.String(serviceID),
33+
semconv.ServiceInstanceIDKey.String(instanceID),
3434
semconv.ServiceVersionKey.String(buildinfo.Version()),
3535
)
3636

go.mod

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ require (
1919
golang.org/x/mod v0.8.0
2020
golang.org/x/sync v0.1.0
2121
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2
22-
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675
22+
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
2323
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230215201556-9c5414ab4bde
2424
google.golang.org/grpc v1.53.0
2525
)
@@ -52,7 +52,7 @@ require (
5252
go.opentelemetry.io/proto/otlp v0.19.0 // indirect
5353
golang.org/x/crypto v0.6.0 // indirect
5454
golang.org/x/net v0.7.0 // indirect
55-
golang.org/x/sys v0.5.0 // indirect
55+
golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 // indirect
5656
golang.org/x/term v0.5.0 // indirect
5757
golang.org/x/text v0.7.0 // indirect
5858
golang.org/x/time v0.3.0 // indirect

go.sum

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,8 @@ golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7w
432432
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
433433
golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
434434
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
435-
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
436-
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
435+
golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 h1:260HNjMTPDya+jq5AM1zZLgG9pv9GASPAGiEEJUbRg4=
436+
golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
437437
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
438438
golang.org/x/term v0.5.0 h1:n2a8QNdAb0sZNpU9R1ALUXBbY+w51fCQDN+7EdxNBsY=
439439
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@@ -508,8 +508,8 @@ golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 h1:H2TDz8ibqkAF6YGhCdN3j
508508
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8=
509509
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
510510
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
511-
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 h1:/J/RVnr7ng4fWPRH3xa4WtBJ1Jp+Auu4YNLmGiPv5QU=
512-
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675/go.mod h1:whfbyDBt09xhCYQWtO2+3UVjlaq6/9hDZrjg2ZE6SyA=
511+
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
512+
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
513513
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230215201556-9c5414ab4bde h1:ybF7AMzIUikL9x4LgwEmzhXtzRpKNqngme1VGDWz+Nk=
514514
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230215201556-9c5414ab4bde/go.mod h1:mQqgjkW8GQQcJQsbBvK890TKqUK1DfKWkuBGbOkuMHQ=
515515
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=

tunneld/api.go

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,24 @@ func (api *API) registerClient(req tunnelsdk.ClientRegisterRequest) (tunnelsdk.C
165165

166166
ip, urls := api.WireguardPublicKeyToIPAndURLs(req.PublicKey, req.Version)
167167

168+
api.pkeyCacheMu.Lock()
169+
api.pkeyCache[ip] = cachedPeer{
170+
key: req.PublicKey,
171+
lastHandshake: time.Now(),
172+
}
173+
api.pkeyCacheMu.Unlock()
174+
168175
exists := true
169176
if api.wgDevice.LookupPeer(req.PublicKey) == nil {
170177
exists = false
171178

179+
api.pkeyCacheMu.Lock()
180+
api.pkeyCache[ip] = cachedPeer{
181+
key: req.PublicKey,
182+
lastHandshake: time.Now(),
183+
}
184+
api.pkeyCacheMu.Unlock()
185+
172186
err := api.wgDevice.IpcSet(fmt.Sprintf(`public_key=%x
173187
allowed_ip=%s/128`,
174188
req.PublicKey,
@@ -186,6 +200,7 @@ allowed_ip=%s/128`,
186200

187201
return tunnelsdk.ClientRegisterResponse{
188202
Version: req.Version,
203+
ReregisterWait: api.PeerRegisterInterval,
189204
TunnelURLs: urlsStr,
190205
ClientIP: ip,
191206
ServerEndpoint: api.WireguardEndpoint,
@@ -205,6 +220,12 @@ func (api *API) handleTunnel(rw http.ResponseWriter, r *http.Request) {
205220
subdomainParts := strings.Split(subdomain, "-")
206221
user := subdomainParts[len(subdomainParts)-1]
207222

223+
span := trace.SpanFromContext(ctx)
224+
span.SetAttributes(
225+
attribute.Bool("proxy_request", true),
226+
attribute.String("user", user),
227+
)
228+
208229
ip, err := api.HostnameToWireguardIP(user)
209230
if err != nil {
210231
httpapi.Write(ctx, rw, http.StatusBadRequest, tunnelsdk.Response{
@@ -214,11 +235,17 @@ func (api *API) handleTunnel(rw http.ResponseWriter, r *http.Request) {
214235
return
215236
}
216237

217-
span := trace.SpanFromContext(ctx)
218-
span.SetAttributes(
219-
attribute.Bool("proxy_request", true),
220-
attribute.String("user", user),
221-
)
238+
api.pkeyCacheMu.RLock()
239+
pkey, ok := api.pkeyCache[ip]
240+
api.pkeyCacheMu.RUnlock()
241+
242+
if !ok || time.Since(pkey.lastHandshake) > api.PeerTimeout {
243+
httpapi.Write(ctx, rw, http.StatusBadGateway, tunnelsdk.Response{
244+
Message: "Peer is not connected.",
245+
Detail: "",
246+
})
247+
return
248+
}
222249

223250
// The transport on the reverse proxy uses this ctx value to know which
224251
// IP to dial. See tunneld.go.

tunneld/api_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ func Test_postClients(t *testing.T) {
100100
require.Equal(t, td.WireguardServerIP, res.ServerIP)
101101
require.Equal(t, td.WireguardKey.NoisePublicKey(), res.ServerPublicKey)
102102
require.Equal(t, td.WireguardMTU, res.WireguardMTU)
103+
require.Equal(t, td.PeerRegisterInterval, res.ReregisterWait)
103104

104105
// Register the same client again.
105106
res2, err := client.ClientRegister(context.Background(), tunnelsdk.ClientRegisterRequest{

tunneld/options.go

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ import (
1919
)
2020

2121
const (
22-
DefaultWireguardMTU = 1280
23-
DefaultPeerDialTimeout = 10 * time.Second
22+
DefaultWireguardMTU = 1280
23+
DefaultPeerDialTimeout = 10 * time.Second
24+
DefaultPeerPollDuration = 30 * time.Second
25+
DefaultPeerTimeout = 2 * time.Minute
2426
)
2527

2628
var (
@@ -70,6 +72,12 @@ type Options struct {
7072
// PeerDialTimeout is the timeout for dialing a peer on a request. Defaults
7173
// to 10 seconds.
7274
PeerDialTimeout time.Duration
75+
76+
// PeerRegisterInterval is how often the clients should re-register.
77+
PeerRegisterInterval time.Duration
78+
79+
// PeerTimeout is how long the server will wait before removing the peer.
80+
PeerTimeout time.Duration
7381
}
7482

7583
// Validate checks that the options are valid and populates default values for
@@ -127,6 +135,18 @@ func (options *Options) Validate() error {
127135
if options.PeerDialTimeout <= 0 {
128136
options.PeerDialTimeout = DefaultPeerDialTimeout
129137
}
138+
if options.PeerRegisterInterval <= 0 {
139+
options.PeerRegisterInterval = DefaultPeerPollDuration
140+
}
141+
if options.PeerTimeout <= 0 {
142+
options.PeerTimeout = DefaultPeerTimeout
143+
}
144+
if options.PeerRegisterInterval >= options.PeerTimeout {
145+
return xerrors.Errorf("PeerRegisterInterval(%s) must be less than PeerTimeout(%s)",
146+
options.PeerRegisterInterval.String(),
147+
options.PeerTimeout.String(),
148+
)
149+
}
130150

131151
return nil
132152
}

tunneld/options_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ func Test_Option(t *testing.T) {
4141
WireguardNetworkPrefix: netip.MustParsePrefix("feed::1/64"),
4242
RealIPHeader: "X-Real-Ip",
4343
PeerDialTimeout: 1 * time.Second,
44+
PeerRegisterInterval: time.Second,
45+
PeerTimeout: 2 * time.Second,
4446
}
4547

4648
clone := o

tunneld/tunneld.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"net"
77
"net/http"
88
"net/netip"
9+
"sync"
910
"time"
1011

1112
"go.opentelemetry.io/otel"
@@ -24,6 +25,14 @@ type API struct {
2425
wgNet *netstack.Net
2526
wgDevice *device.Device
2627
transport *http.Transport
28+
29+
pkeyCacheMu sync.RWMutex
30+
pkeyCache map[netip.Addr]cachedPeer
31+
}
32+
33+
type cachedPeer struct {
34+
key device.NoisePublicKey
35+
lastHandshake time.Time
2736
}
2837

2938
func New(options *Options) (*API, error) {
@@ -72,9 +81,10 @@ listen_port=%d`,
7281
}
7382

7483
return &API{
75-
Options: options,
76-
wgNet: wgNet,
77-
wgDevice: dev,
84+
Options: options,
85+
wgNet: wgNet,
86+
wgDevice: dev,
87+
pkeyCache: make(map[netip.Addr]cachedPeer),
7888
transport: &http.Transport{
7989
DialContext: func(ctx context.Context, network, addr string) (nc net.Conn, err error) {
8090
ctx, span := otel.GetTracerProvider().Tracer("").Start(ctx, "(http.Transport).DialContext")

tunneld/tunneld_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package tunneld_test
22

33
import (
44
"context"
5+
"encoding/json"
56
"io"
67
"log"
78
"net"
@@ -337,6 +338,83 @@ func TestTimeout(t *testing.T) {
337338
require.Equal(t, http.StatusBadGateway, res.StatusCode)
338339
}
339340

341+
func TestPeerTimeout(t *testing.T) {
342+
t.Parallel()
343+
344+
td, client := createTestTunneld(t, &tunneld.Options{
345+
PeerTimeout: time.Second,
346+
PeerRegisterInterval: 100 * time.Millisecond,
347+
})
348+
require.NotNil(t, td)
349+
350+
// Start a tunnel.
351+
key, err := tunnelsdk.GeneratePrivateKey()
352+
require.NoError(t, err, "generate private key")
353+
tunnel, err := client.LaunchTunnel(context.Background(), tunnelsdk.TunnelConfig{
354+
Log: slogtest.
355+
Make(t, &slogtest.Options{IgnoreErrors: true}).
356+
Named("tunnel_client"),
357+
PrivateKey: key,
358+
})
359+
require.NoError(t, err, "launch tunnel")
360+
defer func() {
361+
_ = tunnel.Close()
362+
<-tunnel.Wait()
363+
}()
364+
365+
require.NotNil(t, tunnel.URL)
366+
require.Len(t, tunnel.OtherURLs, 1)
367+
require.NotEqual(t, tunnel.URL.String(), tunnel.OtherURLs[0].String())
368+
369+
serveTunnel(t, tunnel)
370+
waitForTunnelReady(t, client, tunnel)
371+
372+
// Successfully send a request to the peer.
373+
{
374+
u, err := tunnel.URL.Parse("/test/1")
375+
require.NoError(t, err)
376+
377+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
378+
defer cancel()
379+
380+
res, err := client.Request(ctx, http.MethodGet, u.String(), nil)
381+
if !assert.NoError(t, err) {
382+
return
383+
}
384+
defer res.Body.Close()
385+
assert.Equal(t, http.StatusOK, res.StatusCode)
386+
387+
body, err := io.ReadAll(res.Body)
388+
require.NoError(t, err)
389+
require.Equal(t, "hello world /test/1", string(body))
390+
}
391+
392+
err = tunnel.Close()
393+
require.NoError(t, err, "close tunnel")
394+
<-tunnel.Wait()
395+
396+
time.Sleep(td.PeerTimeout)
397+
398+
// The correct error should be returned after the peer goes away.
399+
{
400+
u, err := tunnel.URL.Parse("/test/1")
401+
require.NoError(t, err)
402+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
403+
defer cancel()
404+
405+
res, err := client.Request(ctx, http.MethodGet, u.String(), nil)
406+
require.NoError(t, err)
407+
defer res.Body.Close()
408+
409+
tres := tunnelsdk.Response{}
410+
err = json.NewDecoder(res.Body).Decode(&tres)
411+
require.NoError(t, err)
412+
413+
require.Equal(t, http.StatusBadGateway, res.StatusCode)
414+
require.Equal(t, "Peer is not connected.", tres.Message)
415+
}
416+
}
417+
340418
func freeUDPPort(t *testing.T) uint16 {
341419
t.Helper()
342420

tunnelsdk/api.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"net/http"
77
"net/netip"
8+
"time"
89

910
"golang.zx2c4.com/wireguard/device"
1011
)
@@ -20,7 +21,8 @@ type ClientRegisterRequest struct {
2021
}
2122

2223
type ClientRegisterResponse struct {
23-
Version TunnelVersion `json:"version"`
24+
Version TunnelVersion `json:"version"`
25+
ReregisterWait time.Duration `json:"reregister_wait"`
2426
// TunnelURLs contains a list of valid URLs that will be forwarded from the
2527
// server to this tunnel client once connected. The first URL is the
2628
// preferred URL, and the other URLs are provided for compatibility

0 commit comments

Comments
 (0)