diff --git a/const.go b/const.go index bdc54cdc..9ebc549e 100644 --- a/const.go +++ b/const.go @@ -9,5 +9,6 @@ package nebula_go const ( - ErrorTagNotFound = "TagNotFound: Tag not existed!" + ErrorTagNotFound = "TagNotFound: Tag not existed!" + ErrorEdgeNotFound = "EdgeNotFound: Edge not existed!" ) diff --git a/label.go b/label.go index 6eb64203..496d6635 100644 --- a/label.go +++ b/label.go @@ -92,10 +92,22 @@ func (field LabelFieldSchema) BuildAddTagFieldQL(labelName string) string { return q + ");" } +func (field LabelFieldSchema) BuildAddEdgeFieldQL(labelName string) string { + q := "ALTER EDGE " + labelName + " ADD (" + field.Field + " " + field.Type + if !field.Nullable { + q += " NOT NULL" + } + return q + ");" +} + func (field Label) BuildDropTagFieldQL(labelName string) string { return "ALTER TAG " + labelName + " DROP (" + field.Field + ");" } +func (field Label) BuildDropEdgeFieldQL(labelName string) string { + return "ALTER EDGE " + labelName + " DROP (" + field.Field + ");" +} + type LabelName struct { Name string `nebula:"Name"` } diff --git a/label_test.go b/label_test.go index d74b8c2b..2e85b8ce 100644 --- a/label_test.go +++ b/label_test.go @@ -55,14 +55,22 @@ func TestBuildAddFieldQL(t *testing.T) { Type: "string", Nullable: false, } + // tag assert.Equal(t, "ALTER TAG account ADD (name string NOT NULL);", field.BuildAddTagFieldQL("account")) field.Nullable = true assert.Equal(t, "ALTER TAG account ADD (name string);", field.BuildAddTagFieldQL("account")) + // edge + assert.Equal(t, "ALTER EDGE account ADD (name string);", field.BuildAddEdgeFieldQL("account")) + field.Nullable = false + assert.Equal(t, "ALTER EDGE account ADD (name string NOT NULL);", field.BuildAddEdgeFieldQL("account")) } func TestBuildDropFieldQL(t *testing.T) { field := Label{ Field: "name", } + // tag assert.Equal(t, "ALTER TAG account DROP (name);", field.BuildDropTagFieldQL("account")) + // edge + assert.Equal(t, "ALTER EDGE account DROP (name);", field.BuildDropEdgeFieldQL("account")) } diff --git a/schema_manager.go b/schema_manager.go index 43384ebd..ba6aa7a6 100644 --- a/schema_manager.go +++ b/schema_manager.go @@ -107,3 +107,90 @@ func (mgr *SchemaManager) ApplyTag(tag LabelSchema) (*ResultSet, error) { return nil, nil } + +// ApplyEdge applies the given edge to the graph. +// 1. If the edge does not exist, it will be created. +// 2. If the edge exists, it will be checked if the fields are the same. +// 2.1 If not, the new fields will be added. +// 2.2 If the field type is different, it will return an error. +// 2.3 If a field exists in the graph but not in the given edge, +// it will be removed. +// +// Notice: +// We won't change the field type because it has +// unexpected behavior for the data. +func (mgr *SchemaManager) ApplyEdge(edge LabelSchema) (*ResultSet, error) { + // 1. Check if the edge exists + fields, err := mgr.pool.DescEdge(edge.Name) + if err != nil { + // 2. If the edge does not exist, create it + if strings.Contains(err.Error(), ErrorEdgeNotFound) { + return mgr.pool.CreateEdge(edge) + } + return nil, err + } + + // 3. If the edge exists, check if the fields are the same + if err != nil { + return nil, err + } + + // 4. Add new fields + // 4.1 Prepare the new fields + addFieldQLs := []string{} + for _, expected := range edge.Fields { + found := false + for _, actual := range fields { + if expected.Field == actual.Field { + found = true + // 4.1 Check if the field type is different + if expected.Type != actual.Type { + return nil, fmt.Errorf("field type is different. "+ + "Expected: %s, Actual: %s", expected.Type, actual.Type) + } + break + } + } + if !found { + // 4.2 Add the not exists field QL + q := expected.BuildAddEdgeFieldQL(edge.Name) + addFieldQLs = append(addFieldQLs, q) + } + } + // 4.3 Execute the add field QLs if needed + if len(addFieldQLs) > 0 { + queries := strings.Join(addFieldQLs, "") + _, err := mgr.pool.ExecuteAndCheck(queries) + if err != nil { + return nil, err + } + } + + // 5. Remove the not expected field + // 5.1 Prepare the not expected fields + dropFieldQLs := []string{} + for _, actual := range fields { + redundant := true + for _, expected := range edge.Fields { + if expected.Field == actual.Field { + redundant = false + break + } + } + if redundant { + // 5.2 Remove the not expected field + q := actual.BuildDropEdgeFieldQL(edge.Name) + dropFieldQLs = append(dropFieldQLs, q) + } + } + // 5.3 Execute the drop field QLs if needed + if len(dropFieldQLs) > 0 { + queries := strings.Join(dropFieldQLs, "") + _, err := mgr.pool.ExecuteAndCheck(queries) + if err != nil { + return nil, err + } + } + + return nil, nil +} diff --git a/schema_manager_test.go b/schema_manager_test.go index 358683cf..1eaa7863 100644 --- a/schema_manager_test.go +++ b/schema_manager_test.go @@ -60,6 +60,7 @@ func TestSchemaManagerApplyTag(t *testing.T) { Fields: []LabelFieldSchema{ { Field: "name", + Type: "string", Nullable: false, }, }, @@ -158,3 +159,136 @@ func TestSchemaManagerApplyTag(t *testing.T) { assert.Equal(t, "phone", labels[1].Field) assert.Equal(t, "int64", labels[1].Type) } + +func TestSchemaManagerApplyEdge(t *testing.T) { + spaceName := "test_space_apply_edge" + err := prepareSpace(spaceName) + if err != nil { + t.Fatal(err) + } + defer dropSpace(spaceName) + + hostAddress := HostAddress{Host: address, Port: port} + config, err := NewSessionPoolConf( + "root", + "nebula", + []HostAddress{hostAddress}, + spaceName) + if err != nil { + t.Errorf("failed to create session pool config, %s", err.Error()) + } + + // allow only one session in the pool so it is easier to test + config.maxSize = 1 + + // create session pool + sessionPool, err := NewSessionPool(*config, DefaultLogger{}) + if err != nil { + t.Fatal(err) + } + defer sessionPool.Close() + + schemaManager := NewSchemaManager(sessionPool) + + spaces, err := sessionPool.ShowSpaces() + if err != nil { + t.Fatal(err) + } + assert.LessOrEqual(t, 1, len(spaces)) + var spaceNames []string + for _, space := range spaces { + spaceNames = append(spaceNames, space.Name) + } + assert.Contains(t, spaceNames, spaceName) + + edgeSchema := LabelSchema{ + Name: "account_email", + Fields: []LabelFieldSchema{ + { + Field: "email", + Type: "string", + Nullable: false, + }, + }, + } + _, err = schemaManager.ApplyEdge(edgeSchema) + if err != nil { + t.Fatal(err) + } + edges, err := sessionPool.ShowEdges() + if err != nil { + t.Fatal(err) + } + assert.Equal(t, 1, len(edges)) + assert.Equal(t, "account_email", edges[0].Name) + labels, err := sessionPool.DescEdge("account_email") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, 1, len(labels)) + assert.Equal(t, "email", labels[0].Field) + assert.Equal(t, "string", labels[0].Type) + + edgeSchema = LabelSchema{ + Name: "account_email", + Fields: []LabelFieldSchema{ + { + Field: "email", + Type: "string", + Nullable: false, + }, + { + Field: "created_at", + Type: "timestamp", + Nullable: true, + }, + }, + } + _, err = schemaManager.ApplyEdge(edgeSchema) + if err != nil { + t.Fatal(err) + } + edges, err = sessionPool.ShowEdges() + if err != nil { + t.Fatal(err) + } + assert.Equal(t, 1, len(edges)) + assert.Equal(t, "account_email", edges[0].Name) + labels, err = sessionPool.DescEdge("account_email") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, 2, len(labels)) + assert.Equal(t, "email", labels[0].Field) + assert.Equal(t, "string", labels[0].Type) + assert.Equal(t, "created_at", labels[1].Field) + assert.Equal(t, "timestamp", labels[1].Type) + + edgeSchema = LabelSchema{ + Name: "account_email", + Fields: []LabelFieldSchema{ + { + Field: "email", + Type: "string", + Nullable: false, + }, + }, + } + _, err = schemaManager.ApplyEdge(edgeSchema) + if err != nil { + t.Fatal(err) + } + edges, err = sessionPool.ShowEdges() + if err != nil { + t.Fatal(err) + } + assert.Equal(t, 1, len(edges)) + assert.Equal(t, "account_email", edges[0].Name) + labels, err = sessionPool.DescEdge("account_email") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, 1, len(labels)) + assert.Equal(t, "email", labels[0].Field) + assert.Equal(t, "string", labels[0].Type) +}