diff --git a/pkg/go/.golangci.yaml b/pkg/go/.golangci.yaml index 81f7b04a..a722a8d3 100644 --- a/pkg/go/.golangci.yaml +++ b/pkg/go/.golangci.yaml @@ -47,6 +47,7 @@ linters-settings: - "$test" allow: - $gostd + - github.com/google/go-cmp/cmp - github.com/stretchr/testify - github.com/openfga/api - gopkg.in/yaml.v3 diff --git a/pkg/go/go.mod b/pkg/go/go.mod index 55828ec7..3d6b6ef3 100644 --- a/pkg/go/go.mod +++ b/pkg/go/go.mod @@ -4,9 +4,11 @@ go 1.21.9 require ( github.com/antlr4-go/antlr/v4 v4.13.0 + github.com/google/go-cmp v0.6.0 github.com/hashicorp/go-multierror v1.1.1 github.com/openfga/api/proto v0.0.0-20240318145204-66b9e5cb403c github.com/stretchr/testify v1.9.0 + gonum.org/v1/gonum v0.15.0 google.golang.org/protobuf v1.34.0 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/pkg/go/go.sum b/pkg/go/go.sum index 82cc9298..45a1f5ff 100644 --- a/pkg/go/go.sum +++ b/pkg/go/go.sum @@ -34,6 +34,8 @@ golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +gonum.org/v1/gonum v0.15.0 h1:2lYxjRbTYyxkJxlhC+LvJIx3SsANPdRybu1tGj9/OrQ= +gonum.org/v1/gonum v0.15.0/go.mod h1:xzZVBJBtS+Mz4q0Yl2LJTk+OxOg4jiXZ7qBoM0uISGo= google.golang.org/genproto/googleapis/api v0.0.0-20240401170217-c3f982113cda h1:b6F6WIV4xHHD0FA4oIyzU6mHWg2WI2X1RBehwa5QN38= google.golang.org/genproto/googleapis/api v0.0.0-20240401170217-c3f982113cda/go.mod h1:AHcE/gZH76Bk/ROZhQphlRoWo5xKDEtz3eVEO1LfA8c= google.golang.org/genproto/googleapis/rpc v0.0.0-20240401170217-c3f982113cda h1:LI5DOvAxUPMv/50agcLLoo+AdWc1irS9Rzz4vPuD1V4= diff --git a/pkg/go/graph/graph.go b/pkg/go/graph/graph.go new file mode 100644 index 00000000..c7fc065d --- /dev/null +++ b/pkg/go/graph/graph.go @@ -0,0 +1,40 @@ +package graph + +import ( + "errors" + + "gonum.org/v1/gonum/graph/encoding" + "gonum.org/v1/gonum/graph/encoding/dot" + "gonum.org/v1/gonum/graph/multi" +) + +var ErrBuildingGraph = errors.New("cannot build graph") + +type AuthorizationModelGraph struct { + *multi.DirectedGraph +} + +var _ dot.Attributers = (*AuthorizationModelGraph)(nil) + +func (g *AuthorizationModelGraph) DOTAttributers() (graph, node, edge encoding.Attributer) { + return g, nil, nil +} + +func (g *AuthorizationModelGraph) Attributes() []encoding.Attribute { + // https://graphviz.org/docs/attrs/rankdir/ - bottom to top + return []encoding.Attribute{{ + Key: "rankdir", + Value: "BT", + }} +} + +// GetDOT returns the DOT visualization. The output text is stable. +// It should only be used for debugging. +func (g *AuthorizationModelGraph) GetDOT() string { + dotRepresentation, err := dot.MarshalMulti(g, "", "", "") + if err != nil { + return "" + } + + return string(dotRepresentation) +} diff --git a/pkg/go/graph/graph_builder.go b/pkg/go/graph/graph_builder.go new file mode 100644 index 00000000..a29d48a6 --- /dev/null +++ b/pkg/go/graph/graph_builder.go @@ -0,0 +1,127 @@ +package graph + +import ( + "cmp" + "fmt" + "slices" + + openfgav1 "github.com/openfga/api/proto/openfga/v1" + "gonum.org/v1/gonum/graph" + "gonum.org/v1/gonum/graph/multi" +) + +type AuthorizationModelGraphBuilder struct { + graph.DirectedMultigraphBuilder + + ids map[string]int64 // nodes: unique labels to ids. Used to find nodes by label. +} + +// NewAuthorizationModelGraph builds an authorization model in graph form. +// For example, types such as `group`, usersets such as `group#member` and wildcards `group:*` are encoded as nodes. +// +// The edges are defined by the relations, e.g. +// `define viewer: [group]` defines an edge from group to document#viewer. +func NewAuthorizationModelGraph(model *openfgav1.AuthorizationModel) (*AuthorizationModelGraph, error) { + res, err := parseModel(model) + if err != nil { + return nil, err + } + + return &AuthorizationModelGraph{res}, nil +} + +func parseModel(model *openfgav1.AuthorizationModel) (*multi.DirectedGraph, error) { + graphBuilder := &AuthorizationModelGraphBuilder{ + multi.NewDirectedGraph(), map[string]int64{}, + } + + // sort types by name to guarantee stable output + sortedTypeDefs := make([]*openfgav1.TypeDefinition, len(model.GetTypeDefinitions())) + copy(sortedTypeDefs, model.GetTypeDefinitions()) + + slices.SortFunc(sortedTypeDefs, func(a, b *openfgav1.TypeDefinition) int { + return cmp.Compare(a.GetType(), b.GetType()) + }) + + for _, typeDef := range sortedTypeDefs { + graphBuilder.GetOrAddNode(typeDef.GetType(), typeDef.GetType(), SpecificType) + + // sort relations by name to guarantee stable output + sortedRelations := make([]string, 0, len(typeDef.GetRelations())) + for relationName := range typeDef.GetRelations() { + sortedRelations = append(sortedRelations, relationName) + } + + slices.Sort(sortedRelations) + + for _, relation := range sortedRelations { + uniqueLabel := fmt.Sprintf("%s#%s", typeDef.GetType(), relation) + relationNode := graphBuilder.GetOrAddNode(uniqueLabel, uniqueLabel, UsersetType) + + rewrite := typeDef.GetRelations()[relation] + switch rewrite.GetUserset().(type) { + case *openfgav1.Userset_This: + directlyRelated := make([]*openfgav1.RelationReference, 0) + if metadata, ok := typeDef.GetMetadata().GetRelations()[relation]; ok { + directlyRelated = metadata.GetDirectlyRelatedUserTypes() + } + + for _, directlyRelatedDef := range directlyRelated { + assignableType := directlyRelatedDef.GetType() + + newNode := graphBuilder.GetOrAddNode(assignableType, assignableType, SpecificType) + graphBuilder.AddEdge(newNode, relationNode) + } + } + } + } + + multigraph, ok := graphBuilder.DirectedMultigraphBuilder.(*multi.DirectedGraph) + if ok { + return multigraph, nil + } + + return nil, fmt.Errorf("%w: could not cast to directed graph", ErrBuildingGraph) +} + +func (g *AuthorizationModelGraphBuilder) GetOrAddNode(uniqueLabel, label string, nodeType NodeType) *AuthorizationModelNode { + if existingNode := g.GetNodeFor(uniqueLabel); existingNode != nil { + return existingNode + } + + node := g.NewNode() + nodeid := node.ID() + newNode := &AuthorizationModelNode{ + Node: node, + label: label, + nodeType: nodeType, + uniqueLabel: uniqueLabel, + } + g.AddNode(newNode) + g.ids[uniqueLabel] = nodeid + + return newNode +} + +func (g *AuthorizationModelGraphBuilder) GetNodeFor(uniqueLabel string) *AuthorizationModelNode { + id, ok := g.ids[uniqueLabel] + if !ok { + return nil + } + + authModelNode, ok := g.Node(id).(*AuthorizationModelNode) + if !ok { + return nil + } + + return authModelNode +} + +func (g *AuthorizationModelGraphBuilder) AddEdge(from, to graph.Node) *AuthorizationModelEdge { + l := g.NewLine(from, to) + lineid := l.ID() + newLine := &AuthorizationModelEdge{Line: l, uniqueid: lineid} + g.SetLine(newLine) + + return newLine +} diff --git a/pkg/go/graph/graph_edge.go b/pkg/go/graph/graph_edge.go new file mode 100644 index 00000000..537a5a10 --- /dev/null +++ b/pkg/go/graph/graph_edge.go @@ -0,0 +1,19 @@ +package graph + +import ( + "gonum.org/v1/gonum/graph" + "gonum.org/v1/gonum/graph/encoding" +) + +type AuthorizationModelEdge struct { + graph.Line + uniqueid int64 +} + +var _ encoding.Attributer = (*AuthorizationModelEdge)(nil) + +func (n *AuthorizationModelEdge) Attributes() []encoding.Attribute { + var attrs []encoding.Attribute + + return attrs +} diff --git a/pkg/go/graph/graph_node.go b/pkg/go/graph/graph_node.go new file mode 100644 index 00000000..3b4284cd --- /dev/null +++ b/pkg/go/graph/graph_node.go @@ -0,0 +1,35 @@ +package graph + +import ( + "gonum.org/v1/gonum/graph" + "gonum.org/v1/gonum/graph/encoding" +) + +type NodeType int64 + +const ( + SpecificType NodeType = 0 // `group` + UsersetType NodeType = 2 // `group#viewer` +) + +type AuthorizationModelNode struct { + graph.Node + label string // e.g. `union`, for DOT + nodeType NodeType + uniqueLabel string // e.g. `union[a,b]` +} + +func (n *AuthorizationModelNode) String() string { return n.uniqueLabel } + +var _ encoding.Attributer = (*AuthorizationModelNode)(nil) + +func (n *AuthorizationModelNode) Attributes() []encoding.Attribute { + var attrs []encoding.Attribute + + attrs = append(attrs, encoding.Attribute{ + Key: "label", + Value: n.label, + }) + + return attrs +} diff --git a/pkg/go/graph/graph_test.go b/pkg/go/graph/graph_test.go new file mode 100644 index 00000000..41f1cf12 --- /dev/null +++ b/pkg/go/graph/graph_test.go @@ -0,0 +1,76 @@ +package graph + +import ( + "sort" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" + + language "github.com/openfga/language/pkg/go/transformer" +) + +// TestGetDOTRepresentation also tests that the graph is built correctly. +func TestGetDOTRepresentation(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + model string + expectedOutput string + }{ + `direct_assignment`: { + model: ` + model + schema 1.1 + type folder + relations + define viewer: [user] + type user`, + expectedOutput: `digraph { +graph [ +rankdir=BT +]; + +// Node definitions. +0 [label=folder]; +1 [label="folder#viewer"]; +2 [label=user]; + +// Edge definitions. +2 -> 1; +}`, + }, + } + + for name, test := range testCases { + test := test + + t.Run(name, func(t *testing.T) { + t.Parallel() + + model := language.MustTransformDSLToProto(test.model) + graph, err := NewAuthorizationModelGraph(model) + require.NoError(t, err) + + actualDOT := graph.GetDOT() + actualSorted := getSorted(actualDOT) + expectedSorted := getSorted(test.expectedOutput) + + diff := cmp.Diff(expectedSorted, actualSorted) + + require.Empty(t, diff, "expected %s\ngot %s", test.expectedOutput, actualDOT) + }) + } +} + +// getSorted assumes the input has multiple lines and returns the sorted version of it. +func getSorted(input string) string { + lines := strings.FieldsFunc(input, func(r rune) bool { + return r == '\n' + }) + + sort.Strings(lines) + + return strings.Join(lines, "\n") +}