Skip to content

Commit

Permalink
cmd/atlas/internal/cmdstate: shared package for storing cmd states (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
a8m authored Nov 8, 2023
1 parent e9dafeb commit 00c0306
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 52 deletions.
3 changes: 1 addition & 2 deletions cmd/atlas/internal/cmdapi/migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1273,8 +1273,7 @@ func TestMigrate_Diff(t *testing.T) {

t.Run("Edit", func(t *testing.T) {
p := t.TempDir()
require.NoError(t, os.Setenv("EDITOR", "echo '-- Comment' >>"))
t.Cleanup(func() { require.NoError(t, os.Unsetenv("EDITOR")) })
t.Setenv("EDITOR", "echo '-- Comment' >>")
args := []string{
"--edit",
"--dir", "file://" + p,
Expand Down
48 changes: 16 additions & 32 deletions cmd/atlas/internal/cmdapi/vercheck/vercheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,21 @@ import (
"fmt"
"net/http"
"net/url"
"os"
"path/filepath"
"text/template"
"time"

"ariga.io/atlas/cmd/atlas/internal/cmdstate"
)

// StateFileName is the name of the file where the vercheck state is stored.
const StateFileName = "release.json"

// New returns a new VerChecker for the endpoint.
func New(endpoint, statePath string) *VerChecker {
return &VerChecker{endpoint: endpoint, statePath: statePath}
func New(endpoint string) *VerChecker {
return &VerChecker{
endpoint: endpoint,
state: &cmdstate.File[State]{Name: StateFileName},
}
}

type (
Expand All @@ -46,8 +52,8 @@ type (
}
// VerChecker retrieves version information from the vercheck service.
VerChecker struct {
endpoint string
statePath string
endpoint string
state *cmdstate.File[State]
}
// State stores information about local runs of VerChecker to limit the
// frequency in which clients poll the service for information.
Expand Down Expand Up @@ -91,37 +97,15 @@ func (v *VerChecker) Check(ctx context.Context, ver string) (*Payload, error) {
if err := json.NewDecoder(resp.Body).Decode(&p); err != nil {
return nil, err
}
if v.statePath != "" {
s := State{CheckedAt: time.Now()}
st, err := json.Marshal(s)
if err != nil {
return nil, err
}
// Create containing directory if it doesn't exist.
if err := os.MkdirAll(filepath.Dir(v.statePath), os.ModePerm); err != nil {
return nil, err
}
if err := os.WriteFile(v.statePath, st, 0666); err != nil {
return nil, err
}
if err := v.state.Write(State{CheckedAt: time.Now()}); err != nil {
return nil, err
}
return &p, nil
}

func (v *VerChecker) verifyTime() error {
// Skip check if path to state file isn't configured.
if v.statePath == "" {
return nil
}
var s State
f, err := os.Open(v.statePath)
if err != nil {
return nil
}
if err := json.NewDecoder(f).Decode(&s); err != nil {
return nil
}
if time.Since(s.CheckedAt) >= (time.Hour * 24) {
s, err := v.state.Read()
if err != nil || time.Since(s.CheckedAt) >= (time.Hour*24) {
return nil
}
return errSkip
Expand Down
37 changes: 27 additions & 10 deletions cmd/atlas/internal/cmdapi/vercheck/vercheck_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"time"

"ariga.io/atlas/cmd/atlas/internal/cloudapi"

"github.com/mitchellh/go-homedir"
"github.com/stretchr/testify/require"
)

Expand All @@ -31,7 +33,9 @@ func TestVerCheck(t *testing.T) {
}))
defer srv.Close()

vc := New(srv.URL, "")
home := t.TempDir()
t.Setenv("HOME", home)
vc := New(srv.URL)
ver := "v0.1.2"
check, err := vc.Check(context.Background(), ver)

Expand All @@ -47,9 +51,15 @@ func TestVerCheck(t *testing.T) {
Link: "https://github.com/ariga/atlas/releases/tag/v0.7.2",
},
}, check)

dirs, err := os.ReadDir(filepath.Join(home, ".atlas"))
require.NoError(t, err)
require.Len(t, dirs, 1)
}

func TestState(t *testing.T) {
homedir.DisableCache = true
t.Cleanup(func() { homedir.DisableCache = false })
hrAgo, err := json.Marshal(State{CheckedAt: time.Now().Add(-time.Hour)})
require.NoError(t, err)
weekAgo, err := json.Marshal(State{CheckedAt: time.Now().Add(-time.Hour * 24 * 7)})
Expand Down Expand Up @@ -87,33 +97,40 @@ func TestState(t *testing.T) {
_, _ = w.Write([]byte(`{}`))
}))
t.Cleanup(srv.Close)
path := filepath.Join(t.TempDir(), "release.json")
home := t.TempDir()
path := filepath.Join(home, ".atlas", StateFileName)
if tt.state != "" {
err := os.WriteFile(path, []byte(tt.state), 0666)
require.NoError(t, err)
require.NoError(t, os.MkdirAll(filepath.Dir(path), os.ModePerm))
require.NoError(t, os.WriteFile(path, []byte(tt.state), 0666))
}
vc := New(srv.URL, path)
t.Setenv("HOME", home)
vc := New(srv.URL)
_, _ = vc.Check(context.Background(), "v0.1.2")
require.EqualValues(t, tt.expectedRun, ran)

b, err := os.ReadFile(path)
buf, err := os.ReadFile(path)
require.NoError(t, err)
if tt.expectedRun {
require.NotEqualValues(t, tt.state, b)
require.NotEqualValues(t, tt.state, buf)
} else {
require.EqualValues(t, tt.state, b)
require.EqualValues(t, tt.state, buf)
}
})
}
}

