diff --git a/Makefile b/Makefile index d15cdb7..bd59a74 100644 --- a/Makefile +++ b/Makefile @@ -32,7 +32,7 @@ PID := $(GOBUILD)/.$(PROJECTNAME).pid # Make is verbose in Linux. Make it silent. MAKEFLAGS += --silent -default: install lint format test compile +default: lint format test ci-checks: lint format test @@ -50,9 +50,7 @@ compile: #@cat $(STDERR) | sed -e '1s/.*/\nError:\n/' | sed 's/make\[.*/ /' | sed "/^/s/^/ /" 1>&2 -test: install go-test - -cover: install go-cover +test: go-test clean: @-rm $(GOBIN)/$(PROGRAMNAME)* 2> /dev/null @@ -71,11 +69,6 @@ go-build: go-get go-build-linux-amd64 go-build-linux-arm64 go-build-darwin-amd64 go-test: go test $(MODFLAGS) `go list $(MODFLAGS) ./...` -go-cover: - go test $(MODFLAGS) -coverprofile=$(GOBUILD)/.coverprof `go list $(MODFLAGS) ./...` - go tool cover -html=$(GOBUILD)/.coverprof -o $(GOBUILD)/coverage.html - @open $(GOBUILD)/coverage.html - go-build-linux-amd64: @echo " > Building linux amd64 binaries..." @GOPATH=$(GOPATH) GOOS=$(GOOS_LINUX) GOARCH=$(GOARCH_AMD64) GOBIN=$(GOBIN) go build $(MODFLAGS) $(LDFLAGS) -o $(GOBIN)/$(PROGRAMNAME)-$(GOOS_LINUX)-$(GOARCH_AMD64) $(GOBASE)/internal diff --git a/spinner.go b/spinner.go index 660f7bc..9ea452e 100644 --- a/spinner.go +++ b/spinner.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "strings" "sync" "time" ) @@ -56,16 +57,16 @@ func (f *SimpleSpinnerFormatter) CharSeq() []string { type Spinner interface { Start() (context.CancelFunc, error) Stop(string) error - SetTitle(title string) + SetTitle(title string) error } type spinner struct { writer io.Writer interval time.Duration - mx *sync.RWMutex - titleMx *sync.RWMutex + stateMx *sync.RWMutex active bool stopC chan bool + titleC chan string title string formatter SpinnerFormatter } @@ -75,10 +76,10 @@ func NewSpinner(writer io.Writer, title string, interval time.Duration, formatte return &spinner{ writer: writer, interval: interval, - mx: &sync.RWMutex{}, - titleMx: &sync.RWMutex{}, + stateMx: &sync.RWMutex{}, active: false, stopC: make(chan bool), + titleC: make(chan string), title: title, formatter: formatter, } @@ -86,7 +87,7 @@ func NewSpinner(writer io.Writer, title string, interval time.Duration, formatte // NewDefaultSpinner creates a new Spinner that writes to Stdout with a default update interval func NewDefaultSpinner() Spinner { - return NewSpinner(StdoutWriter, "", 500, DefaultSpinnerFormatter()) + return NewSpinner(StdoutWriter, "", time.Millisecond*100, DefaultSpinnerFormatter()) } func (s *spinner) writeString(str string) (n int, err error) { @@ -95,8 +96,8 @@ func (s *spinner) writeString(str string) (n int, err error) { // Start starts the spinner in the background and returns a cancellation handle and an error in case the spinner is already running. func (s *spinner) Start() (cancel context.CancelFunc, err error) { - s.mx.Lock() - defer s.mx.Unlock() + s.stateMx.Lock() + defer s.stateMx.Unlock() if s.active { return nil, errors.New("spinner already active") @@ -115,27 +116,39 @@ func (s *spinner) Start() (cancel context.CancelFunc, err error) { defer s.setActiveSafe(false) + update := func(title string) { + indicatorValue := s.formatter.FormatIndicator(fmt.Sprintf("%v", spinring.Value)) + if title != "" { + _, _ = s.writeString(fmt.Sprintf("%s%s %s", TermControlEraseLine, indicatorValue, s.formatter.FormatTitle(title))) + } else { + _, _ = s.writeString(fmt.Sprintf("%s%s", TermControlEraseLine, indicatorValue)) + } + } + for { select { case <-context.Done(): timer.Stop() + close(s.titleC) + s.printExitMessage("Cancelled...") + return case <-s.stopC: timer.Stop() + close(s.titleC) return + case title := <-s.titleC: + // The title is only written by this routine, so we're safe. + s.title = title + update(title) + case <-timer.C: spinring = spinring.Next() - title := s.getTitle() - indicatorValue := s.formatter.FormatIndicator(fmt.Sprintf("%v", spinring.Value)) - if title != "" { - s.writeString(fmt.Sprintf("%s%s %s", TermControlEraseLine, indicatorValue, s.formatter.FormatTitle(title))) - } else { - s.writeString(fmt.Sprintf("%s%s", TermControlEraseLine, indicatorValue)) - } - + title := s.title + update(title) } } }() @@ -147,8 +160,8 @@ func (s *spinner) Start() (cancel context.CancelFunc, err error) { // Stop stops the spinner and displays the specified message func (s *spinner) Stop(message string) (err error) { - s.mx.Lock() - defer s.mx.Unlock() + s.stateMx.Lock() + defer s.stateMx.Unlock() if !s.active { err = errors.New("spinner not active") @@ -162,23 +175,21 @@ func (s *spinner) Stop(message string) (err error) { } // SetTitle updates the spinner text. -func (s *spinner) SetTitle(title string) { - s.titleMx.Lock() - defer s.titleMx.Unlock() - - s.title = title -} +func (s *spinner) SetTitle(title string) (err error) { + defer func() { + if recover() != nil { + err = errors.New("spinner not active") + } + }() -func (s *spinner) getTitle() string { - s.titleMx.RLock() - defer s.titleMx.RUnlock() + s.titleC <- strings.TrimSpace(title) - return s.title + return err } func (s *spinner) printExitMessage(message string) { - s.writeString(TermControlEraseLine) - s.writeString(message) + _, _ = s.writeString(TermControlEraseLine) + _, _ = s.writeString(message) } func (s *spinner) createSpinnerRing() *ring.Ring { @@ -193,15 +204,15 @@ func (s *spinner) createSpinnerRing() *ring.Ring { } func (s *spinner) isActiveSafe() bool { - s.mx.RLock() - defer s.mx.RUnlock() + s.stateMx.RLock() + defer s.stateMx.RUnlock() return s.active } func (s *spinner) setActiveSafe(active bool) { - s.mx.Lock() - defer s.mx.Unlock() + s.stateMx.Lock() + defer s.stateMx.Unlock() s.active = active } diff --git a/spinner_test.go b/spinner_test.go index 85cfa28..12c1eac 100644 --- a/spinner_test.go +++ b/spinner_test.go @@ -54,7 +54,7 @@ func TestSpinnerStopAlreadyStopped(t *testing.T) { emulatedStdout := new(bytes.Buffer) spin := NewSpinner(emulatedStdout, "", interval, DefaultSpinnerFormatter()) - spin.Start() + _, _ = spin.Start() err := spin.Stop("") assert.NoError(t, err) @@ -66,8 +66,10 @@ func TestSpinnerStopMessage(t *testing.T) { emulatedStdout := new(bytes.Buffer) spin := NewSpinner(emulatedStdout, "", interval, DefaultSpinnerFormatter()) - spin.Start() - err := spin.Stop(expectedStopMessage) + _, err := spin.Start() + assert.NoError(t, err) + + err = spin.Stop(expectedStopMessage) assert.NoError(t, err) assertBufferEventuallyContains(t, emulatedStdout, expectedStopMessage) @@ -96,11 +98,25 @@ func TestSpinnerSetTitle(t *testing.T) { assertBufferEventuallyContains(t, emulatedStdout, expectedInitialTitle) - spin.SetTitle(expectedUpdatedTitle) + assert.NoError(t, spin.SetTitle(expectedUpdatedTitle)) assertBufferEventuallyContains(t, emulatedStdout, expectedUpdatedTitle) } +func TestSpinnerSetTitleOnStoppedSpinner(t *testing.T) { + expectedInitialTitle := generateRandomString() + expectedUpdatedTitle := generateRandomString() + emulatedStdout := new(bytes.Buffer) + + spin := NewSpinner(emulatedStdout, expectedInitialTitle, interval, DefaultSpinnerFormatter()) + _, _ = spin.Start() + + assertBufferEventuallyContains(t, emulatedStdout, expectedInitialTitle) + + assert.NoError(t, spin.Stop("")) + assert.Error(t, spin.SetTitle(expectedUpdatedTitle)) +} + func assertBufferEventuallyContains(t *testing.T, outBuffer *bytes.Buffer, expected string) { assert.Eventually( t, @@ -134,45 +150,35 @@ func assertStoppedEventually(t *testing.T, outBuffer *bytes.Buffer, spinner *spi ) } -// TODO can this be simplified? func assertSpinnerCharSequence(t *testing.T, outBuffer *bytes.Buffer) { charSeq := DefaultSpinnerCharSeq() - readChars := make([]string, len(charSeq)) - readCharsCount := 0 + readChars := []string{} - readSequence := func() string { - startTime := time.Now() + scan := func() { for { - s, _ := outBuffer.ReadString(TermControlEraseLine[len(TermControlEraseLine)-1]) // read everything you got - if strippedString := strings.Trim(s, TermControlEraseLine+"\x00"); strippedString != "" { - return strippedString + r, _, e := outBuffer.ReadRune() + print(string(r), ",") + if e != nil { + continue } - - // guard again infinite loop - if time.Now().After(startTime.Add(time.Second * 30)) { - return "" + readChar := string(r) + if len(readChars) == 0 && readChar == charSeq[0] { + readChars = append(readChars, readChar) + } else if len(readChars) > 0 { + for _, ch := range charSeq { + if ch == readChar { + readChars = append(readChars, ch) + } + + if len(readChars) == len(charSeq) { + return + } + } } } } - // find the first character in the spinner sequence, so we can validate order properly - for { - strippedString := readSequence() - if strippedString != "" && strippedString == charSeq[0] { - readChars[0] = strippedString - break - } - // guard against infinite loop caused by bugs - readCharsCount++ - if readCharsCount > len(charSeq)*2 { - assert.Fail(t, "something went wrong...") - } - } - - for i := 1; i < len(charSeq); i++ { - readChars[i] = readSequence() - } + scan() assert.Equal(t, charSeq, readChars) - }