Skip to content

Commit 16be491

Browse files
authored
feat: add --wireguard-key-file param (#6)
1 parent 0658062 commit 16be491

File tree

4 files changed

+54
-8
lines changed

4 files changed

+54
-8
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,5 @@ go.work
1616

1717
# Build directory.
1818
build/
19+
20+
*.key

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ build: build/tunneld build/tunnel
3737
# architecture. You can change the architecture by setting GOOS and GOARCH
3838
# manually before calling this target.
3939
build/tunneld build/tunnel: build/%: $(shell find . -type f -name '*.go')
40-
go build \
40+
CGO_ENABLED=0 go build \
4141
-o "$@" \
4242
-tags urfave_cli_no_docs \
4343
-ldflags "-s -w -X 'github.com/coder/wgtunnel/buildinfo.tag=$(VERSION)'" \

cmd/tunneld/main.go

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,15 @@ func main() {
7474
&cli.StringFlag{
7575
Name: "wireguard-key",
7676
Aliases: []string{"wg-key"},
77-
Usage: "The private key for the wireguard server. It should be base64 encoded.",
77+
Usage: "The private key for the wireguard server. It should be base64 encoded. You can generate a key with `wg genkey`. Mutually exclusive with wireguard-key-file.",
7878
EnvVars: []string{"TUNNELD_WIREGUARD_KEY"},
7979
},
80+
&cli.StringFlag{
81+
Name: "wireguard-key-file",
82+
Aliases: []string{"wg-key-file"},
83+
Usage: "The file path containing the private key for the wireguard server. The contents should be base64 encoded. If the file does not exist, a key will be generated for you and written to the file. Mutually exclusive with wireguard-key.",
84+
EnvVars: []string{"TUNNELD_WIREGUARD_KEY_FILE"},
85+
},
8086
&cli.IntFlag{
8187
Name: "wireguard-mtu",
8288
Aliases: []string{"wg-mtu"},
@@ -127,24 +133,28 @@ func runApp(ctx *cli.Context) error {
127133
wireguardEndpoint = ctx.String("wireguard-endpoint")
128134
wireguardPort = ctx.Uint("wireguard-port")
129135
wireguardKey = ctx.String("wireguard-key")
136+
wireguardKeyFile = ctx.String("wireguard-key-file")
130137
wireguardMTU = ctx.Int("wireguard-mtu")
131138
wireguardServerIP = ctx.String("wireguard-server-ip")
132139
wireguardNetworkPrefix = ctx.String("wireguard-network-prefix")
133140
pprofListenAddress = ctx.String("pprof-listen-address")
134141
tracingHoneycombTeam = ctx.String("tracing-honeycomb-team")
135142
)
136143
if baseURL == "" {
137-
return xerrors.New("base-hostname is required. See --help for more information.")
144+
return xerrors.New("base-url is required. See --help for more information.")
138145
}
139146
if wireguardEndpoint == "" {
140147
return xerrors.New("wireguard-endpoint is required. See --help for more information.")
141148
}
142149
if wireguardPort < 1 || wireguardPort > 65535 {
143150
return xerrors.New("wireguard-port is required and must be between 1 and 65535. See --help for more information.")
144151
}
145-
if wireguardKey == "" {
152+
if wireguardKey == "" && wireguardKeyFile == "" {
146153
return xerrors.New("wireguard-key is required. See --help for more information.")
147154
}
155+
if wireguardKey != "" && wireguardKeyFile != "" {
156+
return xerrors.New("wireguard-key and wireguard-key-file are mutually exclusive. See --help for more information.")
157+
}
148158

149159
logger := slog.Make(sloghuman.Sink(os.Stderr)).Leveled(slog.LevelInfo)
150160
if verbose {
@@ -182,10 +192,6 @@ func runApp(ctx *cli.Context) error {
182192
if err != nil {
183193
return xerrors.Errorf("could not parse base-url %q: %w", baseURL, err)
184194
}
185-
wireguardKeyParsed, err := tunnelsdk.ParsePrivateKey(wireguardKey)
186-
if err != nil {
187-
return xerrors.Errorf("could not parse wireguard-key %q: %w", wireguardKey, err)
188-
}
189195
wireguardServerIPParsed, err := netip.ParseAddr(wireguardServerIP)
190196
if err != nil {
191197
return xerrors.Errorf("could not parse wireguard-server-ip %q: %w", wireguardServerIP, err)
@@ -195,6 +201,37 @@ func runApp(ctx *cli.Context) error {
195201
return xerrors.Errorf("could not parse wireguard-network-prefix %q: %w", wireguardNetworkPrefix, err)
196202
}
197203

204+
if wireguardKeyFile != "" {
205+
_, err = os.Stat(wireguardKeyFile)
206+
if xerrors.Is(err, os.ErrNotExist) {
207+
logger.Info(ctx.Context, "generating private key to file", slog.F("path", wireguardKeyFile))
208+
key, err := tunnelsdk.GeneratePrivateKey()
209+
if err != nil {
210+
return xerrors.Errorf("could not generate private key: %w", err)
211+
}
212+
213+
err = os.WriteFile(wireguardKeyFile, []byte(key.String()), 0600)
214+
if err != nil {
215+
return xerrors.Errorf("could not write base64-encoded private key to %q: %w", wireguardKeyFile, err)
216+
}
217+
} else if err != nil {
218+
return xerrors.Errorf("could not stat wireguard-key-file %q: %w", wireguardKeyFile, err)
219+
}
220+
221+
logger.Info(ctx.Context, "reading private key from file", slog.F("path", wireguardKeyFile))
222+
wireguardKeyBytes, err := os.ReadFile(wireguardKeyFile)
223+
if err != nil {
224+
return xerrors.Errorf("could not read wireguard-key-file %q: %w", wireguardKeyFile, err)
225+
}
226+
wireguardKey = string(wireguardKeyBytes)
227+
}
228+
229+
wireguardKeyParsed, err := tunnelsdk.ParsePrivateKey(wireguardKey)
230+
if err != nil {
231+
return xerrors.Errorf("could not parse wireguard-key %q: %w", wireguardKey, err)
232+
}
233+
logger.Info(ctx.Context, "parsed private key", slog.F("hash", wireguardKeyParsed.Hash()))
234+
198235
options := &tunneld.Options{
199236
BaseURL: baseURLParsed,
200237
WireguardEndpoint: wireguardEndpoint,

tunnelsdk/tunnel.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package tunnelsdk
22

33
import (
44
"context"
5+
"crypto/sha512"
56
"encoding/hex"
67
"fmt"
78
"net"
@@ -121,6 +122,12 @@ func (k Key) HexString() string {
121122
return hex.EncodeToString(k.k[:])
122123
}
123124

125+
// Hash returns the SHA512 hash of the key.
126+
func (k Key) Hash() string {
127+
hash := sha512.Sum512(k.k[:])
128+
return hex.EncodeToString(hash[:])
129+
}
130+
124131
// NoisePrivateKey returns the device.NoisePrivateKey for the key. If the key is
125132
// not a private key, an error is returned.
126133
func (k Key) NoisePrivateKey() (device.NoisePrivateKey, error) {

0 commit comments

Comments
 (0)