func TestStatePersist(t *testing.T) {
homedir.DisableCache = true
t.Cleanup(func() { homedir.DisableCache = false })

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{}`))
}))
t.Cleanup(srv.Close)
path := filepath.Join(t.TempDir(), ".atlas", "release.json")
vc := New(srv.URL, path)
home := t.TempDir()
path := filepath.Join(home, ".atlas", StateFileName)
t.Setenv("HOME", home)
vc := New(srv.URL)
_, err := vc.Check(context.Background(), "v0.1.2")
require.NoError(t, err)

Expand Down
2 changes: 1 addition & 1 deletion cmd/atlas/internal/cmdext/cmdext_oss.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
)

// RemoteSchema is a data source that for reading remote schemas.
func RemoteSchema(ctx *hcl.EvalContext, block *hclsyntax.Block) (cty.Value, error) {
func RemoteSchema(*hcl.EvalContext, *hclsyntax.Block) (cty.Value, error) {
return cty.Zero, fmt.Errorf("data.remote_schema is not supported by this release. See: https://atlasgo.io/getting-started")
}

Expand Down
82 changes: 82 additions & 0 deletions cmd/atlas/internal/cmdstate/cmdstate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright 2021-present The Atlas Authors. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.

package cmdstate

import (
"encoding/json"
"os"
"path/filepath"
"reflect"

"github.com/mitchellh/go-homedir"
)

// DefaultDir is the directory where CLI state is stored.
const DefaultDir = "~/.atlas"

// File is a state file for the given type.
type File[T any] struct {
// Dir where the file is stored. If empty, DefaultDir is used.
Dir string
// Name of the file. Suffixed with .json.
Name string
}

// Read reads the value from the file system.
func (f File[T]) Read() (v T, err error) {
path, err := f.Path()
if err != nil {
return v, err
}
switch buf, err := os.ReadFile(path); {
case os.IsNotExist(err):
return newT(v), nil
case err != nil:
return v, err
default:
err = json.Unmarshal(buf, &v)
return v, err
}
}

// Write writes the value to the file system.
func (f File[T]) Write(t T) error {
buf, err := json.Marshal(t)
if err != nil {
return err
}
path, err := f.Path()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil {
return err
}
return os.WriteFile(path, buf, 0666)
}

// Path returns the path to the file.
func (f File[T]) Path() (string, error) {
name := f.Name
if filepath.Ext(name) == "" {
name += ".json"
}
if f.Dir != "" {
return filepath.Join(f.Dir, name), nil
}
path, err := homedir.Expand(filepath.Join(DefaultDir, name))
if err != nil {
return "", err
}
return path, nil
}

// newT ensures the type is initialized.
func newT[T any](t T) T {
if rt := reflect.TypeOf(t); rt.Kind() == reflect.Ptr {
return reflect.New(rt.Elem()).Interface().(T)
}
return t
}
53 changes: 53 additions & 0 deletions cmd/atlas/internal/cmdstate/cmdstate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright 2021-present The Atlas Authors. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.

package cmdstate_test

import (
"os"
"path/filepath"
"testing"

"ariga.io/atlas/cmd/atlas/internal/cmdstate"

"github.com/mitchellh/go-homedir"
"github.com/stretchr/testify/require"
)

func TestFile(t *testing.T) {
homedir.DisableCache = true
t.Cleanup(func() { homedir.DisableCache = false })

type T struct{ V string }
f := cmdstate.File[T]{Name: "test", Dir: t.TempDir()}
v, err := f.Read()
require.NoError(t, err)
require.Equal(t, T{}, v)
require.NoError(t, f.Write(T{V: "v"}))
v, err = f.Read()
require.NoError(t, err)
require.Equal(t, T{V: "v"}, v)

home := t.TempDir()
t.Setenv("HOME", home)
f = cmdstate.File[T]{Name: "t"}
_, err = f.Read()
require.NoError(t, err)
dirs, err := os.ReadDir(home)
require.NoError(t, err)
require.Empty(t, dirs)

require.NoError(t, f.Write(T{V: "v"}))
dirs, err = os.ReadDir(home)
require.NoError(t, err)
require.Len(t, dirs, 1)
require.Equal(t, ".atlas", dirs[0].Name())
dirs, err = os.ReadDir(filepath.Join(home, ".atlas"))
require.NoError(t, err)
require.Len(t, dirs, 1)
require.Equal(t, "t.json", dirs[0].Name())
v, err = f.Read()
require.NoError(t, err)
require.Equal(t, T{V: "v"}, v)
}
8 changes: 1 addition & 7 deletions cmd/atlas/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
_ "ariga.io/atlas/sql/postgres/postgrescheck"
_ "ariga.io/atlas/sql/sqlite"
_ "ariga.io/atlas/sql/sqlite/sqlitecheck"
"github.com/mitchellh/go-homedir"
"golang.org/x/mod/semver"

_ "github.com/go-sql-driver/mysql"
Expand Down Expand Up @@ -48,7 +47,6 @@ const (
// envNoUpdate when enabled it cancels checking for update
envNoUpdate = "ATLAS_NO_UPDATE_NOTIFIER"
vercheckURL = "https://vercheck.ariga.io"
versionFile = "~/.atlas/release.json"
)

func noText() string { return "" }
Expand All @@ -65,15 +63,11 @@ func checkForUpdate(ctx context.Context) func() string {
if !semver.IsValid(version) {
return noText
}
path, err := homedir.Expand(versionFile)
if err != nil {
return noText
}
var message string
go func() {
defer close(done)
endpoint := vercheckEndpoint(ctx)
vc := vercheck.New(endpoint, path)
vc := vercheck.New(endpoint)
payload, err := vc.Check(ctx, version)
if err != nil {
return
Expand Down

0 comments on commit 00c0306

Please sign in to comment.