diff --git a/README.md b/README.md index 6104790..fee6a95 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,7 @@ Measurement Commands: traceroute Run a traceroute test Additional Commands: + auth Authenticate with the Globalping API completion Generate the autocompletion script for the specified shell help Help about any command history Display the measurement history of your current session diff --git a/cmd/auth.go b/cmd/auth.go new file mode 100644 index 0000000..914fa81 --- /dev/null +++ b/cmd/auth.go @@ -0,0 +1,141 @@ +package cmd + +import ( + "errors" + "syscall" + + "github.com/jsdelivr/globalping-cli/globalping" + "github.com/spf13/cobra" +) + +func (r *Root) initAuth() { + authCmd := &cobra.Command{ + Use: "auth", + Short: "Authenticate with the Globalping API", + Long: "Authenticate with the Globalping API for higher measurements limits.", + } + + loginCmd := &cobra.Command{ + RunE: r.RunAuthLogin, + Use: "login", + Short: "Log in to your Globalping account", + Long: `Log in to your Globalping account for higher measurements limits.`, + } + + loginFlags := loginCmd.Flags() + loginFlags.Bool("with-token", false, "authenticate with a token read from stdin instead of the default browser-based flow") + + statusCmd := &cobra.Command{ + RunE: r.RunAuthStatus, + Use: "status", + Short: "Check the current authentication status", + Long: `Check the current authentication status.`, + } + + logoutCmd := &cobra.Command{ + RunE: r.RunAuthLogout, + Use: "logout", + Short: "Log out from your Globalping account", + Long: `Log out from your Globalping account.`, + } + + authCmd.AddCommand(loginCmd) + authCmd.AddCommand(statusCmd) + authCmd.AddCommand(logoutCmd) + + r.Cmd.AddCommand(authCmd) +} + +func (r *Root) RunAuthLogin(cmd *cobra.Command, args []string) error { + var err error + oldToken := r.storage.GetProfile().Token + withToken := cmd.Flags().Changed("with-token") + if withToken { + err := r.loginWithToken() + if err != nil { + return err + } + if oldToken != nil { + r.client.RevokeToken(oldToken.RefreshToken) + } + return nil + } + res, err := r.client.Authorize(func(e error) { + defer func() { + r.cancel <- syscall.SIGINT + }() + if e != nil { + err = e + r.Cmd.SilenceUsage = true + return + } + if oldToken != nil { + r.client.RevokeToken(oldToken.RefreshToken) + } + r.printer.Println("Success! You are now authenticated.") + }) + if err != nil { + return err + } + r.printer.Println("Please visit the following URL to authenticate:") + r.printer.Println(res.AuthorizeURL) + r.utils.OpenBrowser(res.AuthorizeURL) + r.printer.Println("\nCan't use the browser-based flow? Use \"globalping auth login --with-token\" to read a token from stdin instead.") + <-r.cancel + return err +} + +func (r *Root) RunAuthStatus(cmd *cobra.Command, args []string) error { + res, err := r.client.TokenIntrospection("") + if err != nil { + e, ok := err.(*globalping.AuthorizeError) + if ok && e.ErrorType == "not_authorized" { + r.printer.Println("Not logged in.") + return nil + } + return err + } + if res.Active { + r.printer.Printf("Logged in as %s.\n", res.Username) + } else { + r.printer.Println("Not logged in.") + } + return nil +} + +func (r *Root) RunAuthLogout(cmd *cobra.Command, args []string) error { + err := r.client.Logout() + if err != nil { + return err + } + r.printer.Println("You are now logged out.") + return nil +} + +func (r *Root) loginWithToken() error { + r.printer.Println("Please enter your token:") + token, err := r.printer.ReadPassword() + if err != nil { + return err + } + if token == "" { + return errors.New("empty token") + } + introspection, err := r.client.TokenIntrospection(token) + if err != nil { + return err + } + if !introspection.Active { + return errors.New("invalid token") + } + profile := r.storage.GetProfile() + profile.Token = &globalping.Token{ + AccessToken: token, + } + err = r.storage.SaveConfig() + if err != nil { + return errors.New("failed to save token") + } + r.printer.Printf("Logged in as %s.\n", introspection.Username) + return nil +} diff --git a/cmd/auth_test.go b/cmd/auth_test.go new file mode 100644 index 0000000..4259ed9 --- /dev/null +++ b/cmd/auth_test.go @@ -0,0 +1,157 @@ +package cmd + +import ( + "bytes" + "context" + "os" + "syscall" + "testing" + + "github.com/jsdelivr/globalping-cli/globalping" + "github.com/jsdelivr/globalping-cli/mocks" + "github.com/jsdelivr/globalping-cli/storage" + "github.com/jsdelivr/globalping-cli/view" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +func Test_Auth_Login_WithToken(t *testing.T) { + t.Cleanup(sessionCleanup) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + gbMock := mocks.NewMockClient(ctrl) + + w := new(bytes.Buffer) + r := new(bytes.Buffer) + r.WriteString("token\n") + printer := view.NewPrinter(r, w, w) + ctx := createDefaultContext("") + _storage := storage.NewLocalStorage(".test_globalping-cli") + defer _storage.Remove() + err := _storage.Init() + if err != nil { + t.Fatal(err) + } + _storage.GetProfile().Token = &globalping.Token{ + AccessToken: "oldToken", + RefreshToken: "oldRefreshToken", + } + + root := NewRoot(printer, ctx, nil, nil, gbMock, nil, _storage) + + gbMock.EXPECT().TokenIntrospection("token").Return(&globalping.IntrospectionResponse{ + Active: true, + Username: "test", + }, nil) + gbMock.EXPECT().RevokeToken("oldRefreshToken").Return(nil) + + os.Args = []string{"globalping", "auth", "login", "--with-token"} + err = root.Cmd.ExecuteContext(context.TODO()) + assert.NoError(t, err) + + assert.Equal(t, `Please enter your token: +Logged in as test. +`, w.String()) + + profile := _storage.GetProfile() + assert.Equal(t, &storage.Profile{ + Token: &globalping.Token{ + AccessToken: "token", + }, + }, profile) +} + +func Test_Auth_Login(t *testing.T) { + t.Cleanup(sessionCleanup) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + gbMock := mocks.NewMockClient(ctrl) + utilsMock := mocks.NewMockUtils(ctrl) + + w := new(bytes.Buffer) + printer := view.NewPrinter(nil, w, w) + ctx := createDefaultContext("") + _storage := storage.NewLocalStorage(".test_globalping-cli") + defer _storage.Remove() + err := _storage.Init() + if err != nil { + t.Fatal(err) + } + _storage.GetProfile().Token = &globalping.Token{ + AccessToken: "oldToken", + RefreshToken: "oldRefreshToken", + } + + root := NewRoot(printer, ctx, nil, utilsMock, gbMock, nil, _storage) + + gbMock.EXPECT().Authorize(gomock.Any()).Do(func(_ any) { + root.cancel <- syscall.SIGINT + }).Return(&globalping.AuthorizeResponse{ + AuthorizeURL: "http://localhost", + }, nil) + utilsMock.EXPECT().OpenBrowser("http://localhost").Return(nil) + + os.Args = []string{"globalping", "auth", "login"} + err = root.Cmd.ExecuteContext(context.TODO()) + assert.NoError(t, err) + + assert.Equal(t, `Please visit the following URL to authenticate: +http://localhost + +Can't use the browser-based flow? Use "globalping auth login --with-token" to read a token from stdin instead. +`, w.String()) +} + +func Test_AuthStatus(t *testing.T) { + t.Cleanup(sessionCleanup) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + gbMock := mocks.NewMockClient(ctrl) + + w := new(bytes.Buffer) + printer := view.NewPrinter(nil, w, w) + ctx := createDefaultContext("") + + root := NewRoot(printer, ctx, nil, nil, gbMock, nil, nil) + + gbMock.EXPECT().TokenIntrospection("").Return(&globalping.IntrospectionResponse{ + Active: true, + Username: "test", + }, nil) + + os.Args = []string{"globalping", "auth", "status"} + err := root.Cmd.ExecuteContext(context.TODO()) + assert.NoError(t, err) + + assert.Equal(t, `Logged in as test. +`, w.String()) +} + +func Test_Logout(t *testing.T) { + t.Cleanup(sessionCleanup) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + gbMock := mocks.NewMockClient(ctrl) + + w := new(bytes.Buffer) + printer := view.NewPrinter(nil, w, w) + ctx := createDefaultContext("") + + root := NewRoot(printer, ctx, nil, nil, gbMock, nil, nil) + + gbMock.EXPECT().Logout().Return(nil) + + os.Args = []string{"globalping", "auth", "logout"} + err := root.Cmd.ExecuteContext(context.TODO()) + assert.NoError(t, err) + + assert.Equal(t, "You are now logged out.\n", w.String()) +} diff --git a/cmd/common.go b/cmd/common.go index 4784a3a..9804669 100644 --- a/cmd/common.go +++ b/cmd/common.go @@ -18,6 +18,7 @@ import ( "github.com/icza/backscanner" "github.com/jsdelivr/globalping-cli/globalping" "github.com/jsdelivr/globalping-cli/version" + "github.com/jsdelivr/globalping-cli/view" "github.com/shirou/gopsutil/process" ) @@ -114,6 +115,26 @@ func (r *Root) getLocations() ([]globalping.Locations, error) { return locations, nil } +func (r *Root) evaluateError(err error) { + if err == nil { + return + } + e, ok := err.(*globalping.MeasurementError) + if !ok { + return + } + if e.Code == globalping.StatusUnauthorizedWithTokenRefreshed { + r.Cmd.SilenceErrors = true + r.printer.ErrPrintln("Access token successfully refreshed. Try repeating the measurement.") + return + } + if e.Code == http.StatusTooManyRequests && r.ctx.MeasurementsCreated > 0 { + r.Cmd.SilenceErrors = true + r.printer.ErrPrintln(r.printer.Color("> "+e.Message, view.FGBrightYellow)) + return + } +} + type TargetQuery struct { Target string From string diff --git a/cmd/common_test.go b/cmd/common_test.go index 8675e67..1ae646c 100644 --- a/cmd/common_test.go +++ b/cmd/common_test.go @@ -28,7 +28,7 @@ func Test_UpdateContext(t *testing.T) { func test_updateContext_NoArg(t *testing.T) { ctx := createDefaultContext("ping") printer := view.NewPrinter(nil, nil, nil) - root := NewRoot(printer, ctx, nil, nil, nil, nil) + root := NewRoot(printer, ctx, nil, nil, nil, nil, nil) err := root.updateContext("test", []string{"1.1.1.1"}) assert.Equal(t, "test", ctx.Cmd) @@ -40,7 +40,7 @@ func test_updateContext_NoArg(t *testing.T) { func test_updateContext_Country(t *testing.T) { ctx := createDefaultContext("ping") printer := view.NewPrinter(nil, nil, nil) - root := NewRoot(printer, ctx, nil, nil, nil, nil) + root := NewRoot(printer, ctx, nil, nil, nil, nil, nil) err := root.updateContext("test", []string{"1.1.1.1", "from", "Germany"}) assert.Equal(t, "test", ctx.Cmd) @@ -53,7 +53,7 @@ func test_updateContext_Country(t *testing.T) { func test_updateContext_CountryWhitespace(t *testing.T) { ctx := createDefaultContext("ping") printer := view.NewPrinter(nil, nil, nil) - root := NewRoot(printer, ctx, nil, nil, nil, nil) + root := NewRoot(printer, ctx, nil, nil, nil, nil, nil) err := root.updateContext("test", []string{"1.1.1.1", "from", " Germany, France"}) assert.Equal(t, "test", ctx.Cmd) @@ -65,7 +65,7 @@ func test_updateContext_CountryWhitespace(t *testing.T) { func test_updateContext_NoTarget(t *testing.T) { ctx := createDefaultContext("ping") printer := view.NewPrinter(nil, nil, nil) - root := NewRoot(printer, ctx, nil, nil, nil, nil) + root := NewRoot(printer, ctx, nil, nil, nil, nil, nil) err := root.updateContext("test", []string{}) assert.Error(t, err) @@ -78,7 +78,7 @@ func test_updateContext_CIEnv(t *testing.T) { ctx := createDefaultContext("ping") printer := view.NewPrinter(nil, nil, nil) - root := NewRoot(printer, ctx, nil, nil, nil, nil) + root := NewRoot(printer, ctx, nil, nil, nil, nil, nil) err := root.updateContext("test", []string{"1.1.1.1"}) assert.Equal(t, "test", ctx.Cmd) @@ -92,7 +92,7 @@ func test_updateContext_TargetIsNotAHostname(t *testing.T) { ctx := createDefaultContext("ping") ctx.Ipv4 = true printer := view.NewPrinter(nil, nil, nil) - root := NewRoot(printer, ctx, nil, nil, nil, nil) + root := NewRoot(printer, ctx, nil, nil, nil, nil, nil) err := root.updateContext("ping", []string{"1.1.1.1"}) assert.EqualError(t, err, ErrTargetIPVersionNotAllowed.Error()) @@ -107,7 +107,7 @@ func test_updateContext_ResolverIsNotAHostname(t *testing.T) { ctx := createDefaultContext("dns") ctx.Ipv4 = true printer := view.NewPrinter(nil, nil, nil) - root := NewRoot(printer, ctx, nil, nil, nil, nil) + root := NewRoot(printer, ctx, nil, nil, nil, nil, nil) err := root.updateContext("dns", []string{"example.com", "@1.1.1.1"}) assert.EqualError(t, err, ErrResolverIPVersionNotAllowed.Error()) diff --git a/cmd/dns.go b/cmd/dns.go index 970d1b3..288c07b 100644 --- a/cmd/dns.go +++ b/cmd/dns.go @@ -102,6 +102,7 @@ func (r *Root) RunDNS(cmd *cobra.Command, args []string) error { res, err := r.client.CreateMeasurement(opts) if err != nil { cmd.SilenceUsage = silenceUsageOnCreateMeasurementError(err) + r.evaluateError(err) return err } @@ -109,7 +110,7 @@ func (r *Root) RunDNS(cmd *cobra.Command, args []string) error { hm := &view.HistoryItem{ Id: res.ID, Status: globalping.StatusInProgress, - StartedAt: r.time.Now(), + StartedAt: r.utils.Now(), } r.ctx.History.Push(hm) if r.ctx.RecordToSession { diff --git a/cmd/dns_test.go b/cmd/dns_test.go index eff8ead..1378f0f 100644 --- a/cmd/dns_test.go +++ b/cmd/dns_test.go @@ -37,13 +37,13 @@ func Test_Execute_DNS_Default(t *testing.T) { viewerMock := mocks.NewMockViewer(ctrl) viewerMock.EXPECT().Output(measurementID1, expectedOpts).Times(1).Return(nil) - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) ctx := createDefaultContext("dns") - root := NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "dns", "jsdelivr.com", "from", "Berlin", @@ -101,13 +101,13 @@ func Test_Execute_DNS_IPv4(t *testing.T) { viewerMock := mocks.NewMockViewer(ctrl) viewerMock.EXPECT().Output(measurementID1, expectedOpts).Times(1).Return(nil) - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) ctx := createDefaultContext("dns") - root := NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "dns", "jsdelivr.com", "from", "Berlin", @@ -141,13 +141,13 @@ func Test_Execute_DNS_IPv6(t *testing.T) { viewerMock := mocks.NewMockViewer(ctrl) viewerMock.EXPECT().Output(measurementID1, expectedOpts).Times(1).Return(nil) - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) ctx := createDefaultContext("dns") - root := NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "dns", "jsdelivr.com", "from", "Berlin", diff --git a/cmd/history.go b/cmd/history.go index 1dde76f..e1b662b 100644 --- a/cmd/history.go +++ b/cmd/history.go @@ -100,7 +100,7 @@ func (r *Root) UpdateHistory() error { } index = fmt.Sprintf("%d", i) } - time := r.time.Now().Unix() + time := r.utils.Now().Unix() cmd := strings.Join(os.Args[1:], " ") f, err := os.OpenFile(getHistoryPath(), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { diff --git a/cmd/history_test.go b/cmd/history_test.go index 3d7b477..c902897 100644 --- a/cmd/history_test.go +++ b/cmd/history_test.go @@ -20,13 +20,13 @@ func Test_Execute_History_Default(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() ctx := createDefaultContext("ping") w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) - root := NewRoot(printer, ctx, nil, timeMock, nil, nil) + root := NewRoot(printer, ctx, nil, utilsMock, nil, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com"} ctx.History.Push(&view.HistoryItem{ diff --git a/cmd/http.go b/cmd/http.go index a296a4a..0930b8c 100644 --- a/cmd/http.go +++ b/cmd/http.go @@ -111,6 +111,7 @@ func (r *Root) RunHTTP(cmd *cobra.Command, args []string) error { res, err := r.client.CreateMeasurement(opts) if err != nil { cmd.SilenceUsage = silenceUsageOnCreateMeasurementError(err) + r.evaluateError(err) return err } @@ -118,7 +119,7 @@ func (r *Root) RunHTTP(cmd *cobra.Command, args []string) error { hm := &view.HistoryItem{ Id: res.ID, Status: globalping.StatusInProgress, - StartedAt: r.time.Now(), + StartedAt: r.utils.Now(), } r.ctx.History.Push(hm) if r.ctx.RecordToSession { diff --git a/cmd/http_test.go b/cmd/http_test.go index f16a682..d5f903f 100644 --- a/cmd/http_test.go +++ b/cmd/http_test.go @@ -39,13 +39,13 @@ func Test_Execute_HTTP_Default(t *testing.T) { viewerMock := mocks.NewMockViewer(ctrl) viewerMock.EXPECT().Output(measurementID1, expectedOpts).Times(1).Return(nil) - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) ctx := createDefaultContext("http") - root := NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "http", "jsdelivr.com", "from", "Berlin", "--protocol", "HTTPS", @@ -112,13 +112,13 @@ func Test_Execute_HTTP_IPv4(t *testing.T) { viewerMock := mocks.NewMockViewer(ctrl) viewerMock.EXPECT().Output(measurementID1, expectedOpts).Times(1).Return(nil) - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) ctx := createDefaultContext("http") - root := NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "http", "jsdelivr.com", "from", "Berlin", "--ipv4", @@ -155,13 +155,13 @@ func Test_Execute_HTTP_IPv6(t *testing.T) { viewerMock := mocks.NewMockViewer(ctrl) viewerMock.EXPECT().Output(measurementID1, expectedOpts).Times(1).Return(nil) - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) ctx := createDefaultContext("http") - root := NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "http", "jsdelivr.com", "from", "Berlin", "--ipv6", @@ -251,7 +251,7 @@ func Test_ParseHttpHeaders_Invalid(t *testing.T) { func Test_BuildHttpMeasurementRequest_Full(t *testing.T) { ctx := createDefaultContext("http") printer := view.NewPrinter(nil, nil, nil) - root := NewRoot(printer, ctx, nil, nil, nil, nil) + root := NewRoot(printer, ctx, nil, nil, nil, nil, nil) ctx.Target = "https://example.com/my/path?x=123&yz=abc" ctx.From = "london" @@ -284,7 +284,7 @@ func Test_BuildHttpMeasurementRequest_Full(t *testing.T) { func Test_BuildHttpMeasurementRequest_HEAD(t *testing.T) { ctx := createDefaultContext("http") printer := view.NewPrinter(nil, nil, nil) - root := NewRoot(printer, ctx, nil, nil, nil, nil) + root := NewRoot(printer, ctx, nil, nil, nil, nil, nil) ctx.Target = "https://example.com/my/path?x=123&yz=abc" ctx.From = "london" diff --git a/cmd/install_probe_test.go b/cmd/install_probe_test.go index 94217e2..99fd8ba 100644 --- a/cmd/install_probe_test.go +++ b/cmd/install_probe_test.go @@ -28,7 +28,7 @@ func Test_Execute_Install_Probe_Docker(t *testing.T) { w := new(bytes.Buffer) printer := view.NewPrinter(reader, w, w) ctx := createDefaultContext("install-probe") - root := NewRoot(printer, ctx, nil, nil, nil, probeMock) + root := NewRoot(printer, ctx, nil, nil, nil, probeMock, nil) os.Args = []string{"globalping", "install-probe"} err := root.Cmd.ExecuteContext(context.TODO()) assert.NoError(t, err) diff --git a/cmd/mtr.go b/cmd/mtr.go index 504b470..8b1c4c8 100644 --- a/cmd/mtr.go +++ b/cmd/mtr.go @@ -93,6 +93,7 @@ func (r *Root) RunMTR(cmd *cobra.Command, args []string) error { res, err := r.client.CreateMeasurement(opts) if err != nil { cmd.SilenceUsage = silenceUsageOnCreateMeasurementError(err) + r.evaluateError(err) return err } @@ -100,7 +101,7 @@ func (r *Root) RunMTR(cmd *cobra.Command, args []string) error { hm := &view.HistoryItem{ Id: res.ID, Status: globalping.StatusInProgress, - StartedAt: r.time.Now(), + StartedAt: r.utils.Now(), } r.ctx.History.Push(hm) if r.ctx.RecordToSession { diff --git a/cmd/mtr_test.go b/cmd/mtr_test.go index b45931c..e75446a 100644 --- a/cmd/mtr_test.go +++ b/cmd/mtr_test.go @@ -33,13 +33,13 @@ func Test_Execute_MTR_Default(t *testing.T) { viewerMock := mocks.NewMockViewer(ctrl) viewerMock.EXPECT().Output(measurementID1, expectedOpts).Times(1).Return(nil) - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) ctx := createDefaultContext("mtr") - root := NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "mtr", "jsdelivr.com", "from", "Berlin", "--limit", "2", @@ -92,13 +92,13 @@ func Test_Execute_MTR_IPv4(t *testing.T) { viewerMock := mocks.NewMockViewer(ctrl) viewerMock.EXPECT().Output(measurementID1, expectedOpts).Times(1).Return(nil) - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) ctx := createDefaultContext("mtr") - root := NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "mtr", "jsdelivr.com", "from", "Berlin", "--ipv4", @@ -131,13 +131,13 @@ func Test_Execute_MTR_IPv6(t *testing.T) { viewerMock := mocks.NewMockViewer(ctrl) viewerMock.EXPECT().Output(measurementID1, expectedOpts).Times(1).Return(nil) - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) ctx := createDefaultContext("mtr") - root := NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "mtr", "jsdelivr.com", "from", "Berlin", "--ipv6", diff --git a/cmd/ping.go b/cmd/ping.go index a7d0570..e8b2819 100644 --- a/cmd/ping.go +++ b/cmd/ping.go @@ -2,7 +2,6 @@ package cmd import ( "fmt" - "net/http" "syscall" "time" @@ -98,6 +97,7 @@ func (r *Root) RunPing(cmd *cobra.Command, args []string) error { hm, err := r.createMeasurement(opts) if err != nil { + r.evaluateError(err) return err } return r.viewer.Output(hm.Id, opts) @@ -119,13 +119,7 @@ func (r *Root) pingInfinite(opts *globalping.MeasurementCreate) error { <-r.cancel r.viewer.OutputSummary() - if err != nil && r.ctx.MeasurementsCreated > 0 { - e, ok := err.(*globalping.MeasurementError) - if ok && e.Code == http.StatusTooManyRequests { - r.Cmd.SilenceErrors = true - r.printer.ErrPrintf(r.printer.Color("> "+e.Message, view.FGBrightYellow) + "\n") - } - } + r.evaluateError(err) r.viewer.OutputShare() return err } @@ -133,7 +127,7 @@ func (r *Root) pingInfinite(opts *globalping.MeasurementCreate) error { func (r *Root) ping(opts *globalping.MeasurementCreate) error { var runErr error mbuf := NewMeasurementsBuffer(10) // 10 is the maximum number of measurements that can be in progress at the same time - r.ctx.RunSessionStartedAt = r.time.Now() + r.ctx.RunSessionStartedAt = r.utils.Now() for { mbuf.Restart() elapsedTime := time.Duration(0) @@ -164,14 +158,14 @@ func (r *Root) ping(opts *globalping.MeasurementCreate) error { } if runErr == nil && mbuf.CanAppend() { opts.Locations = []globalping.Locations{{Magic: r.ctx.History.Last().Id}} - start := r.time.Now() + start := r.utils.Now() hm, err := r.createMeasurement(opts) if err != nil { runErr = err // Return the error after all measurements have finished } else { mbuf.Append(hm) } - elapsedTime += r.time.Now().Sub(start) + elapsedTime += r.utils.Now().Sub(start) } el = mbuf.Next() } @@ -204,7 +198,7 @@ func (r *Root) createMeasurement(opts *globalping.MeasurementCreate) (*view.Hist hm := &view.HistoryItem{ Id: res.ID, Status: globalping.StatusInProgress, - StartedAt: r.time.Now(), + StartedAt: r.utils.Now(), } r.ctx.History.Push(hm) if r.ctx.RecordToSession { diff --git a/cmd/ping_test.go b/cmd/ping_test.go index 1e0a9f1..58eb34c 100644 --- a/cmd/ping_test.go +++ b/cmd/ping_test.go @@ -32,13 +32,13 @@ func Test_Execute_Ping_Default(t *testing.T) { viewerMock := mocks.NewMockViewer(ctrl) viewerMock.EXPECT().Output(measurementID1, expectedOpts).Times(1).Return(nil) - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) ctx := createDefaultContext("ping") - root := NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com"} err := root.Cmd.ExecuteContext(context.TODO()) @@ -84,13 +84,13 @@ func Test_Execute_Ping_Locations_And_Session(t *testing.T) { c2 := viewerMock.EXPECT().Output(measurementID2, expectedOpts).Times(3).Return(nil).After(c1) viewerMock.EXPECT().Output(measurementID3, expectedOpts).Times(3).Return(nil).After(c2) - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) ctx := createDefaultContext("ping") - root := NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com", "from", "Berlin,New York "} err := root.Cmd.ExecuteContext(context.TODO()) assert.NoError(t, err) @@ -101,7 +101,7 @@ func Test_Execute_Ping_Locations_And_Session(t *testing.T) { ctx = createDefaultContext("ping") expectedOpts.Locations = []globalping.Locations{{Magic: measurementID1}} - root = NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root = NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com", "from", "@-1"} err = root.Cmd.ExecuteContext(context.TODO()) assert.NoError(t, err) @@ -112,7 +112,7 @@ func Test_Execute_Ping_Locations_And_Session(t *testing.T) { ctx = createDefaultContext("ping") expectedOpts.Locations = []globalping.Locations{{Magic: measurementID1}} - root = NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root = NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com", "from", "last"} err = root.Cmd.ExecuteContext(context.TODO()) assert.NoError(t, err) @@ -122,7 +122,7 @@ func Test_Execute_Ping_Locations_And_Session(t *testing.T) { ctx = createDefaultContext("ping") expectedOpts.Locations = []globalping.Locations{{Magic: measurementID1}} - root = NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root = NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com", "from", "previous"} err = root.Cmd.ExecuteContext(context.TODO()) assert.NoError(t, err) @@ -133,7 +133,7 @@ func Test_Execute_Ping_Locations_And_Session(t *testing.T) { ctx = createDefaultContext("ping") expectedOpts.Locations = []globalping.Locations{{Magic: "world"}} expectedResponse.ID = measurementID2 - root = NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root = NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com"} err = root.Cmd.ExecuteContext(context.TODO()) assert.NoError(t, err) @@ -145,7 +145,7 @@ func Test_Execute_Ping_Locations_And_Session(t *testing.T) { ctx = createDefaultContext("ping") expectedOpts.Locations = []globalping.Locations{{Magic: measurementID1}} - root = NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root = NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com", "from", "@1"} err = root.Cmd.ExecuteContext(context.TODO()) assert.NoError(t, err) @@ -156,7 +156,7 @@ func Test_Execute_Ping_Locations_And_Session(t *testing.T) { ctx = createDefaultContext("ping") expectedOpts.Locations = []globalping.Locations{{Magic: measurementID1}} - root = NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root = NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com", "from", "first"} err = root.Cmd.ExecuteContext(context.TODO()) assert.NoError(t, err) @@ -167,7 +167,7 @@ func Test_Execute_Ping_Locations_And_Session(t *testing.T) { ctx = createDefaultContext("ping") expectedOpts.Locations = []globalping.Locations{{Magic: "world"}} expectedResponse.ID = measurementID3 - root = NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root = NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com"} err = root.Cmd.ExecuteContext(context.TODO()) assert.NoError(t, err) @@ -179,7 +179,7 @@ func Test_Execute_Ping_Locations_And_Session(t *testing.T) { ctx = createDefaultContext("ping") expectedOpts.Locations = []globalping.Locations{{Magic: measurementID2}} - root = NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root = NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com", "from", "@2"} err = root.Cmd.ExecuteContext(context.TODO()) assert.NoError(t, err) @@ -191,7 +191,7 @@ func Test_Execute_Ping_Locations_And_Session(t *testing.T) { ctx = createDefaultContext("ping") expectedOpts.Locations = []globalping.Locations{{Magic: measurementID1}} - root = NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root = NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com", "from", "@-3"} err = root.Cmd.ExecuteContext(context.TODO()) assert.NoError(t, err) @@ -207,7 +207,7 @@ func Test_Execute_Ping_Locations_And_Session(t *testing.T) { assert.Equal(t, expectedHistory, b) ctx = createDefaultContext("ping") - root = NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root = NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com", "from", "@-4"} err = root.Cmd.ExecuteContext(context.TODO()) assert.Error(t, err, ErrIndexOutOfRange) @@ -224,7 +224,7 @@ func Test_Execute_Ping_Locations_And_Session(t *testing.T) { w.Reset() ctx = createDefaultContext("ping") - root = NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root = NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com", "from", "@1"} err = root.Cmd.ExecuteContext(context.TODO()) assert.Error(t, err, ErrNoPreviousMeasurements) @@ -235,7 +235,7 @@ func Test_Execute_Ping_Locations_And_Session(t *testing.T) { w.Reset() ctx = createDefaultContext("ping") - root = NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root = NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com", "from", "@0"} err = root.Cmd.ExecuteContext(context.TODO()) assert.Error(t, err, ErrInvalidIndex) @@ -246,7 +246,7 @@ func Test_Execute_Ping_Locations_And_Session(t *testing.T) { w.Reset() ctx = createDefaultContext("ping") - root = NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root = NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com", "from", "@x"} err = root.Cmd.ExecuteContext(context.TODO()) assert.Error(t, err, ErrInvalidIndex) @@ -257,7 +257,7 @@ func Test_Execute_Ping_Locations_And_Session(t *testing.T) { w.Reset() ctx = createDefaultContext("ping") - root = NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root = NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com", "from", "@"} err = root.Cmd.ExecuteContext(context.TODO()) assert.Error(t, err, ErrInvalidIndex) @@ -334,8 +334,8 @@ func Test_Execute_Ping_Infinite(t *testing.T) { viewerMock.EXPECT().OutputSummary().Times(1) viewerMock.EXPECT().OutputShare().Times(1) - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) @@ -344,7 +344,7 @@ func Test_Execute_Ping_Infinite(t *testing.T) { From: "world", Limit: 1, } - root := NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com", "--infinite", "from", "Berlin"} go func() { @@ -447,13 +447,13 @@ func Test_Execute_Ping_Infinite_Output_Error(t *testing.T) { viewerMock.EXPECT().OutputSummary().Times(1) viewerMock.EXPECT().OutputShare().Times(1) - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) ctx := createDefaultContext("ping") - root := NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com", "--infinite", "from", "Berlin"} err := root.Cmd.ExecuteContext(context.TODO()) assert.Equal(t, "error message", err.Error()) @@ -512,14 +512,14 @@ func Test_Execute_Ping_Infinite_Output_TooManyRequests_Error(t *testing.T) { viewerMock.EXPECT().OutputSummary().Times(1) viewerMock.EXPECT().OutputShare().Times(1) - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() w := new(bytes.Buffer) errW := new(bytes.Buffer) printer := view.NewPrinter(nil, w, errW) ctx := createDefaultContext("ping") - root := NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com", "from", "Berlin", "--infinite", "--share"} err := root.Cmd.ExecuteContext(context.TODO()) assert.Equal(t, "too many requests", err.Error()) @@ -566,13 +566,13 @@ func Test_Execute_Ping_IPv4(t *testing.T) { viewerMock := mocks.NewMockViewer(ctrl) viewerMock.EXPECT().Output(measurementID1, expectedOpts).Times(1).Return(nil) - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) ctx := createDefaultContext("ping") - root := NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com", "--ipv4"} err := root.Cmd.ExecuteContext(context.TODO()) @@ -603,13 +603,13 @@ func Test_Execute_Ping_IPv6(t *testing.T) { viewerMock := mocks.NewMockViewer(ctrl) viewerMock.EXPECT().Output(measurementID1, expectedOpts).Times(1).Return(nil) - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) ctx := createDefaultContext("ping") - root := NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "ping", "jsdelivr.com", "--ipv6"} err := root.Cmd.ExecuteContext(context.TODO()) diff --git a/cmd/root.go b/cmd/root.go index e63b798..6e2fe93 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -11,6 +11,7 @@ import ( "github.com/jsdelivr/globalping-cli/globalping" "github.com/jsdelivr/globalping-cli/globalping/probe" + "github.com/jsdelivr/globalping-cli/storage" "github.com/jsdelivr/globalping-cli/utils" "github.com/jsdelivr/globalping-cli/view" "github.com/spf13/cobra" @@ -22,7 +23,8 @@ type Root struct { viewer view.Viewer client globalping.Client probe probe.Probe - time utils.Time + utils utils.Utils + storage *storage.LocalStorage Cmd *cobra.Command cancel chan os.Signal } @@ -30,10 +32,16 @@ type Root struct { // Execute adds all child commands to the root command and sets flags appropriately. // This is called by main.main(). It only needs to happen once to the rootCmd. func Execute() { - utime := utils.NewTime() + _utils := utils.NewUtils() printer := view.NewPrinter(os.Stdin, os.Stdout, os.Stderr) config := utils.NewConfig() config.Load() + localStorage := storage.NewLocalStorage(".globalping-cli") + if err := localStorage.Init(); err != nil { + printer.ErrPrintf("Error: failed to initialize storage: %v\n", err) + os.Exit(1) + } + profile := localStorage.GetProfile() ctx := &view.Context{ APIMinInterval: config.GlobalpingAPIInterval, History: view.NewHistoryBuffer(10), @@ -42,13 +50,26 @@ func Execute() { } t := time.NewTicker(10 * time.Second) globalpingClient := globalping.NewClientWithCacheCleanup(globalping.Config{ - APIURL: config.GlobalpingAPIURL, - APIToken: config.GlobalpingToken, - UserAgent: getUserAgent(), + APIURL: config.GlobalpingAPIURL, + AuthURL: config.GlobalpingAuthURL, + DashboardURL: config.GlobalpingDashboardURL, + AuthAccessToken: config.GlobalpingToken, + AuthToken: profile.Token, + OnTokenRefresh: func(token *globalping.Token) { + profile.Token = token + err := localStorage.SaveConfig() + if err != nil { + printer.ErrPrintf("Error: failed to save config: %v\n", err) + os.Exit(1) + } + }, + AuthClientID: config.GlobalpingAuthClientID, + AuthClientSecret: config.GlobalpingAuthClientSecret, + UserAgent: getUserAgent(), }, t, 30) globalpingProbe := probe.NewProbe() - viewer := view.NewViewer(ctx, printer, utime, globalpingClient) - root := NewRoot(printer, ctx, viewer, utime, globalpingClient, globalpingProbe) + viewer := view.NewViewer(ctx, printer, _utils, globalpingClient) + root := NewRoot(printer, ctx, viewer, _utils, globalpingClient, globalpingProbe, localStorage) err := root.Cmd.Execute() if err != nil { @@ -60,17 +81,19 @@ func NewRoot( printer *view.Printer, ctx *view.Context, viewer view.Viewer, - time utils.Time, + utils utils.Utils, globalpingClient globalping.Client, globalpingProbe probe.Probe, + localStorage *storage.LocalStorage, ) *Root { root := &Root{ printer: printer, ctx: ctx, viewer: viewer, - time: time, + utils: utils, client: globalpingClient, probe: globalpingProbe, + storage: localStorage, cancel: make(chan os.Signal, 1), } @@ -116,6 +139,7 @@ For more information about the platform, tips, and best practices, visit our Git root.initInstallProbe() root.initVersion() root.initHistory() + root.initAuth() return root } diff --git a/cmd/traceroute.go b/cmd/traceroute.go index 9a7df80..da70210 100644 --- a/cmd/traceroute.go +++ b/cmd/traceroute.go @@ -94,6 +94,7 @@ func (r *Root) RunTraceroute(cmd *cobra.Command, args []string) error { res, err := r.client.CreateMeasurement(opts) if err != nil { cmd.SilenceUsage = silenceUsageOnCreateMeasurementError(err) + r.evaluateError(err) return err } @@ -101,7 +102,7 @@ func (r *Root) RunTraceroute(cmd *cobra.Command, args []string) error { hm := &view.HistoryItem{ Id: res.ID, Status: globalping.StatusInProgress, - StartedAt: r.time.Now(), + StartedAt: r.utils.Now(), } r.ctx.History.Push(hm) if r.ctx.RecordToSession { diff --git a/cmd/traceroute_test.go b/cmd/traceroute_test.go index c941632..66f0ee0 100644 --- a/cmd/traceroute_test.go +++ b/cmd/traceroute_test.go @@ -32,13 +32,13 @@ func Test_Execute_Traceroute_Default(t *testing.T) { viewerMock := mocks.NewMockViewer(ctrl) viewerMock.EXPECT().Output(measurementID1, expectedOpts).Times(1).Return(nil) - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) ctx := createDefaultContext("traceroute") - root := NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "traceroute", "jsdelivr.com", "from", "Berlin", "--limit", "2", @@ -88,13 +88,13 @@ func Test_Execute_Traceroute_IPv4(t *testing.T) { viewerMock := mocks.NewMockViewer(ctrl) viewerMock.EXPECT().Output(measurementID1, expectedOpts).Times(1).Return(nil) - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) ctx := createDefaultContext("traceroute") - root := NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "traceroute", "jsdelivr.com", "from", "Berlin", "--ipv4", @@ -126,13 +126,13 @@ func Test_Execute_Traceroute_IPv6(t *testing.T) { viewerMock := mocks.NewMockViewer(ctrl) viewerMock.EXPECT().Output(measurementID1, expectedOpts).Times(1).Return(nil) - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime).AnyTimes() w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) ctx := createDefaultContext("traceroute") - root := NewRoot(printer, ctx, viewerMock, timeMock, gbMock, nil) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) os.Args = []string{"globalping", "traceroute", "jsdelivr.com", "from", "Berlin", "--ipv6", diff --git a/cmd/version_test.go b/cmd/version_test.go index 241aaec..2b243ad 100644 --- a/cmd/version_test.go +++ b/cmd/version_test.go @@ -15,7 +15,7 @@ func Test_Execute_Version_Default(t *testing.T) { version.Version = "1.0.0" w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) - root := NewRoot(printer, &view.Context{}, nil, nil, nil, nil) + root := NewRoot(printer, &view.Context{}, nil, nil, nil, nil, nil) os.Args = []string{"globalping", "version"} err := root.Cmd.ExecuteContext(context.TODO()) diff --git a/globalping/auth.go b/globalping/auth.go new file mode 100644 index 0000000..0ca540c --- /dev/null +++ b/globalping/auth.go @@ -0,0 +1,288 @@ +package globalping + +import ( + "context" + "encoding/json" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "golang.org/x/oauth2" +) + +type Token struct { + // AccessToken is the token that authorizes and authenticates + // the requests. + AccessToken string `json:"access_token"` + + // TokenType is the type of token. + // The Type method returns either this or "Bearer", the default. + TokenType string `json:"token_type,omitempty"` + + // RefreshToken is a token that's used by the application + // (as opposed to the user) to refresh the access token + // if it expires. + RefreshToken string `json:"refresh_token,omitempty"` + + // Expiry is the optional expiration time of the access token. + // + // If zero, TokenSource implementations will reuse the same + // token forever and RefreshToken or equivalent + // mechanisms for that TokenSource will not be used. + Expiry time.Time `json:"expiry,omitempty"` +} + +type AuthorizeError struct { + Code int `json:"-"` + ErrorType string `json:"error"` + Description string `json:"error_description"` +} + +func (e *AuthorizeError) Error() string { + return e.ErrorType + ": " + e.Description +} + +type AuthorizeResponse struct { + AuthorizeURL string + CallbackURL string +} + +func (c *client) Authorize(callback func(error)) (*AuthorizeResponse, error) { + pkce := oauth2.GenerateVerifier() + mux := http.NewServeMux() + server := &http.Server{ + Handler: mux, + } + callbackURL := "" + mux.HandleFunc("/callback", func(w http.ResponseWriter, req *http.Request) { + req.ParseForm() + token, err := c.exchange(req.Form, pkce, callbackURL) + if err != nil { + http.Redirect(w, req, c.dashboardURL+"/authorize/error", http.StatusFound) + } else { + http.Redirect(w, req, c.dashboardURL+"/authorize/success", http.StatusFound) + } + go func() { + server.Shutdown(req.Context()) + if err == nil { + c.token.Store(token) + if c.onTokenRefresh != nil { + c.onTokenRefresh(mapToken(token)) + } + } + callback(err) + }() + }) + var err error + var ln net.Listener + ports := []int{60000, 60010, 60020, 60030, 60040, 60100, 60110, 60120, 60130, 60140} + port := "" + for i := range ports { + port = strconv.Itoa(ports[i]) + ln, err = net.Listen("tcp", ":"+port) + if err == nil { + break + } + } + if err != nil { + return nil, err + } + go func() { + err := server.Serve(ln) + if err != nil && err != http.ErrServerClosed { + callback(&AuthorizeError{ErrorType: "failed to start server", Description: err.Error()}) + } + }() + callbackURL = "http://localhost:" + port + "/callback" + return &AuthorizeResponse{ + AuthorizeURL: c.oauth2.AuthCodeURL("", oauth2.S256ChallengeOption(pkce)), + CallbackURL: callbackURL, + }, nil +} + +func (c *client) TokenIntrospection(token string) (*IntrospectionResponse, error) { + if token == "" { + var err error + token, _, err = c.accessToken() + if err != nil { + return nil, &AuthorizeError{ + ErrorType: "not_authorized", + Description: err.Error(), + } + } + } + if token == "" { + return nil, &AuthorizeError{ + ErrorType: "not_authorized", + Description: "client is not authorized", + } + } + return c.introspection(token) +} + +func (c *client) Logout() error { + t := c.token.Load() + if t == nil { + return nil + } + err := c.RevokeToken(t.RefreshToken) + if err != nil { + return err + } + c.mu.Lock() + defer c.mu.Unlock() + c.tokenSource = nil + c.token.Store(nil) + if c.onTokenRefresh != nil { + c.onTokenRefresh(nil) + } + return nil +} + +func (c *client) exchange(form url.Values, pkce string, redirect string) (*oauth2.Token, error) { + if form.Get("error") != "" { + return nil, &AuthorizeError{ + ErrorType: form.Get("error"), + Description: form.Get("error_description"), + } + } + code := form.Get("code") + if code == "" { + return nil, &AuthorizeError{ + ErrorType: "missing_code", + Description: "missing code in response", + } + } + return c.oauth2.Exchange( + context.Background(), + code, + oauth2.VerifierOption(pkce), + oauth2.SetAuthURLParam("redirect_uri", redirect), + ) +} + +func (c *client) accessToken() (string, string, error) { + c.mu.RLock() + defer c.mu.RUnlock() + if c.tokenSource == nil { + return "", "", nil + } + token, err := c.tokenSource.Token() + if err != nil { + e, ok := err.(*oauth2.RetrieveError) + if ok && e.ErrorCode == "invalid_grant" && c.onTokenRefresh != nil { + c.onTokenRefresh(nil) + } + return "", "", err + } + curr := c.token.Load() + if curr != nil && token.AccessToken != curr.AccessToken { + c.token.Store(token) + if c.onTokenRefresh != nil { + c.onTokenRefresh(mapToken(token)) + } + } + return token.AccessToken, token.Type(), nil +} + +// https://datatracker.ietf.org/doc/html/rfc7662#section-2.1 +type IntrospectionResponse struct { + // Required fields + Active bool `json:"active"` + + // Optional fields + Scope string `json:"scope"` + ClientID string `json:"client_id"` + Username string `json:"username"` + TokenType string `json:"token_type"` + Exp int64 `json:"exp"` // Expiration Time. Unix timestamp + Iat int64 `json:"iat"` // Issued At. Unix timestamp + Nbf int64 `json:"nbf"` // Not to be used before. Unix timestamp + Sub string `json:"sub"` // Subject + Aud string `json:"aud"` // Audience + Iss string `json:"iss"` // Issuer + Jti string `json:"jti"` // JWT ID +} + +func (c *client) introspection(token string) (*IntrospectionResponse, error) { + form := url.Values{"token": {token}}.Encode() + req, err := http.NewRequest("POST", c.authURL+"/oauth/token/introspect", strings.NewReader(form)) + if err != nil { + return nil, &AuthorizeError{ + ErrorType: "introspection_failed", + Description: err.Error(), + } + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Content-Length", strconv.Itoa(len(form))) + resp, err := c.http.Do(req) + if err != nil { + return nil, &AuthorizeError{ + ErrorType: "introspection_failed", + Description: err.Error(), + } + } + if resp.StatusCode != http.StatusOK { + err := &AuthorizeError{ + Code: resp.StatusCode, + ErrorType: "introspection_failed", + Description: resp.Status, + } + json.NewDecoder(resp.Body).Decode(err) + return nil, err + } + ires := &IntrospectionResponse{} + err = json.NewDecoder(resp.Body).Decode(ires) + if err != nil { + return nil, &AuthorizeError{ + ErrorType: "introspection_failed", + Description: err.Error(), + } + } + return ires, nil +} + +func (c *client) RevokeToken(token string) error { + if token == "" { + return nil + } + form := url.Values{"token": {token}}.Encode() + req, err := http.NewRequest("POST", c.authURL+"/oauth/token/revoke", strings.NewReader(form)) + if err != nil { + return &AuthorizeError{ + ErrorType: "revoke_failed", + Description: err.Error(), + } + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Content-Length", strconv.Itoa(len(form))) + resp, err := c.http.Do(req) + if err != nil { + return &AuthorizeError{ + ErrorType: "revoke_failed", + Description: err.Error(), + } + } + if resp.StatusCode != http.StatusOK { + err := &AuthorizeError{ + Code: resp.StatusCode, + ErrorType: "revoke_failed", + Description: resp.Status, + } + json.NewDecoder(resp.Body).Decode(err) + return err + } + return nil +} + +func mapToken(t *oauth2.Token) *Token { + return &Token{ + AccessToken: t.AccessToken, + TokenType: t.TokenType, + RefreshToken: t.RefreshToken, + Expiry: t.Expiry, + } +} diff --git a/globalping/auth_test.go b/globalping/auth_test.go new file mode 100644 index 0000000..ddef713 --- /dev/null +++ b/globalping/auth_test.go @@ -0,0 +1,401 @@ +package globalping + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func Test_Authorize(t *testing.T) { + succesCalled := false + expectedRedirectURI := "" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/authorize/error" { + t.Fatalf("unexpected request to %s", r.URL.Path) + return + } + if r.URL.Path == "/authorize/success" { + succesCalled = true + return + } + if r.URL.Path == "/oauth/token" { + if r.Method != http.MethodPost { + t.Fatalf("expected POST request, got %s", r.Method) + } + err := r.ParseForm() + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "", r.Form.Get("client_id")) + assert.Equal(t, "", r.Form.Get("client_secret")) + assert.Equal(t, "authorization_code", r.Form.Get("grant_type")) + assert.Equal(t, "cod3", r.Form.Get("code")) + assert.Equal(t, expectedRedirectURI, r.Form.Get("redirect_uri")) + assert.Equal(t, 43, len(r.Form.Get("code_verifier"))) + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(getTokenJSON()) + if err != nil { + t.Fatal(err) + } + return + } + t.Fatalf("unexpected request to %s", r.URL.Path) + })) + defer server.Close() + client := NewClient(Config{ + AuthClientID: "", + AuthClientSecret: "", + AuthURL: server.URL, + DashboardURL: server.URL, + OnTokenRefresh: func(_token *Token) { + assert.Equal(t, &Token{ + AccessToken: "token", + TokenType: "bearer", + RefreshToken: "refresh", + Expiry: _token.Expiry, + }, _token) + }, + }) + res, err := client.Authorize(func(err error) { + assert.Nil(t, err) + }) + assert.Nil(t, err) + expectedRedirectURI = res.CallbackURL + u, err := url.Parse(res.AuthorizeURL) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, server.URL+"/oauth/authorize", u.Scheme+"://"+u.Host+u.Path) + assert.Equal(t, "", u.Query().Get("client_id")) + assert.Equal(t, 43, len(u.Query().Get("code_challenge"))) + assert.Equal(t, "S256", u.Query().Get("code_challenge_method")) + assert.Equal(t, "code", u.Query().Get("response_type")) + assert.Equal(t, "measurements", u.Query().Get("scope")) + + _, err = http.Post(res.CallbackURL+"?code=cod3", "application/x-www-form-urlencoded", nil) + if err != nil { + t.Fatal(err) + } + + assert.True(t, succesCalled, "/authorize/success not called") + +} + +func Test_TokenIntrospection(t *testing.T) { + now := time.Now() + introspectionRes := &IntrospectionResponse{ + Active: true, + Scope: "measurements", + ClientID: "", + Username: "user", + TokenType: "bearer", + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/oauth/token/introspect" { + if r.Method != http.MethodPost { + t.Fatalf("expected POST request, got %s", r.Method) + } + err := r.ParseForm() + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "tok3n", r.Form.Get("token")) + + w.Header().Set("Content-Type", "application/json") + b, _ := json.Marshal(introspectionRes) + _, err = w.Write(b) + if err != nil { + t.Fatal(err) + } + return + } + t.Fatalf("unexpected request to %s", r.URL.Path) + })) + defer server.Close() + + onTokenRefreshCalled := false + client := NewClient(Config{ + AuthClientID: "", + AuthClientSecret: "", + AuthURL: server.URL, + DashboardURL: server.URL, + AuthToken: &Token{ + AccessToken: "tok3n", + Expiry: now.Add(time.Hour), + }, + OnTokenRefresh: func(_ *Token) { + onTokenRefreshCalled = true + }, + }) + res, err := client.TokenIntrospection("") + assert.Nil(t, err) + assert.Equal(t, introspectionRes, res) + + assert.False(t, onTokenRefreshCalled) +} + +func Test_TokenIntrospection_Token_Refreshed(t *testing.T) { + now := time.Now() + introspectionRes := &IntrospectionResponse{ + Active: true, + Scope: "measurements", + ClientID: "", + Username: "user", + TokenType: "bearer", + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/oauth/token/introspect" { + if r.Method != http.MethodPost { + t.Fatalf("expected POST request, got %s", r.Method) + } + err := r.ParseForm() + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "new_token", r.Form.Get("token")) + + w.Header().Set("Content-Type", "application/json") + b, _ := json.Marshal(introspectionRes) + _, err = w.Write(b) + if err != nil { + t.Fatal(err) + } + return + } + if r.URL.Path == "/oauth/token" { + if r.Method != http.MethodPost { + t.Fatalf("expected POST request, got %s", r.Method) + } + err := r.ParseForm() + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "", r.Form.Get("client_id")) + assert.Equal(t, "", r.Form.Get("client_secret")) + assert.Equal(t, "refresh_token", r.Form.Get("grant_type")) + assert.Equal(t, "refresh_tok3n", r.Form.Get("refresh_token")) + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write([]byte(`{"access_token":"new_token","token_type":"bearer","refresh_token":"new_refresh_token","expires_in":3600}`)) + if err != nil { + t.Fatal(err) + } + return + } + t.Fatalf("unexpected request to %s", r.URL.Path) + })) + defer server.Close() + + var token *Token + client := NewClient(Config{ + AuthClientID: "", + AuthClientSecret: "", + AuthURL: server.URL, + DashboardURL: server.URL, + AuthToken: &Token{ + AccessToken: "tok3n", + RefreshToken: "refresh_tok3n", + Expiry: now.Add(-time.Hour), + }, + OnTokenRefresh: func(_t *Token) { + token = _t + }, + }) + res, err := client.TokenIntrospection("") + assert.Nil(t, err) + assert.Equal(t, introspectionRes, res) + + assert.Equal(t, &Token{ + AccessToken: "new_token", + TokenType: "bearer", + RefreshToken: "new_refresh_token", + Expiry: token.Expiry, + }, token) +} + +func Test_TokenIntrospection_With_Token(t *testing.T) { + now := time.Now() + introspectionRes := &IntrospectionResponse{ + Active: true, + Scope: "measurements", + ClientID: "", + Username: "user", + TokenType: "bearer", + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/oauth/token/introspect" { + if r.Method != http.MethodPost { + t.Fatalf("expected POST request, got %s", r.Method) + } + err := r.ParseForm() + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "tok3n", r.Form.Get("token")) + + w.Header().Set("Content-Type", "application/json") + b, _ := json.Marshal(introspectionRes) + _, err = w.Write(b) + if err != nil { + t.Fatal(err) + } + return + } + t.Fatalf("unexpected request to %s", r.URL.Path) + })) + defer server.Close() + + onTokenRefreshCalled := false + client := NewClient(Config{ + AuthClientID: "", + AuthClientSecret: "", + AuthURL: server.URL, + DashboardURL: server.URL, + AuthToken: &Token{ + AccessToken: "local_token", + Expiry: now.Add(time.Hour), + }, + OnTokenRefresh: func(_ *Token) { + onTokenRefreshCalled = true + }, + }) + res, err := client.TokenIntrospection("tok3n") + assert.Nil(t, err) + assert.Equal(t, introspectionRes, res) + + assert.False(t, onTokenRefreshCalled) +} + +func Test_Logout(t *testing.T) { + isCalled := false + now := time.Now() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + isCalled = true + if r.URL.Path == "/oauth/token/revoke" { + if r.Method != http.MethodPost { + t.Fatalf("expected POST request, got %s", r.Method) + } + err := r.ParseForm() + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "refresh_tok3n", r.Form.Get("token")) + return + } + t.Fatalf("unexpected request to %s", r.URL.Path) + })) + defer server.Close() + + onTokenRefreshCalled := false + client := NewClient(Config{ + AuthClientID: "", + AuthClientSecret: "", + AuthURL: server.URL, + DashboardURL: server.URL, + AuthToken: &Token{ + AccessToken: "tok3n", + RefreshToken: "refresh_tok3n", + Expiry: now.Add(time.Hour), + }, + OnTokenRefresh: func(token *Token) { + onTokenRefreshCalled = true + assert.Nil(t, token) + }, + }) + err := client.Logout() + if err != nil { + t.Fatal(err) + } + assert.True(t, isCalled) + assert.True(t, onTokenRefreshCalled) +} + +func Test_RevokeToken(t *testing.T) { + isCalled := false + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + isCalled = true + if r.URL.Path == "/oauth/token/revoke" { + if r.Method != http.MethodPost { + t.Fatalf("expected POST request, got %s", r.Method) + } + err := r.ParseForm() + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "refresh_tok3n", r.Form.Get("token")) + return + } + t.Fatalf("unexpected request to %s", r.URL.Path) + })) + defer server.Close() + + client := NewClient(Config{ + AuthClientID: "", + AuthClientSecret: "", + AuthURL: server.URL, + DashboardURL: server.URL, + }) + err := client.RevokeToken("refresh_tok3n") + assert.Nil(t, err) + assert.True(t, isCalled) +} + +func Test_Logout_No_RefreshToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatalf("unexpected request to %s", r.URL.Path) + })) + defer server.Close() + + onTokenRefreshCalled := false + client := NewClient(Config{ + AuthClientID: "", + AuthClientSecret: "", + AuthURL: server.URL, + DashboardURL: server.URL, + AuthToken: &Token{ + AccessToken: "tok3n", + }, + OnTokenRefresh: func(token *Token) { + onTokenRefreshCalled = true + assert.Nil(t, token) + }, + }) + err := client.Logout() + if err != nil { + t.Fatal(err) + } + assert.True(t, onTokenRefreshCalled) +} + +func Test_Logout_AccessToken_Is_Set(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatalf("unexpected request to %s", r.URL.Path) + })) + defer server.Close() + + onTokenRefreshCalled := false + client := NewClient(Config{ + AuthClientID: "", + AuthClientSecret: "", + AuthURL: server.URL, + DashboardURL: server.URL, + AuthAccessToken: "tok3n", + OnTokenRefresh: func(token *Token) { + onTokenRefreshCalled = true + }, + }) + err := client.Logout() + if err != nil { + t.Fatal(err) + } + assert.False(t, onTokenRefreshCalled) +} diff --git a/globalping/client.go b/globalping/client.go index 4468a72..79ff811 100644 --- a/globalping/client.go +++ b/globalping/client.go @@ -1,9 +1,13 @@ package globalping import ( + "context" "net/http" "sync" + "sync/atomic" "time" + + "golang.org/x/oauth2" ) type Client interface { @@ -19,11 +23,36 @@ type Client interface { // // https://www.jsdelivr.com/docs/api.globalping.io#get-/v1/measurements/-id- GetMeasurementRaw(id string) ([]byte, error) + // Returns a link to be used for authorization and listens for the authorization callback. + // + // onTokenRefresh will be called if the authorization is successful. + Authorize(callback func(error)) (*AuthorizeResponse, error) + // Returns the introspection response for the token. + // + // If the token is empty, the client's current token will be used. + TokenIntrospection(token string) (*IntrospectionResponse, error) + // Removes the current token from the client. It also revokes the tokens if the refresh token is available. + // + // onTokenRefresh will be called if the token is successfully removed. + Logout() error + + // Revokes the token. + RevokeToken(token string) error } type Config struct { - APIURL string - APIToken string + HTTPClient *http.Client // If set, this client will be used for API requests and authorization + + APIURL string + DashboardURL string + + AuthURL string + AuthClientID string + AuthClientSecret string + AuthAccessToken string // If set, this token will be used for API requests + AuthToken *Token + OnTokenRefresh func(*Token) + UserAgent string } @@ -34,12 +63,18 @@ type CacheEntry struct { } type client struct { - sync.RWMutex + mu sync.RWMutex http *http.Client cache map[string]*CacheEntry + oauth2 *oauth2.Config + token atomic.Pointer[oauth2.Token] + tokenSource oauth2.TokenSource + onTokenRefresh func(*Token) + apiURL string - apiToken string + authURL string + dashboardURL string apiResponseCacheExpireSeconds int64 userAgent string } @@ -48,15 +83,50 @@ type client struct { // The client will not have a cache cleanup goroutine, therefore cached responses will never be removed. // If you want a cache cleanup goroutine, use NewClientWithCacheCleanup. func NewClient(config Config) Client { - return &client{ - http: &http.Client{ - Timeout: 30 * time.Second, + c := &client{ + mu: sync.RWMutex{}, + oauth2: &oauth2.Config{ + ClientID: config.AuthClientID, + ClientSecret: config.AuthClientSecret, + Scopes: []string{"measurements"}, + Endpoint: oauth2.Endpoint{ + AuthURL: config.AuthURL + "/oauth/authorize", + TokenURL: config.AuthURL + "/oauth/token", + AuthStyle: oauth2.AuthStyleInParams, + }, }, - apiURL: config.APIURL, - apiToken: config.APIToken, - userAgent: config.UserAgent, - cache: map[string]*CacheEntry{}, + onTokenRefresh: config.OnTokenRefresh, + apiURL: config.APIURL, + authURL: config.AuthURL, + dashboardURL: config.DashboardURL, + userAgent: config.UserAgent, + cache: map[string]*CacheEntry{}, + } + if config.HTTPClient != nil { + c.http = config.HTTPClient + } else { + c.http = &http.Client{ + Timeout: 30 * time.Second, + } } + if config.AuthAccessToken != "" { + c.tokenSource = oauth2.StaticTokenSource(&oauth2.Token{AccessToken: config.AuthAccessToken}) + } else if config.AuthToken != nil { + t := &oauth2.Token{ + AccessToken: config.AuthToken.AccessToken, + TokenType: config.AuthToken.TokenType, + RefreshToken: config.AuthToken.RefreshToken, + Expiry: config.AuthToken.Expiry, + } + c.token.Store(t) + if config.AuthToken.RefreshToken == "" { + c.tokenSource = oauth2.StaticTokenSource(&oauth2.Token{AccessToken: config.AuthToken.AccessToken}) + } else { + ctx := context.WithValue(context.Background(), oauth2.HTTPClient, c.http) + c.tokenSource = c.oauth2.TokenSource(ctx, t) + } + } + return c } // NewClientWithCacheCleanup creates a new client with a cache cleanup goroutine that runs every t. @@ -74,8 +144,8 @@ func NewClientWithCacheCleanup(config Config, t *time.Ticker, cacheExpireSeconds } func (c *client) getETag(id string) string { - c.RLock() - defer c.RUnlock() + c.mu.RLock() + defer c.mu.RUnlock() e, ok := c.cache[id] if !ok { return "" @@ -84,8 +154,8 @@ func (c *client) getETag(id string) string { } func (c *client) getCachedResponse(id string) []byte { - c.RLock() - defer c.RUnlock() + c.mu.RLock() + defer c.mu.RUnlock() e, ok := c.cache[id] if !ok { return nil @@ -94,8 +164,8 @@ func (c *client) getCachedResponse(id string) []byte { } func (c *client) cacheResponse(id string, etag string, resp []byte) { - c.Lock() - defer c.Unlock() + c.mu.Lock() + defer c.mu.Unlock() var expires int64 if c.apiResponseCacheExpireSeconds != 0 { expires = time.Now().Unix() + c.apiResponseCacheExpireSeconds @@ -115,8 +185,8 @@ func (c *client) cacheResponse(id string, etag string, resp []byte) { } func (c *client) cleanupCache() { - c.Lock() - defer c.Unlock() + c.mu.Lock() + defer c.mu.Unlock() now := time.Now().Unix() for k, v := range c.cache { if v.ExpireAt > 0 && v.ExpireAt < now { diff --git a/globalping/globalping.go b/globalping/measurements.go similarity index 93% rename from globalping/globalping.go rename to globalping/measurements.go index 7fed187..b260c5e 100644 --- a/globalping/globalping.go +++ b/globalping/measurements.go @@ -19,6 +19,10 @@ var ( noCreditsAuthErr = "You have run out of credits for this session. You can wait %s for the rate limit to reset or get higher limits by sponsoring us or hosting probes." ) +var ( + StatusUnauthorizedWithTokenRefreshed = 1000 +) + func (c *client) CreateMeasurement(measurement *MeasurementCreate) (*MeasurementCreateResponse, error) { postData, err := json.Marshal(measurement) if err != nil { @@ -33,8 +37,12 @@ func (c *client) CreateMeasurement(measurement *MeasurementCreate) (*Measurement req.Header.Set("Accept-Encoding", "br") req.Header.Set("Content-Type", "application/json") - if c.apiToken != "" { - req.Header.Set("Authorization", "Bearer "+c.apiToken) + token, tokenType, err := c.accessToken() + if err != nil { + return nil, &MeasurementError{Message: "failed to get token: " + err.Error()} + } + if token != "" { + req.Header.Set("Authorization", tokenType+" "+token) } resp, err := c.http.Do(req) @@ -44,14 +52,13 @@ func (c *client) CreateMeasurement(measurement *MeasurementCreate) (*Measurement defer resp.Body.Close() if resp.StatusCode != http.StatusAccepted { - var data MeasurementCreateError + var data MeasurementErrorResponse err = json.NewDecoder(resp.Body).Decode(&data) if err != nil { return nil, &MeasurementError{Message: "invalid error format returned - please report this bug"} } - err := &MeasurementError{ - Code: resp.StatusCode, - } + err := data.Error + err.Code = resp.StatusCode if resp.StatusCode == http.StatusBadRequest { resErr := "" for _, v := range data.Error.Params { @@ -65,8 +72,12 @@ func (c *client) CreateMeasurement(measurement *MeasurementCreate) (*Measurement return nil, err } - if resp.StatusCode == http.StatusUnauthorized { - err.Message = fmt.Sprintf("unauthorized: %s", data.Error.Message) + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + token, _, e := c.accessToken() + if e == nil && token != "" { + err.Code = StatusUnauthorizedWithTokenRefreshed + } + err.Message = "unauthorized: " + data.Error.Message return nil, err } @@ -81,7 +92,7 @@ func (c *client) CreateMeasurement(measurement *MeasurementCreate) (*Measurement creditsRemaining, _ := strconv.ParseInt(resp.Header.Get("X-Credits-Remaining"), 10, 64) requestCost, _ := strconv.ParseInt(resp.Header.Get("X-Request-Cost"), 10, 64) remaining := rateLimitRemaining + creditsRemaining - if c.apiToken == "" { + if token == "" { if remaining > 0 { err.Message = fmt.Sprintf(moreCreditsRequiredNoAuthErr, utils.Pluralize(remaining, "credit"), requestCost, utils.FormatSeconds(rateLimitReset)) return nil, err diff --git a/globalping/globalping_test.go b/globalping/measurements_test.go similarity index 90% rename from globalping/globalping_test.go rename to globalping/measurements_test.go index 152694e..8836b80 100644 --- a/globalping/globalping_test.go +++ b/globalping/measurements_test.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" "net/http/httptest" - "os" "strings" "testing" @@ -14,30 +13,7 @@ import ( "github.com/stretchr/testify/assert" ) -// PostAPI tests -func TestPostAPI(t *testing.T) { - // Suppress error outputs - os.Stdout, _ = os.Open(os.DevNull) - for scenario, fn := range map[string]func(t *testing.T){ - "valid": testPostValid, - "authorized": testPostAuthorized, - "auth_error": testPostAuthorizedError, - "more_credits_no_auth_error": testPostMoreCreditsRequiredNoAuthError, - "more_credits_auth_error": testPostMoreCreditsRequiredAuthError, - "no_credits_no_auth_error": testPostNoCreditsNoAuthError, - "no_credits_auth_error": testPostNoCreditsAuthError, - "no_probes": testPostNoProbes, - "validation": testPostValidation, - "api_error": testPostInternalError, - } { - t.Run(scenario, func(t *testing.T) { - fn(t) - }) - } -} - -// Test a valid call of PostAPI -func testPostValid(t *testing.T) { +func Test_CreateMeasurement_Valid(t *testing.T) { server := generateServer(`{"id":"abcd","probesCount":1}`, http.StatusAccepted) defer server.Close() client := NewClient(Config{APIURL: server.URL}) @@ -50,12 +26,12 @@ func testPostValid(t *testing.T) { assert.NoError(t, err) } -func testPostAuthorized(t *testing.T) { +func Test_CreateMeasurement_Authorized(t *testing.T) { server := generateServerAuthorized(`{"id":"abcd","probesCount":1}`) defer server.Close() client := NewClient(Config{ - APIToken: "secret", - APIURL: server.URL, + AuthAccessToken: "secret", + APIURL: server.URL, }) opts := &MeasurementCreate{} @@ -66,7 +42,7 @@ func testPostAuthorized(t *testing.T) { assert.NoError(t, err) } -func testPostAuthorizedError(t *testing.T) { +func Test_CreateMeasurement_AuthorizedError(t *testing.T) { server := generateServerAuthorized(`{"id":"abcd","probesCount":1}`) defer server.Close() client := NewClient(Config{ @@ -80,7 +56,7 @@ func testPostAuthorizedError(t *testing.T) { assert.EqualError(t, err, "unauthorized: Unauthorized.") } -func testPostMoreCreditsRequiredNoAuthError(t *testing.T) { +func Test_CreateMeasurement_MoreCreditsRequiredNoAuthError(t *testing.T) { rateLimitReset := "61" server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-RateLimit-Remaining", "1") @@ -109,7 +85,7 @@ func testPostMoreCreditsRequiredNoAuthError(t *testing.T) { assert.EqualError(t, err, fmt.Sprintf(moreCreditsRequiredNoAuthErr, "2 credits", 3, "2 minutes")) } -func testPostMoreCreditsRequiredAuthError(t *testing.T) { +func Test_CreateMeasurement_MoreCreditsRequiredAuthError(t *testing.T) { rateLimitReset := "40" server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-RateLimit-Remaining", "0") @@ -129,8 +105,8 @@ func testPostMoreCreditsRequiredAuthError(t *testing.T) { defer server.Close() client := NewClient(Config{ - APIToken: "secret", - APIURL: server.URL, + AuthAccessToken: "secret", + APIURL: server.URL, }) opts := &MeasurementCreate{} @@ -142,7 +118,7 @@ func testPostMoreCreditsRequiredAuthError(t *testing.T) { assert.EqualError(t, err, fmt.Sprintf(moreCreditsRequiredAuthErr, "1 credit", 2, "1 second")) } -func testPostNoCreditsNoAuthError(t *testing.T) { +func Test_CreateMeasurement_NoCreditsNoAuthError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-RateLimit-Remaining", "0") w.Header().Set("X-RateLimit-Reset", "5") @@ -166,7 +142,7 @@ func testPostNoCreditsNoAuthError(t *testing.T) { assert.EqualError(t, err, fmt.Sprintf(noCreditsNoAuthErr, "5 seconds")) } -func testPostNoCreditsAuthError(t *testing.T) { +func Test_CreateMeasurement_NoCreditsAuthError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-RateLimit-Remaining", "0") w.Header().Set("X-RateLimit-Reset", "5") @@ -184,8 +160,8 @@ func testPostNoCreditsAuthError(t *testing.T) { defer server.Close() client := NewClient(Config{ - APIToken: "secret", - APIURL: server.URL, + AuthAccessToken: "secret", + APIURL: server.URL, }) opts := &MeasurementCreate{} _, err := client.CreateMeasurement(opts) @@ -193,7 +169,7 @@ func testPostNoCreditsAuthError(t *testing.T) { assert.EqualError(t, err, fmt.Sprintf(noCreditsAuthErr, "5 seconds")) } -func testPostNoProbes(t *testing.T) { +func Test_CreateMeasurement_NoProbes(t *testing.T) { server := generateServer(`{ "error": { "message": "No suitable probes found", @@ -207,11 +183,12 @@ func testPostNoProbes(t *testing.T) { assert.Equal(t, &MeasurementError{ Code: 422, + Type: "no_probes_found", Message: "no suitable probes found - please choose a different location", }, err) } -func testPostValidation(t *testing.T) { +func Test_CreateMeasurement_Validation(t *testing.T) { server := generateServer(`{ "error": { "message": "Validation Failed", @@ -228,12 +205,16 @@ func testPostValidation(t *testing.T) { assert.Equal(t, &MeasurementError{ Code: 400, + Type: "validation_error", Message: `invalid parameters - "target" does not match any of the allowed types`, + Params: map[string]interface{}{ + "target": "\"target\" does not match any of the allowed types", + }, }, err) } -func testPostInternalError(t *testing.T) { +func Test_CreateMeasurement_InternalError(t *testing.T) { server := generateServer(`{ "error": { "message": "Internal Server Error", @@ -247,24 +228,7 @@ func testPostInternalError(t *testing.T) { assert.EqualError(t, err, "internal server error - please try again later") } -// GetAPI tests -func TestGetAPI(t *testing.T) { - for scenario, fn := range map[string]func(t *testing.T){ - "valid": testGetValid, - "json": testGetJson, - "ping": testGetPing, - "traceroute": testGetTraceroute, - "dns": testGetDns, - "mtr": testGetMtr, - "http": testGetHttp, - } { - t.Run(scenario, func(t *testing.T) { - fn(t) - }) - } -} - -func testGetValid(t *testing.T) { +func Test_GetMeasurement_Valid(t *testing.T) { server := generateServer(`{"id":"abcd"}`, http.StatusOK) defer server.Close() client := NewClient(Config{APIURL: server.URL}) @@ -275,19 +239,7 @@ func testGetValid(t *testing.T) { assert.Equal(t, "abcd", res.ID) } -func testGetJson(t *testing.T) { - server := generateServer(`{"id":"abcd"}`, http.StatusOK) - defer server.Close() - client := NewClient(Config{APIURL: server.URL}) - res, err := client.GetMeasurementRaw("abcd") - if err != nil { - t.Error(err) - } - - assert.Equal(t, `{"id":"abcd"}`, string(res)) -} - -func testGetPing(t *testing.T) { +func Test_GetMeasurement_Ping(t *testing.T) { server := generateServer(`{ "id": "abcd", "type": "ping", @@ -368,7 +320,7 @@ func testGetPing(t *testing.T) { assert.Equal(t, float64(0), stats.Loss) } -func testGetTraceroute(t *testing.T) { +func Test_GetMeasurement_Traceroute(t *testing.T) { server := generateServer(`{ "id": "abcd", "type": "traceroute", @@ -456,7 +408,7 @@ func testGetTraceroute(t *testing.T) { assert.Equal(t, "1.1.1.1", res.Results[0].Result.ResolvedHostname) } -func testGetDns(t *testing.T) { +func Test_GetMeasurement_Dns(t *testing.T) { server := generateServer(`{ "id": "abcd", "type": "dns", @@ -536,7 +488,7 @@ func testGetDns(t *testing.T) { assert.Equal(t, float64(15), timings.Total) } -func testGetMtr(t *testing.T) { +func Test_GetMeasurement_Mtr(t *testing.T) { server := generateServer(`{ "id": "abcd", "type": "mtr", @@ -654,7 +606,7 @@ func testGetMtr(t *testing.T) { assert.IsType(t, json.RawMessage{}, res.Results[0].Result.TimingsRaw) } -func testGetHttp(t *testing.T) { +func Test_GetMeasurement_Http(t *testing.T) { server := generateServer(`{ "id": "abcd", "type": "http", @@ -768,7 +720,7 @@ func testGetHttp(t *testing.T) { assert.Equal(t, 19, timings.TCP) } -func TestFetchWithEtag(t *testing.T) { +func Test_GetMeasurement_WithEtag(t *testing.T) { id1 := "123abc" id2 := "567xyz" @@ -828,7 +780,7 @@ func TestFetchWithEtag(t *testing.T) { assert.Equal(t, 2, cacheMissCount) } -func TestFetchWithBrotli(t *testing.T) { +func Test_GetMeasurement_WithBrotli(t *testing.T) { id := "123abc" s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -858,7 +810,18 @@ func TestFetchWithBrotli(t *testing.T) { assert.Equal(t, id, m.ID) } -// Generate server for testing +func Test_GetMeasurementRaw_Json(t *testing.T) { + server := generateServer(`{"id":"abcd"}`, http.StatusOK) + defer server.Close() + client := NewClient(Config{APIURL: server.URL}) + res, err := client.GetMeasurementRaw("abcd") + if err != nil { + t.Error(err) + } + + assert.Equal(t, `{"id":"abcd"}`, string(res)) +} + func generateServer(json string, statusCode int) *httptest.Server { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(statusCode) diff --git a/globalping/models.go b/globalping/models.go index 0fbf9ab..52e51ec 100644 --- a/globalping/models.go +++ b/globalping/models.go @@ -51,27 +51,25 @@ type MeasurementCreate struct { } type MeasurementError struct { - Code int - Message string + Code int `json:"-"` + Message string `json:"message"` + Type string `json:"type"` + Params map[string]interface{} `json:"params,omitempty"` } func (e *MeasurementError) Error() string { return e.Message } +type MeasurementErrorResponse struct { + Error *MeasurementError `json:"error"` +} + type MeasurementCreateResponse struct { ID string `json:"id"` ProbesCount int `json:"probesCount"` } -type MeasurementCreateError struct { - Error struct { - Message string `json:"message"` - Type string `json:"type"` - Params map[string]interface{} `json:"params,omitempty"` - } `json:"error"` -} - type ProbeDetails struct { Continent string `json:"continent"` Region string `json:"region"` diff --git a/globalping/utils_test.go b/globalping/utils_test.go new file mode 100644 index 0000000..8da567c --- /dev/null +++ b/globalping/utils_test.go @@ -0,0 +1,10 @@ +package globalping + +func getTokenJSON() []byte { + return []byte(`{ +"access_token":"token", +"token_type":"bearer", +"refresh_token":"refresh", +"expires_in": 3600 +}`) +} diff --git a/go.mod b/go.mod index 16313e9..4741805 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.9.0 go.uber.org/mock v0.4.0 + golang.org/x/oauth2 v0.23.0 golang.org/x/term v0.18.0 ) diff --git a/go.sum b/go.sum index 9f59a1a..f1416ac 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,8 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/icza/backscanner v0.0.0-20240221180818-f23e3ba0e79f h1:EKPpaKkARuHjoV/ZKzk3vqbSJXULRSivDCQhL+tF77Y= github.com/icza/backscanner v0.0.0-20240221180818-f23e3ba0e79f/go.mod h1:GYeBD1CF7AqnKZK+UCytLcY3G+UKo0ByXX/3xfdNyqQ= github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6 h1:8UsGZ2rr2ksmEru6lToqnXgA8Mz1DP11X4zSJ159C3k= @@ -49,6 +51,8 @@ github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= +golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= +golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= diff --git a/mocks/gen_mocks.sh b/mocks/gen_mocks.sh index 8452e6f..4aee6ff 100755 --- a/mocks/gen_mocks.sh +++ b/mocks/gen_mocks.sh @@ -3,4 +3,4 @@ rm -rf mocks/mock_*.go bin/mockgen -source globalping/client.go -destination mocks/mock_client.go -package mocks bin/mockgen -source globalping/probe/probe.go -destination mocks/mock_probe.go -package mocks bin/mockgen -source view/viewer.go -destination mocks/mock_viewer.go -package mocks -bin/mockgen -source utils/time.go -destination mocks/mock_time.go -package mocks +bin/mockgen -source utils/utils.go -destination mocks/mock_utils.go -package mocks diff --git a/mocks/mock_client.go b/mocks/mock_client.go index e7849e7..877949d 100644 --- a/mocks/mock_client.go +++ b/mocks/mock_client.go @@ -39,6 +39,21 @@ func (m *MockClient) EXPECT() *MockClientMockRecorder { return m.recorder } +// Authorize mocks base method. +func (m *MockClient) Authorize(callback func(error)) (*globalping.AuthorizeResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Authorize", callback) + ret0, _ := ret[0].(*globalping.AuthorizeResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Authorize indicates an expected call of Authorize. +func (mr *MockClientMockRecorder) Authorize(callback any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Authorize", reflect.TypeOf((*MockClient)(nil).Authorize), callback) +} + // CreateMeasurement mocks base method. func (m *MockClient) CreateMeasurement(measurement *globalping.MeasurementCreate) (*globalping.MeasurementCreateResponse, error) { m.ctrl.T.Helper() @@ -83,3 +98,46 @@ func (mr *MockClientMockRecorder) GetMeasurementRaw(id any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMeasurementRaw", reflect.TypeOf((*MockClient)(nil).GetMeasurementRaw), id) } + +// Logout mocks base method. +func (m *MockClient) Logout() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Logout") + ret0, _ := ret[0].(error) + return ret0 +} + +// Logout indicates an expected call of Logout. +func (mr *MockClientMockRecorder) Logout() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logout", reflect.TypeOf((*MockClient)(nil).Logout)) +} + +// RevokeToken mocks base method. +func (m *MockClient) RevokeToken(token string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RevokeToken", token) + ret0, _ := ret[0].(error) + return ret0 +} + +// RevokeToken indicates an expected call of RevokeToken. +func (mr *MockClientMockRecorder) RevokeToken(token any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeToken", reflect.TypeOf((*MockClient)(nil).RevokeToken), token) +} + +// TokenIntrospection mocks base method. +func (m *MockClient) TokenIntrospection(token string) (*globalping.IntrospectionResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TokenIntrospection", token) + ret0, _ := ret[0].(*globalping.IntrospectionResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// TokenIntrospection indicates an expected call of TokenIntrospection. +func (mr *MockClientMockRecorder) TokenIntrospection(token any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenIntrospection", reflect.TypeOf((*MockClient)(nil).TokenIntrospection), token) +} diff --git a/mocks/mock_time.go b/mocks/mock_time.go deleted file mode 100644 index 7d5fa74..0000000 --- a/mocks/mock_time.go +++ /dev/null @@ -1,54 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: utils/time.go -// -// Generated by this command: -// -// mockgen -source utils/time.go -destination mocks/mock_time.go -package mocks -// - -// Package mocks is a generated GoMock package. -package mocks - -import ( - reflect "reflect" - time "time" - - gomock "go.uber.org/mock/gomock" -) - -// MockTime is a mock of Time interface. -type MockTime struct { - ctrl *gomock.Controller - recorder *MockTimeMockRecorder -} - -// MockTimeMockRecorder is the mock recorder for MockTime. -type MockTimeMockRecorder struct { - mock *MockTime -} - -// NewMockTime creates a new mock instance. -func NewMockTime(ctrl *gomock.Controller) *MockTime { - mock := &MockTime{ctrl: ctrl} - mock.recorder = &MockTimeMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockTime) EXPECT() *MockTimeMockRecorder { - return m.recorder -} - -// Now mocks base method. -func (m *MockTime) Now() time.Time { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Now") - ret0, _ := ret[0].(time.Time) - return ret0 -} - -// Now indicates an expected call of Now. -func (mr *MockTimeMockRecorder) Now() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Now", reflect.TypeOf((*MockTime)(nil).Now)) -} diff --git a/mocks/mock_utils.go b/mocks/mock_utils.go new file mode 100644 index 0000000..71f7cf6 --- /dev/null +++ b/mocks/mock_utils.go @@ -0,0 +1,68 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: utils/utils.go +// +// Generated by this command: +// +// mockgen -source utils/utils.go -destination mocks/mock_utils.go -package mocks +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + time "time" + + gomock "go.uber.org/mock/gomock" +) + +// MockUtils is a mock of Utils interface. +type MockUtils struct { + ctrl *gomock.Controller + recorder *MockUtilsMockRecorder +} + +// MockUtilsMockRecorder is the mock recorder for MockUtils. +type MockUtilsMockRecorder struct { + mock *MockUtils +} + +// NewMockUtils creates a new mock instance. +func NewMockUtils(ctrl *gomock.Controller) *MockUtils { + mock := &MockUtils{ctrl: ctrl} + mock.recorder = &MockUtilsMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockUtils) EXPECT() *MockUtilsMockRecorder { + return m.recorder +} + +// Now mocks base method. +func (m *MockUtils) Now() time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Now") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// Now indicates an expected call of Now. +func (mr *MockUtilsMockRecorder) Now() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Now", reflect.TypeOf((*MockUtils)(nil).Now)) +} + +// OpenBrowser mocks base method. +func (m *MockUtils) OpenBrowser(url string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenBrowser", url) + ret0, _ := ret[0].(error) + return ret0 +} + +// OpenBrowser indicates an expected call of OpenBrowser. +func (mr *MockUtilsMockRecorder) OpenBrowser(url any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenBrowser", reflect.TypeOf((*MockUtils)(nil).OpenBrowser), url) +} diff --git a/storage/storage.go b/storage/storage.go new file mode 100644 index 0000000..68273fa --- /dev/null +++ b/storage/storage.go @@ -0,0 +1,116 @@ +package storage + +import ( + "encoding/json" + "os" + "path" + + "github.com/jsdelivr/globalping-cli/globalping" +) + +type LocalStorage struct { + name string + configName string + config *Config +} + +func NewLocalStorage(name string) *LocalStorage { + return &LocalStorage{ + name: name, + configName: "config.json", + } +} + +func (s *LocalStorage) Init() error { + homeDir, err := s.joinHomeDir("") + if err != nil { + return err + } + err = os.MkdirAll(homeDir, 0755) + if err != nil { + return err + } + _, err = s.LoadConfig() + if err != nil { + if os.IsNotExist(err) { + s.config = &Config{ + Profile: "default", + Profiles: make(map[string]*Profile), + } + s.SaveConfig() + } + } + return nil +} + +type Profile struct { + Token *globalping.Token `json:"token"` +} + +type Config struct { + Profile string `json:"profile"` + Profiles map[string]*Profile `json:"profiles"` +} + +func (s *LocalStorage) LoadConfig() (*Config, error) { + if s.config != nil { + return s.config, nil + } + path, err := s.joinHomeDir(s.configName) + if err != nil { + return nil, err + } + b, err := os.ReadFile(path) + if err != nil { + return nil, err + } + s.config = &Config{ + Profile: "default", + Profiles: make(map[string]*Profile), + } + err = json.Unmarshal(b, s.config) + if err != nil { + return nil, err + } + return s.config, nil +} + +func (s *LocalStorage) SaveConfig() error { + if s.config == nil { + return nil + } + path, err := s.joinHomeDir(s.configName) + if err != nil { + return err + } + b, err := json.Marshal(s.config) + if err != nil { + return err + } + return os.WriteFile(path, b, 0644) +} + +func (s *LocalStorage) GetProfile() *Profile { + p := s.config.Profiles[s.config.Profile] + if p == nil { + p = &Profile{} + s.config.Profiles[s.config.Profile] = p + } + return p +} + +func (s *LocalStorage) Remove() error { + homeDir, err := s.joinHomeDir("") + if err != nil { + return err + } + return os.RemoveAll(homeDir) +} + +func (s *LocalStorage) joinHomeDir(name string) (string, error) { + dir, err := os.UserHomeDir() + if err != nil { + return "", err + } + return path.Join(dir, s.name, name), nil +} diff --git a/storage/storage_test.go b/storage/storage_test.go new file mode 100644 index 0000000..59bbd3a --- /dev/null +++ b/storage/storage_test.go @@ -0,0 +1,59 @@ +package storage + +import ( + "encoding/json" + "os" + "testing" + "time" + + "github.com/jsdelivr/globalping-cli/globalping" + "github.com/stretchr/testify/assert" +) + +func Test_Config(t *testing.T) { + _storage := NewLocalStorage(".test_globalping-cli") + defer _storage.Remove() + err := _storage.Init() + if err != nil { + t.Fatal(err) + } + config, err := _storage.LoadConfig() + if err != nil { + t.Fatal(err) + } + assert.Equal(t, &Config{ + Profile: "default", + Profiles: make(map[string]*Profile), + }, config) + + profile := _storage.GetProfile() + profile.Token = &globalping.Token{ + AccessToken: "token", + RefreshToken: "refresh", + TokenType: "bearer", + Expiry: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + } + err = _storage.SaveConfig() + if err != nil { + t.Fatal(err) + } + path, err := _storage.joinHomeDir(_storage.configName) + if err != nil { + t.Fatal(err) + } + b, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + c := &Config{} + err = json.Unmarshal(b, c) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, &Config{ + Profile: "default", + Profiles: map[string]*Profile{ + "default": {Token: profile.Token}, + }, + }, c) +} diff --git a/utils/config.go b/utils/config.go index 5d4de6a..e6e3a1d 100644 --- a/utils/config.go +++ b/utils/config.go @@ -6,15 +6,23 @@ import ( ) type Config struct { - GlobalpingToken string - GlobalpingAPIURL string - GlobalpingAPIInterval _time.Duration + GlobalpingToken string + GlobalpingAPIURL string + GlobalpingAuthURL string + GlobalpingDashboardURL string + GlobalpingAuthClientID string + GlobalpingAuthClientSecret string + GlobalpingAPIInterval _time.Duration } func NewConfig() *Config { return &Config{ - GlobalpingAPIURL: "https://api.globalping.io/v1", - GlobalpingAPIInterval: 500 * _time.Millisecond, + GlobalpingAPIURL: "https://api.globalping.io/v1", + GlobalpingAuthURL: "https://auth.globalping.io", + GlobalpingDashboardURL: "https://dash.globalping.io", + GlobalpingAuthClientID: "be231712-03f4-45bf-9f15-023506ce0b72", + GlobalpingAuthClientSecret: "public", + GlobalpingAPIInterval: 500 * _time.Millisecond, } } diff --git a/utils/time.go b/utils/time.go deleted file mode 100644 index cb4e4b3..0000000 --- a/utils/time.go +++ /dev/null @@ -1,33 +0,0 @@ -package utils - -import ( - "math" - _time "time" -) - -type Time interface { - Now() _time.Time -} - -type time struct{} - -func NewTime() Time { - return &time{} -} - -func (d *time) Now() _time.Time { - return _time.Now() -} - -func FormatSeconds(seconds int64) string { - if seconds < 60 { - return Pluralize(seconds, "second") - } - if seconds < 3600 { - return Pluralize(int64(math.Round(float64(seconds)/60)), "minute") - } - if seconds < 86400 { - return Pluralize(int64(math.Round(float64(seconds)/3600)), "hour") - } - return Pluralize(int64(math.Round(float64(seconds)/86400)), "day") -} diff --git a/utils/utils.go b/utils/utils.go new file mode 100644 index 0000000..ec7d100 --- /dev/null +++ b/utils/utils.go @@ -0,0 +1,55 @@ +package utils + +import ( + "errors" + "math" + "os/exec" + "runtime" + _time "time" +) + +type Utils interface { + Now() _time.Time + OpenBrowser(url string) error +} + +type utils struct{} + +func NewUtils() Utils { + return &utils{} +} + +func (u *utils) Now() _time.Time { + return _time.Now() +} + +func (u *utils) OpenBrowser(url string) error { + switch runtime.GOOS { + case "linux": + // WSL workaround + err := exec.Command("rundll32.exe", "url.dll,FileProtocolHandler", url).Start() + if err != nil { + return exec.Command("xdg-open", url).Start() + } + return nil + case "windows": + return exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() + case "darwin": + return exec.Command("open", url).Start() + default: + return errors.New("unsupported platform") + } +} + +func FormatSeconds(seconds int64) string { + if seconds < 60 { + return Pluralize(seconds, "second") + } + if seconds < 3600 { + return Pluralize(int64(math.Round(float64(seconds)/60)), "minute") + } + if seconds < 86400 { + return Pluralize(int64(math.Round(float64(seconds)/3600)), "hour") + } + return Pluralize(int64(math.Round(float64(seconds)/86400)), "day") +} diff --git a/view/infinite.go b/view/infinite.go index 693c80f..153ad7d 100644 --- a/view/infinite.go +++ b/view/infinite.go @@ -368,7 +368,7 @@ func (v *viewer) parsePingRawOutput( res.Stats.Time, _ = strconv.ParseFloat(words[9][:len(words[9])-2], 64) } } else { - res.Stats.Time = float64(v.time.Now().Sub(hm.StartedAt).Milliseconds()) + res.Stats.Time = float64(v.utils.Now().Sub(hm.StartedAt).Milliseconds()) } if res.Stats.Sent > 0 { res.Stats.Lost = res.Stats.Sent - res.Stats.Rcv @@ -402,7 +402,7 @@ func (v *viewer) getAPICreditConsumptionInfo(width int) string { return apiCreditLastConsumptionInfo } apiCreditLastMeasurementCount = v.ctx.MeasurementsCreated - elapsedMinutes := v.time.Now().Sub(v.ctx.RunSessionStartedAt).Minutes() + elapsedMinutes := v.utils.Now().Sub(v.ctx.RunSessionStartedAt).Minutes() consumption := int64(math.Ceil(float64((apiCreditLastMeasurementCount-1)*(len(v.ctx.AggregatedStats))) / elapsedMinutes)) info := fmt.Sprintf(apiCreditConsumptionInfo, utils.Pluralize(consumption, "API credit")) if len(info) > width-4 { diff --git a/view/infinite_test.go b/view/infinite_test.go index 04160da..909b4ca 100644 --- a/view/infinite_test.go +++ b/view/infinite_test.go @@ -16,8 +16,8 @@ func Test_OutputInfinite_SingleProbe_InProgress(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime.Add(500 * time.Millisecond)).Times(3) + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime.Add(500 * time.Millisecond)).Times(3) ctx := createDefaultContext("ping") hm := ctx.History.Find(measurementID1) @@ -25,7 +25,7 @@ func Test_OutputInfinite_SingleProbe_InProgress(t *testing.T) { errW := new(bytes.Buffer) printer := NewPrinter(nil, w, errW) printer.DisableStyling() - viewer := NewViewer(ctx, printer, timeMock, nil) + viewer := NewViewer(ctx, printer, utilsMock, nil) measurement := createPingMeasurement(measurementID1) measurement.Status = globalping.StatusInProgress @@ -160,8 +160,8 @@ func Test_OutputInfinite_MultipleProbes_MultipleCalls(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime.Add(500 * time.Millisecond)).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime.Add(500 * time.Millisecond)).AnyTimes() measurement := createPingMeasurement_MultipleProbes(measurementID1) measurement.Status = globalping.StatusInProgress @@ -172,7 +172,7 @@ func Test_OutputInfinite_MultipleProbes_MultipleCalls(t *testing.T) { w := new(bytes.Buffer) printer := NewPrinter(nil, w, w) printer.DisableStyling() - viewer := NewViewer(ctx, printer, timeMock, nil) + viewer := NewViewer(ctx, printer, utilsMock, nil) // Call 1 expectedOutput := `Location | Sent | Loss | Last | Min | Avg | Max @@ -277,8 +277,8 @@ func Test_OutputInfinite_MultipleProbes_MultipleConcurrentCalls(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime.Add(500 * time.Millisecond)).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime.Add(500 * time.Millisecond)).AnyTimes() // Call 1 measurement1 := createPingMeasurement_MultipleProbes(measurementID1) @@ -295,7 +295,7 @@ func Test_OutputInfinite_MultipleProbes_MultipleConcurrentCalls(t *testing.T) { w := new(bytes.Buffer) printer := NewPrinter(nil, w, w) printer.DisableStyling() - viewer := NewViewer(ctx, printer, timeMock, nil) + viewer := NewViewer(ctx, printer, utilsMock, nil) expectedOutput := `Location | Sent | Loss | Last | Min | Avg | Max London, GB, EU, OVH SAS (AS0) | 1 | 0.00% | 10.0 ms | 10.0 ms | 10.0 ms | 10.0 ms @@ -411,12 +411,12 @@ func Test_OutputInfinite_MultipleProbes(t *testing.T) { measurement := createPingMeasurement_MultipleProbes(measurementID1) - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime.Add(500 * time.Millisecond)).AnyTimes() + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime.Add(500 * time.Millisecond)).AnyTimes() ctx := createDefaultContext("ping") w := new(bytes.Buffer) - v := NewViewer(ctx, NewPrinter(nil, w, w), timeMock, nil) + v := NewViewer(ctx, NewPrinter(nil, w, w), utilsMock, nil) err := v.OutputInfinite(measurement) assert.NoError(t, err) @@ -701,11 +701,11 @@ func Test_ParsePingRawOutput_NoStats(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime.Add(100 * time.Millisecond)) + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime.Add(100 * time.Millisecond)) ctx := createDefaultContext("ping") - v := viewer{ctx: ctx, time: timeMock} + v := viewer{ctx: ctx, utils: utilsMock} hm := ctx.History.Find(measurementID1) @@ -749,11 +749,11 @@ func Test_ParsePingRawOutput_NoStats_WithStartIncmpSeq(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime.Add(100 * time.Millisecond)) + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime.Add(100 * time.Millisecond)) ctx := createDefaultContext("ping") - v := viewer{ctx: ctx, time: timeMock} + v := viewer{ctx: ctx, utils: utilsMock} hm := ctx.History.Find(measurementID1) @@ -805,11 +805,11 @@ func Test_ParsePingRawOutput_WithRedirect(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - timeMock := mocks.NewMockTime(ctrl) - timeMock.EXPECT().Now().Return(defaultCurrentTime.Add(100 * time.Millisecond)) + utilsMock := mocks.NewMockUtils(ctrl) + utilsMock.EXPECT().Now().Return(defaultCurrentTime.Add(100 * time.Millisecond)) ctx := createDefaultContext("ping") - v := viewer{ctx: ctx, time: timeMock} + v := viewer{ctx: ctx, utils: utilsMock} hm := ctx.History.Find(measurementID1) diff --git a/view/printer.go b/view/printer.go index e2bc719..adade52 100644 --- a/view/printer.go +++ b/view/printer.go @@ -1,6 +1,8 @@ package view import ( + "bufio" + "errors" "fmt" "io" "math" @@ -176,6 +178,25 @@ func (p *Printer) BoldBackground(s string, color Color) string { return fmt.Sprintf("\033[1;48;5;%sm%s\033[0m", color, s) } +func (p *Printer) ReadPassword() (string, error) { + if p.InReader == nil { + return "", errors.New("no input reader") + } + f, ok := p.InReader.(*os.File) + if !ok { + scanner := bufio.NewScanner(p.InReader) + scanner.Scan() + return scanner.Text(), scanner.Err() + } + bytePassword, err := term.ReadPassword(int(f.Fd())) + if err != nil { + scanner := bufio.NewScanner(p.InReader) + scanner.Scan() + return scanner.Text(), scanner.Err() + } + return string(bytePassword), nil +} + func (p *Printer) GetSize() (width, height int) { f, ok := p.OutWriter.(*os.File) if !ok { diff --git a/view/viewer.go b/view/viewer.go index 9e400c7..733b897 100644 --- a/view/viewer.go +++ b/view/viewer.go @@ -15,20 +15,20 @@ type Viewer interface { type viewer struct { ctx *Context printer *Printer - time utils.Time + utils utils.Utils globalping globalping.Client } func NewViewer( ctx *Context, printer *Printer, - time utils.Time, + utils utils.Utils, globalpingClient globalping.Client, ) Viewer { return &viewer{ ctx: ctx, printer: printer, - time: time, + utils: utils, globalping: globalpingClient, } }