From 8da6f831f5b03530ce262418a4783bf7f35a1a80 Mon Sep 17 00:00:00 2001 From: Paulo Suzart Date: Sun, 3 Mar 2024 09:47:46 +0100 Subject: [PATCH] Small changes for manual mocking and testing --- db/db.go | 34 +++++++++++++++++++++++--- db/db_test.go | 4 +-- fga.go | 13 +++------- fga_test.go | 67 +++++++++++++++++++++++++++++++++++++++++++++++++++ main.go | 35 +++++++++++++++++++++++++++ ui.go | 5 ++-- 6 files changed, 140 insertions(+), 18 deletions(-) create mode 100644 fga_test.go diff --git a/db/db.go b/db/db.go index b63ce31..e4a5787 100644 --- a/db/db.go +++ b/db/db.go @@ -11,9 +11,34 @@ import ( ) var ( - db *sqlx.DB + db *sqlx.DB + Repository TupleRepository ) +type TupleRepository interface { + CountTuples(filter *Filter) int + GetMarkedForDeletion() []Tuple +} + +type SqlxRepository struct { + TupleRepository + _db *sqlx.DB +} + +func (r *SqlxRepository) CountTuples(filter *Filter) int { + return countTuples(filter) +} + +func (r *SqlxRepository) GetMarkedForDeletion() []Tuple { + return getMarkedForDeletion() +} + +func newRepository() TupleRepository { + var repo TupleRepository + repo = &SqlxRepository{} + return repo +} + // Transact keeps it simple and executes the passed function func Transact(f func()) error { tx := db.MustBegin() @@ -55,6 +80,7 @@ func setupDb(dataSource string) { ) ` db.MustExec(sts) + Repository = newRepository() log.Printf("Finished db setup") } func SetupDb() { @@ -268,7 +294,7 @@ func Load(offset int, filter *Filter) *LoadResult { upperBound: res[len(res)-1].Row, Res: res, Filter: filter, - total: CountTuples(filter), + total: Repository.CountTuples(filter), } } @@ -283,7 +309,7 @@ func GetContinuationToken(apiUrl, storeId string) *string { return &token } -func CountTuples(filter *Filter) int { +func countTuples(filter *Filter) int { selectClause := "select count(*) as count from tuples" var params = make(map[string]interface{}) if filter != nil && filter.isSet() { @@ -389,7 +415,7 @@ func GetObjectTypes() []string { return getTypes("object_type") } -func GetMarkedForDeletion() []Tuple { +func getMarkedForDeletion() []Tuple { sql := `select tuples.* from tuples join pending_actions on pending_actions.tuple_key = tuples.tuple_key and pending_actions.action = 'D' limit 10 ` diff --git a/db/db_test.go b/db/db_test.go index e1d72c8..86a3edb 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -24,7 +24,7 @@ func TestCount(t *testing.T) { ApplyChange(tupleChange) // then - if c := CountTuples(nil); c != 1 { + if c := Repository.CountTuples(nil); c != 1 { t.Error("There must be 1 entry") } }) @@ -43,7 +43,7 @@ func TestCount(t *testing.T) { ApplyChange(tupleChange) // then - if c := CountTuples(nil); c != 0 { + if c := Repository.CountTuples(nil); c != 0 { t.Error("There must be 1 entry") } }) diff --git a/fga.go b/fga.go index d3644fb..edd50e6 100644 --- a/fga.go +++ b/fga.go @@ -21,10 +21,7 @@ func create(ctx context.Context, tupleKey string) { key := openfga.NewTupleKey(user, relation, object) tuple := openfga.NewWriteRequestWrites([]openfga.TupleKey{*key}) - _, _, err := fgaClient.OpenFgaApi.Write(ctx). - Body(openfga.WriteRequest{ - Writes: tuple, - }).Execute() + err := fga.write(ctx, tuple) if err != nil { log.Printf("Error writing tuple: %v", err) @@ -35,7 +32,7 @@ func create(ctx context.Context, tupleKey string) { func deleteMarked(ctx context.Context) { for { - results := db.GetMarkedForDeletion() + results := db.Repository.GetMarkedForDeletion() if results != nil { for _, tuple := range results { deleteTuple := openfga.TupleKeyWithoutCondition{ @@ -44,11 +41,7 @@ func deleteMarked(ctx context.Context) { Object: tuple.ObjectType + ":" + tuple.ObjectId, } deletes := []openfga.TupleKeyWithoutCondition{deleteTuple} - _, resp, err := fgaClient.OpenFgaApi. - Write(ctx). - Body(openfga.WriteRequest{ - Deletes: &openfga.WriteRequestDeletes{ - TupleKeys: deletes}}).Execute() + resp, err := fga.delete(ctx, deletes) if err != nil && resp.StatusCode != 200 { log.Printf("Error deleting tuples %v: %v", err, resp) } diff --git a/fga_test.go b/fga_test.go new file mode 100644 index 0000000..7b5eb65 --- /dev/null +++ b/fga_test.go @@ -0,0 +1,67 @@ +package main + +import ( + "context" + openfga "github.com/openfga/go-sdk" + "github.com/paulosuzart/fgamanager/db" + "net/http" + "testing" + "time" +) + +type mockRepo struct { + db.TupleRepository + GetMarkedForDeletionFunc func() []db.Tuple +} + +func (r mockRepo) GetMarkedForDeletion() []db.Tuple { + return r.GetMarkedForDeletionFunc() +} + +func (r mockRepo) CountTuples(filter *db.Filter) int { + return 0 +} + +type mockFga struct { + fgaService + deleteFunc func(ctx context.Context, deletes []openfga.TupleKeyWithoutCondition) (*http.Response, error) +} + +func (m mockFga) delete(ctx context.Context, deletes []openfga.TupleKeyWithoutCondition) (*http.Response, error) { + return m.deleteFunc(ctx, deletes) +} + +func Test(t *testing.T) { + + db.Repository = mockRepo{ + GetMarkedForDeletionFunc: func() []db.Tuple { + return []db.Tuple{ + { + TupleKey: "user:jack member org:acme", + UserType: "user", + UserId: "jack", + Relation: "member", + ObjectType: "org", + ObjectId: "acme", + }, + } + }, + } + t.Run("Test Writes", func(t *testing.T) { + invokedChan := make(chan interface{}) + fga = mockFga{deleteFunc: func(ctx context.Context, deletes []openfga.TupleKeyWithoutCondition) (*http.Response, error) { + if len(deletes) == 0 { + t.Error("At least one tuple for deletion is expected") + } + if deletes[0].User != "user:jack" { + t.Error("User data mismatch") + } + invokedChan <- true + return &http.Response{StatusCode: 200}, nil + }} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + go deleteMarked(ctx) + <-invokedChan + cancel() + }) +} diff --git a/main.go b/main.go index 3a8fb45..c6ef3f5 100644 --- a/main.go +++ b/main.go @@ -8,8 +8,10 @@ import ( "github.com/paulosuzart/fgamanager/db" "github.com/rivo/tview" "log" + "net/http" "net/url" "os" + "testing" ) var ( @@ -19,6 +21,11 @@ var ( ) func init() { + if testing.Testing() { + testId := "TESTID" + storeId = &testId + return + } err := parser.Parse(os.Args) if err != nil { fmt.Printf("Error: %v", err) @@ -32,6 +39,7 @@ func init() { var ( fgaClient *openfga.APIClient + fga fgaService ) type WatchUpdate struct { @@ -40,6 +48,32 @@ type WatchUpdate struct { WatchEnabled string } +type fgaService interface { + write(ctx context.Context, tuple *openfga.WriteRequestWrites) error + delete(ctx context.Context, deletes []openfga.TupleKeyWithoutCondition) (*http.Response, error) +} + +type fgaWrapper struct { + fgaService +} + +func (f *fgaWrapper) write(ctx context.Context, tuple *openfga.WriteRequestWrites) error { + _, _, err := fgaClient.OpenFgaApi.Write(ctx). + Body(openfga.WriteRequest{ + Writes: tuple, + }).Execute() + return err +} + +func (f *fgaWrapper) delete(ctx context.Context, deletes []openfga.TupleKeyWithoutCondition) (*http.Response, error) { + _, resp, err := fgaClient.OpenFgaApi. + Write(ctx). + Body(openfga.WriteRequest{ + Deletes: &openfga.WriteRequestDeletes{ + TupleKeys: deletes}}).Execute() + return resp, err +} + func main() { // log to custom file LOG_FILE := "/tmp/fgamanager.log" @@ -69,6 +103,7 @@ func main() { StoreId: *storeId, }) fgaClient = openfga.NewAPIClient(configuration) + fga = &fgaWrapper{} if err != nil { log.Panic("Unable to create openfga config") diff --git a/ui.go b/ui.go index 95a7ea9..7e55be5 100644 --- a/ui.go +++ b/ui.go @@ -29,7 +29,8 @@ func (c *count) setTotal(newTotal int) { func (c *count) refresh(d time.Duration) { for { - dbCount := db.CountTuples(nil) + + dbCount := db.Repository.CountTuples(nil) c.setTotal(dbCount) c.newCountChan <- dbCount time.Sleep(d) @@ -75,7 +76,7 @@ func (a Action) String() string { func (t *TupleView) GetRowCount() int { if t.filterSet { - return db.CountTuples(&t.filter) + 1 + return db.Repository.CountTuples(&t.filter) + 1 } if t.page == nil || t.page.GetTotal() == 0 { return 1