diff --git a/go.mod b/go.mod index 506266d1a..9d73eac8e 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/AlecAivazis/survey/v2 v2.3.7 github.com/Psiphon-Labs/psiphon-tunnel-core v1.0.11-0.20240424194431-3612a5a6fb4c github.com/alecthomas/kingpin/v2 v2.4.0 + github.com/amnezia-vpn/amneziawg-go v0.2.8 github.com/apex/log v1.9.0 github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 github.com/cloudflare/circl v1.3.8 @@ -81,6 +82,7 @@ require ( github.com/segmentio/fasthash v1.0.3 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/stretchr/testify v1.9.0 // indirect + github.com/tevino/abool/v2 v2.1.0 // indirect github.com/xhit/go-str2duration/v2 v2.1.0 // indirect gitlab.com/yawning/edwards25519-extra v0.0.0-20231005122941-2149dcafc266 // indirect go.uber.org/mock v0.4.0 // indirect @@ -88,8 +90,9 @@ require ( golang.org/x/exp/typeparams v0.0.0-20230522175609-2e198f4a06a1 // indirect golang.org/x/sync v0.7.0 // indirect golang.org/x/time v0.5.0 // indirect + golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - gvisor.dev/gvisor v0.0.0-20230922204349-b3f36d574a7f // indirect + gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect ) require ( diff --git a/go.sum b/go.sum index c54eab98f..0fcac7a9e 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,8 @@ github.com/alecthomas/kingpin/v2 v2.4.0 h1:f48lwail6p8zpO1bC4TxtqACaGqHYA22qkHjH github.com/alecthomas/kingpin/v2 v2.4.0/go.mod h1:0gyi0zQnjuFk8xrkNKamJoyUo382HRL7ATRpFZCw6tE= github.com/alecthomas/units v0.0.0-20231202071711-9a357b53e9c9 h1:ez/4by2iGztzR4L0zgAOR8lTQK9VlyBVVd7G4omaOQs= github.com/alecthomas/units v0.0.0-20231202071711-9a357b53e9c9/go.mod h1:OMCwj8VM1Kc9e19TLln2VL61YJF0x1XFtfdL4JdbSyE= +github.com/amnezia-vpn/amneziawg-go v0.2.8 h1:J8PPx+hylx5nNZ5U1+ECFj9noGkcm2ThmSV9rBNDgy8= +github.com/amnezia-vpn/amneziawg-go v0.2.8/go.mod h1:12g0XRbFeGbpXvuCmBOV21YxLWSFnUFJnwgrzyHBUyk= github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/apex/log v1.9.0 h1:FHtw/xuaM8AgmvDDTI9fiwoAL25Sq2cxojnZICUU8l0= @@ -531,6 +533,8 @@ github.com/templexxx/cpu v0.1.0/go.mod h1:w7Tb+7qgcAlIyX4NhLuDKt78AHA5SzPmq0Wj6H github.com/templexxx/xorsimd v0.4.1/go.mod h1:W+ffZz8jJMH2SXwuKu9WhygqBMbFnp14G2fqEr8qaNo= github.com/templexxx/xorsimd v0.4.2 h1:ocZZ+Nvu65LGHmCLZ7OoCtg8Fx8jnHKK37SjvngUoVI= github.com/templexxx/xorsimd v0.4.2/go.mod h1:HgwaPoDREdi6OnULpSfxhzaiiSUY4Fi3JPn1wpt28NI= +github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c= +github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY= github.com/tj/assert v0.0.0-20171129193455-018094318fb0/go.mod h1:mZ9/Rh9oLWpLLDRpvE+3b7gP/C2YyLFYxNmcLnPTMe0= github.com/tj/assert v0.0.3 h1:Df/BlaZ20mq6kuai7f5z2TvPFiwC3xaWJSDQNiIS3Rk= github.com/tj/assert v0.0.3/go.mod h1:Ne6X72Q+TB1AteidzQncjw9PabbMp4PBMZ1k+vd1Pvk= @@ -757,6 +761,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= +golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= @@ -800,6 +806,8 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gvisor.dev/gvisor v0.0.0-20230922204349-b3f36d574a7f h1:w4K7S8+VKrhX67mFdUymQUsGVbEElPCN0v7U0DoLpUw= gvisor.dev/gvisor v0.0.0-20230922204349-b3f36d574a7f/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= +gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= +gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/internal/experiment/wireguard/config.go b/internal/experiment/wireguard/config.go new file mode 100644 index 000000000..ae129deeb --- /dev/null +++ b/internal/experiment/wireguard/config.go @@ -0,0 +1,176 @@ +package wireguard + +import ( + "crypto/sha1" + "encoding/base64" + "encoding/hex" + "fmt" + "io" +) + +var ( + // defaultNameserver is the dns server using for resolving names inside the wg tunnel. + defaultNameserver = "8.8.8.8" +) + +// Config contains the experiment config. +// +// This contains all the settings that user can set to modify the behaviour +// of this experiment. By tagging these variables with `ooni:"..."`, we allow +// miniooni's -O flag to find them and set them. +type Config struct { + Verbose bool `ooni:"Use extra-verbose mode in wireguard logs"` + + // These flags modify what sensitive information is stored in the report and submitted to the backend. + PublicTarget bool `ooni:"Treat the target endpoint as public data (if true, it will be included in the report)"` + PublicAmneziaParameters bool `ooni:"Treat the AmneziaWG advanced security parameters as public data"` + + // Safe_XXX options are not sent to the backend for archival by default. + SafeRemote string `ooni:"Remote to connect to using WireGuard"` + SafeIP string `ooni:"Allocated IP for this peer"` + + // Keys are base-64 encoded + SafePrivateKey string `ooni:"Private key to connect to remote (base64)"` + SafePublicKey string `ooni:"Public key of the remote (base64)"` + SafePresharedKey string `ooni:"Pre-shared key for authentication (base64)"` + + // Optional obfuscation parameters for AmneziaWG + SafeJc string `ooni:"jc"` + SafeJmin string `ooni:"jmin"` + SafeJmax string `ooni:"jmax"` + SafeS1 string `ooni:"s1"` + SafeS2 string `ooni:"s2"` + SafeH1 string `ooni:"h1"` + SafeH2 string `ooni:"h2"` + SafeH3 string `ooni:"h3"` + SafeH4 string `ooni:"h4"` +} + +type wireguardOptions struct { + // common wireguard parameters + endpoint string + ip string + ns string + + // keys are hex-encoded + pubKey string + privKey string + presharedKey string + + // optional parameters for AmneziaWG nodes + jc string + jmin string + jmax string + s1 string + s2 string + h1 string + h2 string + h3 string + h4 string +} + +// amneziaValues returns an array with all the amnezia-specific configuration +// parameters. +func (wo *wireguardOptions) amneziaValues() []string { + return []string{ + wo.jc, wo.jmin, wo.jmax, + wo.s1, wo.s2, + wo.h1, wo.h2, wo.h3, wo.h4, + } +} + +// validate returns true if this looks like a sensible wireguard configuration. +func (wo *wireguardOptions) validate() bool { + if wo.endpoint == "" || wo.ip == "" || wo.pubKey == "" || wo.privKey == "" || wo.presharedKey == "" { + return false + } + if isAnyFilled(wo.amneziaValues()...) { + return !isAnyEmpty(wo.amneziaValues()...) + } + return true +} + +// isAmneziaFlavored returns true if none of the mandatory amnezia fields are empty. +func (wo *wireguardOptions) isAmneziaFlavored() bool { + return !isAnyEmpty(wo.amneziaValues()...) +} + +// amneziaConfigHash is a hash representation of the custom parameters in this amneziaWG node. +// intended to be used if PublicAmneziaParameters=false, so that we can verify that we're testing +// the same node. +func (wo *wireguardOptions) configurationHash() string { + if !wo.isAmneziaFlavored() { + return "" + } + return sha1Sum(append(wo.amneziaValues(), wo.endpoint)...) +} + +func sha1Sum(strings ...string) string { + hasher := sha1.New() + for _, str := range strings { + io.WriteString(hasher, str) + } + return fmt.Sprintf("%x", hasher.Sum(nil)) +} + +func newWireguardOptionsFromConfig(c *Config) (*wireguardOptions, error) { + o := &wireguardOptions{} + + pub, err := base64.StdEncoding.DecodeString(c.SafePublicKey) + if err != nil { + return nil, fmt.Errorf("%w: cannot decode public key", ErrInvalidInput) + } + pubHex := hex.EncodeToString(pub) + o.pubKey = pubHex + + priv, err := base64.StdEncoding.DecodeString(c.SafePrivateKey) + if err != nil { + return nil, fmt.Errorf("%w: cannot decode private key", ErrInvalidInput) + } + privHex := hex.EncodeToString(priv) + o.privKey = privHex + + psk, err := base64.StdEncoding.DecodeString(c.SafePresharedKey) + if err != nil { + return nil, fmt.Errorf("%w: cannot decode pre-shared key", ErrInvalidInput) + } + pskHex := hex.EncodeToString(psk) + o.presharedKey = pskHex + + // TODO(ainghazal): reconcile this with Input if c.PublicTarget=true + o.endpoint = c.SafeRemote + + o.ip = c.SafeIP + + // amnezia parameters + o.jc = c.SafeJc + o.jmin = c.SafeJmin + o.jmax = c.SafeJmax + o.s1 = c.SafeS1 + o.s2 = c.SafeS2 + o.h1 = c.SafeH1 + o.h2 = c.SafeH2 + o.h3 = c.SafeH3 + o.h4 = c.SafeH4 + + o.ns = defaultNameserver + return o, nil +} + +func isAnyFilled(fields ...string) bool { + for _, f := range fields { + if f != "" { + return true + } + } + return false +} + +func isAnyEmpty(fields ...string) bool { + for _, f := range fields { + if f == "" { + return true + } + } + return false +} diff --git a/internal/experiment/wireguard/config_test.go b/internal/experiment/wireguard/config_test.go new file mode 100644 index 000000000..37c9089e5 --- /dev/null +++ b/internal/experiment/wireguard/config_test.go @@ -0,0 +1,227 @@ +package wireguard + +import ( + "errors" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func Test_wireguardOptions(t *testing.T) { + t.Run("amnezia values are the expected set", func(t *testing.T) { + wc := wireguardOptions{ + jc: "1", + jmin: "2", + jmax: "3", + s1: "4", + s2: "5", + h1: "6", + h2: "7", + h3: "8", + h4: "9", + } + expected := []string{"1", "2", "3", "4", "5", "6", "7", "8", "9"} + if diff := cmp.Diff(wc.amneziaValues(), expected); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("validate() is true when the mandatory fields are filled", func(t *testing.T) { + wc := wireguardOptions{ + endpoint: "1.1.1.1:8020", + ip: "10.1.2.8", + pubKey: "foobar", + privKey: "foobar", + presharedKey: "foobar", + } + if wc.validate() != true { + t.Fatal("expected options to be valid") + } + }) + + t.Run("validate() is false when one mandatory field is missing", func(t *testing.T) { + wc := wireguardOptions{ + endpoint: "1.1.1.1:8020", + pubKey: "foobar", + privKey: "foobar", + presharedKey: "foobar", + } + if wc.validate() != false { + t.Fatal("expected options not to be valid") + } + }) + + t.Run("validate() is true when the all amnezia fields are filled", func(t *testing.T) { + wc := wireguardOptions{ + endpoint: "1.1.1.1:8020", + ip: "10.1.2.8", + pubKey: "foobar", + privKey: "foobar", + presharedKey: "foobar", + jc: "1", + jmin: "2", + jmax: "3", + s1: "4", + s2: "5", + h1: "6", + h2: "7", + h3: "8", + h4: "9", + } + if wc.validate() != true { + t.Fatal("expected options to be valid") + } + }) + + t.Run("validate() is false when any of the amnezia fields is missing", func(t *testing.T) { + wc := wireguardOptions{ + endpoint: "1.1.1.1:8020", + ip: "10.1.2.8", + pubKey: "foobar", + privKey: "foobar", + presharedKey: "foobar", + jc: "1", + jmin: "2", + jmax: "3", + s1: "4", + s2: "5", + h1: "6", + h2: "7", + h3: "8", + h4: "", + } + if wc.validate() != false { + t.Fatal("expected options not to be valid") + } + }) + + t.Run("isAmneziaFlavored() is true when none of the amnezia fields is missing", func(t *testing.T) { + wc := wireguardOptions{ + endpoint: "1.1.1.1:8020", + ip: "10.1.2.8", + pubKey: "foobar", + privKey: "foobar", + presharedKey: "foobar", + jc: "1", + jmin: "2", + jmax: "3", + s1: "4", + s2: "5", + h1: "6", + h2: "7", + h3: "8", + h4: "9", + } + if wc.isAmneziaFlavored() != true { + t.Fatal("expected to be amnezia flavored") + } + }) + + t.Run("configurationHash() is empty for non-amnezia values", func(t *testing.T) { + wc := wireguardOptions{ + endpoint: "1.1.1.1:8020", + } + expected := "" + if diff := cmp.Diff(wc.configurationHash(), expected); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("get the expected configurationHash()", func(t *testing.T) { + wc := wireguardOptions{ + endpoint: "1.1.1.1:8020", + jc: "1", + jmin: "2", + jmax: "3", + s1: "4", + s2: "5", + h1: "6", + h2: "7", + h3: "8", + h4: "9", + } + expected := "adb00b0ab179bfbdf9835bc124cbc7ab7e59bd8b" + if diff := cmp.Diff(wc.configurationHash(), expected); diff != "" { + t.Fatal(diff) + } + }) +} + +func Test_newWireguardOptionsFromConfig(t *testing.T) { + t.Run("good config does not fail", func(t *testing.T) { + c := &Config{ + SafePublicKey: "ZGVhZGJlZWY=", + SafePrivateKey: "ZGVhZGJlZWY=", + SafePresharedKey: "ZGVhZGJlZWY=", + SafeRemote: "1.2.3.4:8080", + } + + opts, err := newWireguardOptionsFromConfig(c) + if !errors.Is(err, nil) { + t.Fatal("did not expect error") + } + + hexExpected := "6465616462656566" // deadbeef + + if diff := cmp.Diff(opts.pubKey, hexExpected); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(opts.privKey, hexExpected); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(opts.presharedKey, hexExpected); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("bad pubkey fails", func(t *testing.T) { + c := &Config{ + SafePublicKey: "ZGVhZGJlZWY", + SafePrivateKey: "ZGVhZGJlZWY=", + SafePresharedKey: "ZGVhZGJlZWY=", + SafeRemote: "1.2.3.4:8080", + } + + opts, err := newWireguardOptionsFromConfig(c) + if opts != nil { + t.Fatal("did not expect anything other than nil") + } + if !errors.Is(err, ErrInvalidInput) { + t.Fatal("not the error we expected") + } + }) + + t.Run("bad privkey fails", func(t *testing.T) { + c := &Config{ + SafePublicKey: "ZGVhZGJlZWY=", + SafePrivateKey: "ZGVhZGJlZWY", + SafePresharedKey: "ZGVhZGJlZWY=", + SafeRemote: "1.2.3.4:8080", + } + + opts, err := newWireguardOptionsFromConfig(c) + if opts != nil { + t.Fatal("did not expect anything other than nil") + } + if !errors.Is(err, ErrInvalidInput) { + t.Fatal("not the error we expected") + } + }) + + t.Run("bad preshared key fails", func(t *testing.T) { + c := &Config{ + SafePublicKey: "ZGVhZGJlZWY=", + SafePrivateKey: "ZGVhZGJlZWY=", + SafePresharedKey: "ZGVhZGJlZWY", + SafeRemote: "1.2.3.4:8080", + } + + opts, err := newWireguardOptionsFromConfig(c) + if opts != nil { + t.Fatal("did not expect anything other than nil") + } + if !errors.Is(err, ErrInvalidInput) { + t.Fatal("not the error we expected") + } + }) +} diff --git a/internal/experiment/wireguard/doc.go b/internal/experiment/wireguard/doc.go new file mode 100644 index 000000000..e0cc371ce --- /dev/null +++ b/internal/experiment/wireguard/doc.go @@ -0,0 +1,2 @@ +// Package wireguard contains the wireguard experiment. +package wireguard diff --git a/internal/experiment/wireguard/richerinput.go b/internal/experiment/wireguard/richerinput.go new file mode 100644 index 000000000..05b0d2e80 --- /dev/null +++ b/internal/experiment/wireguard/richerinput.go @@ -0,0 +1,89 @@ +package wireguard + +import ( + "context" + + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/targetloading" +) + +// Target is a richer-input target that this experiment should measure. +type Target struct { + // Options contains the configuration. + Options *Config + + // URL is the input URL. + URL string +} + +var _ model.ExperimentTarget = &Target{} + +// Category implements [model.ExperimentTarget]. +func (t *Target) Category() string { + return model.DefaultCategoryCode +} + +// Country implements [model.ExperimentTarget]. +func (t *Target) Country() string { + return model.DefaultCountryCode +} + +// Input implements [model.ExperimentTarget]. +func (t *Target) Input() string { + return t.URL +} + +// String implements [model.ExperimentTarget]. +func (t *Target) String() string { + return t.URL +} + +// NewLoader constructs a new [model.ExperimentTargerLoader] instance. +// +// This function PANICS if options is not an instance of [*openvpn.Config]. +func NewLoader(loader *targetloading.Loader, gopts any) model.ExperimentTargetLoader { + // Panic if we cannot convert the options to the expected type. + // + // We do not expect a panic here because the type is managed by the registry package. + options := gopts.(*Config) + + // Construct the proper loader instance. + return &targetLoader{ + loader: loader, + options: options, + session: loader.Session, + } +} + +// targetLoader loads targets for this experiment. +type targetLoader struct { + loader *targetloading.Loader + options *Config + session targetloading.Session +} + +// Load implements model.ExperimentTargetLoader. +func (tl *targetLoader) Load(ctx context.Context) ([]model.ExperimentTarget, error) { + // TODO(ainghazal): implement remote loading when backend is ready. + + // Attempt to load the static inputs from CLI and files + inputs, err := targetloading.LoadStatic(tl.loader) + + // Handle the case where we couldn't load from CLI or files + if err != nil { + return nil, err + } + + // Build the list of targets that we should measure. + var targets []model.ExperimentTarget + + for _, input := range inputs { + targets = append(targets, + &Target{ + Options: tl.options, + URL: input, + }) + + } + return targets, nil +} diff --git a/internal/experiment/wireguard/richerinput_test.go b/internal/experiment/wireguard/richerinput_test.go new file mode 100644 index 000000000..872901351 --- /dev/null +++ b/internal/experiment/wireguard/richerinput_test.go @@ -0,0 +1,141 @@ +package wireguard + +import ( + "context" + "errors" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/mocks" + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/targetloading" +) + +func TestTarget(t *testing.T) { + target := &Target{ + URL: "wg://unknown.corp", + } + + t.Run("Category", func(t *testing.T) { + if target.Category() != model.DefaultCategoryCode { + t.Fatal("invalid Category") + } + }) + + t.Run("Country", func(t *testing.T) { + if target.Country() != model.DefaultCountryCode { + t.Fatal("invalid Country") + } + }) + + t.Run("Input", func(t *testing.T) { + if target.Input() != "wg://unknown.corp" { + t.Fatal("invalid Input") + } + }) + + t.Run("String", func(t *testing.T) { + if target.String() != "wg://unknown.corp" { + t.Fatal("invalid String") + } + }) +} + +func TestNewLoader(t *testing.T) { + // create the pointers we expect to see + child := &targetloading.Loader{} + options := &Config{} + + // create the loader and cast it to its private type + loader := NewLoader(child, options).(*targetLoader) + + // make sure the loader is okay + if child != loader.loader { + t.Fatal("invalid loader pointer") + } + + // make sure the options are okay + if options != loader.options { + t.Fatal("invalid options pointer") + } +} + +func TestTargetLoaderLoad(t *testing.T) { + // testcase is a test case implemented by this function + type testcase struct { + // name is the test case name + name string + + // options contains the options to use + options *Config + + // loader is the loader to use + loader *targetloading.Loader + + // expectErr is the error we expect + expectErr error + + // expectResults contains the expected results + expectTargets []model.ExperimentTarget + } + + cases := []testcase{ + + { + name: "with options and inputs", + options: &Config{ + SafeRemote: "1.1.1.1:443", + }, + loader: &targetloading.Loader{ + ExperimentName: "wireguard", + InputPolicy: model.InputNone, + Logger: model.DiscardLogger, + Session: &mocks.Session{}, + StaticInputs: []string{ + "wg://unknown.corp/1.1.1.1", + }, + }, + expectErr: nil, + expectTargets: []model.ExperimentTarget{ + &Target{ + URL: "wg://unknown.corp/1.1.1.1", + Options: &Config{ + SafeRemote: "1.1.1.1:443", + }, + }, + }, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // create a target loader using the given config + tl := &targetLoader{ + loader: tc.loader, + options: tc.options, + } + + // load targets + targets, err := tl.Load(context.Background()) + + // make sure error is consistent + switch { + case err == nil && tc.expectErr == nil: + // fallthrough + + case err != nil && tc.expectErr != nil: + if !errors.Is(err, tc.expectErr) { + t.Fatal("unexpected error", err) + } + // fallthrough + + default: + t.Fatal("expected", tc.expectErr, "got", err) + } + + // make sure the targets are consistent + if diff := cmp.Diff(tc.expectTargets, targets); diff != "" { + t.Fatal(diff) + } + }) + } +} diff --git a/internal/experiment/wireguard/testkeys.go b/internal/experiment/wireguard/testkeys.go new file mode 100644 index 000000000..b1b50f0bd --- /dev/null +++ b/internal/experiment/wireguard/testkeys.go @@ -0,0 +1,30 @@ +package wireguard + +// TestKeys contains the experiment's result. +// +// This is what will end up into the Measurement.TestKeys field +// when you run this experiment. +// +// In other words, the variables in this struct will be +// the specific results of this experiment. +type TestKeys struct { + Success bool `json:"success"` + Endpoint string `json:"endpoint"` + EndpointASN string `json:"endpoint_asn,omitempty"` + EndpointID string `json:"endpoint_id,omitempty"` + Failure *string `json:"failure"` + NetworkEvents []*Event `json:"network_events"` + URLGet []*URLGetResult `json:"urlget"` +} + +// URLGetResult is the result of fetching a URL via the wireguard tunnel, +// using the standard library. +type URLGetResult struct { + ByteCount int `json:"bytes,omitempty"` + Error string `json:"error,omitempty"` + Failure *string `json:"failure"` + StatusCode int `json:"status_code"` + T0 float64 `json:"t0"` + T float64 `json:"t"` + URL string `json:"url"` +} diff --git a/internal/experiment/wireguard/urlget.go b/internal/experiment/wireguard/urlget.go new file mode 100644 index 000000000..eae2f9a5e --- /dev/null +++ b/internal/experiment/wireguard/urlget.go @@ -0,0 +1,67 @@ +package wireguard + +import ( + "context" + "net/http" + "time" + + "github.com/ooni/probe-cli/v3/internal/measurexlite" + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/netxlite" +) + +const ( + // defaultURLGetTarget is the web page that the experiment will fetch by default. + defaultURLGetTarget = "https://info.cern.ch/" +) + +// urlget implements an straightforward urlget experiment using the standard library. +// By default we pass the wireguard tunnel DialContext to the `http.Transport` on the `http.Client` creation. +func (m *Measurer) urlget(ctx context.Context, url string, zeroTime time.Time, logger model.Logger) *URLGetResult { + if m.dialContextFn == nil { + m.dialContextFn = m.tnet.DialContext + } + if m.httpClient == nil { + m.httpClient = &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + DialContext: m.dialContextFn, + TLSHandshakeTimeout: 30 * time.Second, + }} + } + + start := time.Since(zeroTime).Seconds() + r, err := m.httpClient.Get(url) + if err != nil { + logger.Warnf("urlget error: %v", err.Error()) + return newURLResultFromError(url, zeroTime, start, err) + } + body, err := netxlite.ReadAllContext(ctx, r.Body) + if err != nil { + logger.Warnf("urlget error: %v", err.Error()) + return newURLResultFromError(url, zeroTime, start, err) + } + defer r.Body.Close() + + return newURLResultWithStatusCode(url, zeroTime, start, r.StatusCode, body) +} + +func newURLResultFromError(url string, zeroTime time.Time, start float64, err error) *URLGetResult { + return &URLGetResult{ + URL: url, + T0: start, + T: time.Since(zeroTime).Seconds(), + Failure: measurexlite.NewFailure(err), + Error: err.Error(), + } +} + +func newURLResultWithStatusCode(url string, zeroTime time.Time, start float64, statusCode int, body []byte) *URLGetResult { + return &URLGetResult{ + ByteCount: len(body), + URL: url, + T0: start, + T: time.Since(zeroTime).Seconds(), + StatusCode: statusCode, + } +} diff --git a/internal/experiment/wireguard/urlget_test.go b/internal/experiment/wireguard/urlget_test.go new file mode 100644 index 000000000..f044a6411 --- /dev/null +++ b/internal/experiment/wireguard/urlget_test.go @@ -0,0 +1,117 @@ +package wireguard + +import ( + "context" + "errors" + "math" + "net" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/ooni/probe-cli/v3/internal/model" +) + +type failingHttpClient struct{} + +func (c *failingHttpClient) Get(string) (*http.Response, error) { + return nil, errors.New("some error") +} + +func Test_urlget(t *testing.T) { + t.Run("dummy server gets a URLGetResult, with no error", func(t *testing.T) { + expected := "dummy data" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(expected)) + })) + defer srv.Close() + + m := &Measurer{} + m.dialContextFn = func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + r := m.urlget(context.Background(), srv.URL, time.Now(), model.DiscardLogger) + if r.StatusCode != 200 { + t.Fatal("expected statusCode==200") + } + }) + + t.Run("dummy server gets a URLGetResult with 500 status code", func(t *testing.T) { + expected := "dummy data" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + w.Write([]byte(expected)) + })) + defer srv.Close() + + m := &Measurer{} + m.dialContextFn = func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + r := m.urlget(context.Background(), srv.URL, time.Now(), model.DiscardLogger) + if r.StatusCode != 500 { + t.Fatal("expected statusCode==500") + } + }) + + t.Run("client returns error", func(t *testing.T) { + m := &Measurer{} + m.httpClient = &failingHttpClient{} + + r := m.urlget(context.Background(), "http://example.org", time.Now(), model.DiscardLogger) + expectedError := "unknown_failure: some error" + if *r.Failure != expectedError { + t.Fatal("expected error") + } + }) +} + +func Test_newURLResultFromError(t *testing.T) { + url := "https://example.org" + zeroTime := time.Now().Add(-1 * time.Second) + start := 0.1 + err := errors.New("some error") + + r := newURLResultFromError(url, zeroTime, start, err) + if r.URL != url { + t.Fatal("wrong url") + } + if r.T0 != start { + t.Fatal("wrong t0") + } + if math.Abs(r.T-1.0) > 0.01 { + t.Fatal("should be ~now, not", r.T) + } + if r.Error != err.Error() { + t.Fatal("wrong error") + } + expectedFailure := "unknown_failure: " + err.Error() + if *r.Failure != expectedFailure { + t.Fatal(*r.Failure) + } +} + +func Test_newURLResultWithStratusCode(t *testing.T) { + url := "https://example.org" + zeroTime := time.Now().Add(-1 * time.Second) + start := 0.1 + + r := newURLResultWithStatusCode(url, zeroTime, start, 200, []byte("potatoes")) + if r.URL != url { + t.Fatal("wrong url") + } + if r.T0 != start { + t.Fatal("wrong t0") + } + if math.Abs(r.T-1.0) > 0.01 { + t.Fatal("should be ~now, not", r.T) + } + if r.StatusCode != 200 { + t.Fatal("expected statusCode==200") + } + if r.ByteCount != 8 { + t.Fatal("expected byteCount=8") + } +} diff --git a/internal/experiment/wireguard/wireguard.go b/internal/experiment/wireguard/wireguard.go new file mode 100644 index 000000000..9171ceb14 --- /dev/null +++ b/internal/experiment/wireguard/wireguard.go @@ -0,0 +1,296 @@ +package wireguard + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "net/netip" + "strings" + "time" + + "github.com/ooni/probe-cli/v3/internal/measurexlite" + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/targetloading" + + "github.com/amnezia-vpn/amneziawg-go/conn" + "github.com/amnezia-vpn/amneziawg-go/device" + "github.com/amnezia-vpn/amneziawg-go/tun" + "github.com/amnezia-vpn/amneziawg-go/tun/netstack" +) + +const ( + testName = "wireguard" + testVersion = "0.1.2" +) + +var ( + ErrInputRequired = targetloading.ErrInputRequired + ErrInvalidInputType = targetloading.ErrInvalidInputType + + // TODO(ainghazal): fix after adding this error into targetloading + ErrInvalidInput = errors.New("invalid input") +) + +type httpClient interface { + Get(string) (*http.Response, error) +} + +// Measurer performs the measurement. +type Measurer struct { + events *eventLogger + options *wireguardOptions + tnet *netstack.Net + + // used just for testing + dialContextFn func(context.Context, string, string) (net.Conn, error) + httpClient httpClient +} + +// NewExperimentMeasurer creates a new ExperimentMeasurer. +func NewExperimentMeasurer() model.ExperimentMeasurer { + return &Measurer{ + events: newEventLogger(), + options: &wireguardOptions{}, + } +} + +// ExperimentName implements model.ExperimentMeasurer.ExperimentName. +func (m *Measurer) ExperimentName() string { + return testName +} + +// ExperimentVersion implements model.ExperimentMeasurer.ExperimentVersion. +func (m *Measurer) ExperimentVersion() string { + return testVersion +} + +// Run implements model.ExperimentMeasurer.Run. +func (m *Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { + measurement := args.Measurement + sess := args.Session + zeroTime := measurement.MeasurementStartTimeSaved + + var err error + + // 0. fail if there is no richer input target. + if args.Target == nil { + return ErrInputRequired + } + + // 1. setup tunnel after parsing options + target, ok := args.Target.(*Target) + if !ok { + return ErrInvalidInputType + } + + config, input := target.Options, target.URL + if err := m.setupWireguardFromConfig(config); err != nil { + // A failure at this point means that we are not able + // to validate the minimal set of options that we need to probe an endpoint. + // We abort the experiment and submit nothing. + return err + } + + // 2. create tunnel + err = m.createTunnel(sess, zeroTime, config) + + testkeys := &TestKeys{ + Success: err == nil, + Failure: measurexlite.NewFailure(err), + URLGet: make([]*URLGetResult, 0), + } + + if config.PublicTarget { + testkeys.Endpoint = m.options.endpoint + } else { + // TODO(ainghazal): if the target is not public, + // we might want to substitute it with ASN. + testkeys.Endpoint = input + } + + testkeys.EndpointID = m.options.configurationHash() + if config.PublicAmneziaParameters { + // TODO(ainghazal): copy the parameters as testkeys + } + + // 3. use tunnel + if err == nil { + sess.Logger().Info("Using the wireguard tunnel.") + urlgetResult := m.urlget(ctx, defaultURLGetTarget, zeroTime, sess.Logger()) + testkeys.URLGet = append(testkeys.URLGet, urlgetResult) + testkeys.NetworkEvents = m.events.log() + } + + // 4. assign test keys + measurement.TestKeys = testkeys + sess.Logger().Infof("%s", "Wireguard experiment done.") + + // NOTE: important to return nil to submit measurement. + return nil +} + +func (m *Measurer) setupWireguardFromConfig(config *Config) error { + opts, err := newWireguardOptionsFromConfig(config) + if err != nil { + return err + } + if ok := opts.validate(); !ok { + return fmt.Errorf("%w: %s", ErrInvalidInput, "cannot validate wireguard options") + } + m.options = opts + return nil +} + +func (m *Measurer) createTunnel(sess model.ExperimentSession, zeroTime time.Time, config *Config) error { + sess.Logger().Info("wireguard: create tunnel") + sess.Logger().Infof("endpoint: %s", m.options.endpoint) + + _, tnet, err := m.configureWireguardInterface(sess.Logger(), m.events, zeroTime, config) + if err != nil { + return err + } + m.tnet = tnet + + sess.Logger().Info("wireguard: create tunnel done") + return nil +} + +func (m *Measurer) configureWireguardInterface( + logger model.Logger, + eventlogger *eventLogger, + zeroTime time.Time, + config *Config) (tun.Device, *netstack.Net, error) { + devTun, tnet, err := netstack.CreateNetTUN( + []netip.Addr{netip.MustParseAddr(m.options.ip)}, + []netip.Addr{netip.MustParseAddr(m.options.ns)}, + 1420) + if err != nil { + return nil, nil, err + } + + dev := device.NewDevice( + devTun, + conn.NewDefaultBind(), + newWireguardLogger(logger, eventlogger, config.Verbose, zeroTime, time.Since), + ) + + var ipcStr string + + opts := m.options + + ipcStr = `jc=` + opts.jc + ` +jmin=` + opts.jmin + ` +jmax=` + opts.jmax + ` +s1=` + opts.s1 + ` +s2=` + opts.s2 + ` +h1=` + opts.h1 + ` +h2=` + opts.h2 + ` +h3=` + opts.h3 + ` +h4=` + opts.h4 + ` +private_key=` + opts.privKey + ` +public_key=` + opts.pubKey + ` +preshared_key=` + opts.presharedKey + ` +endpoint=` + opts.endpoint + ` +allowed_ip=0.0.0.0/0 +` + dev.IpcSet(ipcStr) + + err = dev.Up() + if err != nil { + return nil, nil, err + } + return devTun, tnet, nil +} + +// +// logging utilities +// + +// Event is a network event obtained by parsing wireguard logs. +type Event struct { + EventType string `json:"operation"` + T float64 `json:"t"` +} + +func newEvent(etype string) *Event { + return &Event{ + EventType: etype, + } +} + +type eventLogger struct { + events []*Event +} + +func newEventLogger() *eventLogger { + return &eventLogger{events: make([]*Event, 0)} +} + +func (el *eventLogger) append(e *Event) { + el.events = append(el.events, e) +} + +func (el *eventLogger) log() []*Event { + return el.events +} + +const ( + LOG_KEEPALIVE = "Receiving keepalive packet" + LOG_SEND_HANDSHAKE = "Sending handshake initiation" + LOG_RECV_HANDSHAKE = "Received handshake response" + + EVT_RECV_KEEPALIVE = "RECV_KEEPALIVE" + EVT_SEND_HANDSHAKE_INIT = "SEND_HANDSHAKE_INIT" + EVT_RECV_HANDSHAKE_RESP = "RECV_HANDSHAKE_RESP" +) + +// newWireguardLogger looks at the strings logged by the wireguard +// implementation. It performs simple regex matching and then +// it appends the matchign Event in the passed eventLogger. +// This approach has some potential for brittleness (in the unlikely case +// that upstream wireguard codebase changes the emitted log lines), +// but adding typed log events to the wg codebase might prove to be a +// particularly time-consuming rewrite. +func newWireguardLogger( + logger model.Logger, + eventlogger *eventLogger, + verbose bool, + zeroTime time.Time, + sinceFn func(time.Time) time.Duration) *device.Logger { + verbosef := func(format string, args ...any) { + msg := fmt.Sprintf(format, args...) + + if verbose { + logger.Debugf(msg) + } + + // TODO(ainghazal): we might be interested in parsing additional events. + if strings.Contains(msg, LOG_KEEPALIVE) { + evt := newEvent(EVT_RECV_KEEPALIVE) + evt.T = sinceFn(zeroTime).Seconds() + eventlogger.append(evt) + return + } + if strings.Contains(msg, LOG_SEND_HANDSHAKE) { + evt := newEvent(EVT_SEND_HANDSHAKE_INIT) + evt.T = sinceFn(zeroTime).Seconds() + eventlogger.append(evt) + return + } + if strings.Contains(msg, LOG_RECV_HANDSHAKE) { + evt := newEvent(EVT_RECV_HANDSHAKE_RESP) + evt.T = sinceFn(zeroTime).Seconds() + eventlogger.append(evt) + return + } + } + errorf := func(format string, args ...any) { + logger.Warnf(format, args...) + } + return &device.Logger{ + Verbosef: verbosef, + Errorf: errorf, + } +} diff --git a/internal/experiment/wireguard/wireguard_test.go b/internal/experiment/wireguard/wireguard_test.go new file mode 100644 index 000000000..e0ad68752 --- /dev/null +++ b/internal/experiment/wireguard/wireguard_test.go @@ -0,0 +1,143 @@ +package wireguard + +import ( + "testing" + "time" + + "github.com/amnezia-vpn/amneziawg-go/device" + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/model" +) + +func TestNewExperimentMeasurer(t *testing.T) { + m := NewExperimentMeasurer() + if m.ExperimentName() != "wireguard" { + t.Fatal("invalid ExperimentName") + } + if m.ExperimentVersion() != "0.1.2" { + t.Fatal("invalid ExperimentVersion") + } +} + +func TestNewEvent(t *testing.T) { + e := newEvent("foo") + if e.EventType != "foo" { + t.Fatal("expected type foo") + } + + e1 := newEvent("bar") + e2 := newEvent("baaz") + + log := newEventLogger() + log.append(e) + log.append(e1) + log.append(e2) + + if diff := cmp.Diff(log.log(), []*Event{e, e1, e2}); diff != "" { + t.Fatal(diff) + } +} + +func TestNewWireguardLogger(t *testing.T) { + wgLogger := func(events *eventLogger, t int) *device.Logger { + wgLogger := newWireguardLogger( + model.DiscardLogger, + events, + false, + time.Now(), + func(time.Time) time.Duration { + return time.Duration(t) * time.Second + }) + return wgLogger + } + + t.Run("keepalive packet", func(t *testing.T) { + eventLogger := newEventLogger() + logger := wgLogger(eventLogger, 2) + logger.Verbosef(LOG_KEEPALIVE) + evts := eventLogger.log() + if len(evts) != 1 { + t.Fatal("expected 1 event") + } + if evts[0].EventType != EVT_RECV_KEEPALIVE { + t.Fatal("expected RECV_KEEPALIVE") + } + if evts[0].T != 2.0 { + t.Fatal("expected T=2") + } + }) + t.Run("handshake send packet", func(t *testing.T) { + eventLogger := newEventLogger() + logger := wgLogger(eventLogger, 3) + logger.Verbosef(LOG_SEND_HANDSHAKE) + evts := eventLogger.log() + if len(evts) != 1 { + t.Fatal("expected 1 event") + } + if evts[0].EventType != EVT_SEND_HANDSHAKE_INIT { + t.Fatal("expected SEND_HANDSHAKE_INIT ") + } + if evts[0].T != 3.0 { + t.Fatal("expected T=3") + } + }) + t.Run("handshake recv packet", func(t *testing.T) { + eventLogger := newEventLogger() + logger := wgLogger(eventLogger, 4) + logger.Verbosef(LOG_RECV_HANDSHAKE) + evts := eventLogger.log() + if len(evts) != 1 { + t.Fatal("expected 1 event") + } + if evts[0].EventType != EVT_RECV_HANDSHAKE_RESP { + t.Fatal("expected RECV_HADNSHAKE_RESP ") + } + if evts[0].T != 4.0 { + t.Fatal("expected T=4") + } + }) + +} + +// TODO(cleanup) ---- +/* + +func TestSuccess(t *testing.T) { + m := NewExperimentMeasurer() + if m.ExperimentName() != "wireguard" { + t.Fatal("invalid ExperimentName") + } + if m.ExperimentVersion() != "0.1.1" { + t.Fatal("invalid ExperimentVersion") + } + ctx := context.Background() + sess := &mockable.Session{MockableLogger: log.Log} + callbacks := model.NewPrinterCallbacks(sess.Logger()) + measurement := new(model.Measurement) + args := &model.ExperimentArgs{ + Callbacks: callbacks, + Measurement: measurement, + Session: sess, + } + err := m.Run(ctx, args) + if err != nil { + t.Fatal(err) + } +} + +func TestFailure(t *testing.T) { + m := NewExperimentMeasurer() + ctx := context.Background() + sess := &mockable.Session{MockableLogger: log.Log} + callbacks := model.NewPrinterCallbacks(sess.Logger()) + args := &model.ExperimentArgs{ + Callbacks: callbacks, + Measurement: new(model.Measurement), + Session: sess, + } + err := m.Run(ctx, args) + if !errors.Is(err, example.ErrFailure) { + t.Fatal("expected an error here") + } +} +*/ diff --git a/internal/registry/factory_test.go b/internal/registry/factory_test.go index d4a3d8368..0668486d0 100644 --- a/internal/registry/factory_test.go +++ b/internal/registry/factory_test.go @@ -726,6 +726,11 @@ func TestNewFactory(t *testing.T) { enabledByDefault: true, inputPolicy: model.InputNone, }, + "wireguard": { + enabledByDefault: true, + inputPolicy: model.InputStrictlyRequired, + interruptible: true, + }, } // testCase is a test case checked by this func diff --git a/internal/registry/wireguard.go b/internal/registry/wireguard.go new file mode 100644 index 000000000..480e877d5 --- /dev/null +++ b/internal/registry/wireguard.go @@ -0,0 +1,29 @@ +package registry + +// +// Registers the `wireguard` experiment. +// + +import ( + "github.com/ooni/probe-cli/v3/internal/experiment/wireguard" + "github.com/ooni/probe-cli/v3/internal/model" +) + +func init() { + const canonicalName = "wireguard" + AllExperiments["wireguard"] = func() *Factory { + return &Factory{ + build: func(config interface{}) model.ExperimentMeasurer { + return wireguard.NewExperimentMeasurer() + }, + canonicalName: canonicalName, + config: &wireguard.Config{}, + enabledByDefault: true, + interruptible: true, + // TODO(ainghazal): when the backend is ready to hand us targets, + // we will use InputOrQueryBackend. + inputPolicy: model.InputStrictlyRequired, + newLoader: wireguard.NewLoader, + } + } +}