From 6c61299367a819b8017252b44df0b360b5747c5a Mon Sep 17 00:00:00 2001 From: Dan Goslen Date: Sat, 30 Sep 2023 08:35:29 -0400 Subject: [PATCH 1/7] Adds relationship label types and parser --- .go-version | 1 + config/config.go | 7 +++++ config/config_test.go | 22 ++++++++++++++- config/relationship_label.go | 52 ++++++++++++++++++++++++++++++++++++ 4 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 .go-version create mode 100644 config/relationship_label.go diff --git a/.go-version b/.go-version new file mode 100644 index 0000000..2a4feaf --- /dev/null +++ b/.go-version @@ -0,0 +1 @@ +1.19.6 diff --git a/config/config.go b/config/config.go index a12ff08..4c4c3af 100644 --- a/config/config.go +++ b/config/config.go @@ -18,6 +18,7 @@ const ( UseAllSchemasKey = "useAllSchemas" ShowSchemaPrefix = "showSchemaPrefix" SchemaPrefixSeparator = "schemaPrefixSeparator" + RelationshipLabelsKey = "relationshipLabels" ) type config struct{} @@ -38,6 +39,7 @@ type MermerdConfig interface { UseAllSchemas() bool ShowSchemaPrefix() bool SchemaPrefixSeparator() string + RelationshipLabels() []RelationshipLabel } func NewConfig() MermerdConfig { @@ -72,6 +74,11 @@ func (c config) SelectedTables() []string { return viper.GetStringSlice(SelectedTablesKey) } +func (c config) RelationshipLabels() []RelationshipLabel { + labels := viper.GetStringSlice(RelationshipLabelsKey) + return ParseLabels(labels) +} + func (c config) EncloseWithMermaidBackticks() bool { return viper.GetBool(EncloseWithMermaidBackticksKey) } diff --git a/config/config_test.go b/config/config_test.go index 45cb5f0..466c3cf 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -2,9 +2,10 @@ package config import ( "bytes" + "testing" + "github.com/spf13/viper" "github.com/stretchr/testify/assert" - "testing" ) func TestYamlConfig(t *testing.T) { @@ -32,6 +33,10 @@ showDescriptions: - enumValues - columnComments - notNull +relationshipLabels: + - "schema.table1 schema.table2 : is_a" + - "table-name another-table-name : has_many" + - "incorrect format" useAllSchemas: true showSchemaPrefix: true schemaPrefixSeparator: "_" @@ -63,4 +68,19 @@ connectionStringSuggestions: assert.True(t, config.UseAllSchemas()) assert.True(t, config.ShowSchemaPrefix()) assert.Equal(t, "_", config.SchemaPrefixSeparator()) + assert.ElementsMatch(t, + config.RelationshipLabels(), + []RelationshipLabel{ + RelationshipLabel{ + PkName: "schema.table1", + FkName: "schema.table2", + Label: "is_a", + }, + RelationshipLabel{ + PkName: "table-name", + FkName: "another-table-name", + Label: "has_many", + }, + }, + ) } diff --git a/config/relationship_label.go b/config/relationship_label.go new file mode 100644 index 0000000..1c6b2e3 --- /dev/null +++ b/config/relationship_label.go @@ -0,0 +1,52 @@ +package config + +import ( + "errors" + "regexp" + "strings" + + "github.com/sirupsen/logrus" +) + +type RelationshipLabel struct { + PkName string + FkName string + Label string +} + +func ParseLabels(labels []string) []RelationshipLabel { + var relationshipLabels []RelationshipLabel + for _, label := range labels { + parsed, err := ParseLabel(label) + if err != nil { + logrus.Warnf("label '%s' is not in the correct format", label) + continue + } + relationshipLabels = append(relationshipLabels, parsed) + } + return relationshipLabels +} + +func ParseLabel(label string) (RelationshipLabel, error) { + label = strings.Trim(label, " \t") + matched, groups := match(label) + if !matched { + return RelationshipLabel{}, errors.New("invalid relationship label") + } + + return RelationshipLabel{ + PkName: string(groups[1]), + FkName: string(groups[2]), + Label: string(groups[3]), + }, nil +} + +var labelRegex = regexp.MustCompile(`([\w\._-]+)[\s]+([\w\._-]+)[\s]+:[\s]+([\w._-]+)`) + +func match(label string) (bool, [][]byte) { + groups := labelRegex.FindSubmatch([]byte(label)) + if groups == nil { + return false, [][]byte{} + } + return true, groups +} From 92d5d6cdb64b2e4102d4dc4cdb90be6ee75e3aaa Mon Sep 17 00:00:00 2001 From: Dan Goslen Date: Sat, 30 Sep 2023 08:52:01 -0400 Subject: [PATCH 2/7] Lookup label based on pk and fk names; overrides omitting the label and the constraint label --- diagram/diagram_util.go | 22 ++++++++++-- diagram/diagram_util_test.go | 67 ++++++++++++++++++++++++++++++++++++ mocks/MermerdConfig.go | 21 ++++++++++- 3 files changed, 106 insertions(+), 4 deletions(-) diff --git a/diagram/diagram_util.go b/diagram/diagram_util.go index 93cfdf9..a694666 100644 --- a/diagram/diagram_util.go +++ b/diagram/diagram_util.go @@ -2,9 +2,10 @@ package diagram import ( "fmt" - "github.com/sirupsen/logrus" "strings" + "github.com/sirupsen/logrus" + "github.com/KarnerTh/mermerd/config" "github.com/KarnerTh/mermerd/database" ) @@ -91,19 +92,34 @@ func shouldSkipConstraint(config config.MermerdConfig, tables []ErdTableData, co } func getConstraintData(config config.MermerdConfig, constraint database.ConstraintResult) ErdConstraintData { + pkTableName := getTableName(config, database.TableDetail{Schema: constraint.PkSchema, Name: constraint.PkTable}) + fkTableName := getTableName(config, database.TableDetail{Schema: constraint.FkSchema, Name: constraint.FkTable}) + constraintLabel := constraint.ColumnName if config.OmitConstraintLabels() { constraintLabel = "" } + if relationshipLabel := findRelationshipLabel(config, pkTableName, fkTableName); relationshipLabel != "" { + constraintLabel = relationshipLabel + } return ErdConstraintData{ - PkTableName: getTableName(config, database.TableDetail{Schema: constraint.PkSchema, Name: constraint.PkTable}), - FkTableName: getTableName(config, database.TableDetail{Schema: constraint.FkSchema, Name: constraint.FkTable}), + PkTableName: pkTableName, + FkTableName: fkTableName, Relation: getRelation(constraint), ConstraintLabel: constraintLabel, } } +func findRelationshipLabel(config config.MermerdConfig, pkTableName, fkTableName string) string { + for _, label := range config.RelationshipLabels() { + if label.PkName == pkTableName && label.FkName == fkTableName { + return label.Label + } + } + return "" +} + func getTableName(config config.MermerdConfig, table database.TableDetail) string { if !config.ShowSchemaPrefix() { return table.Name diff --git a/diagram/diagram_util_test.go b/diagram/diagram_util_test.go index 0b5ecda..9bc4371 100644 --- a/diagram/diagram_util_test.go +++ b/diagram/diagram_util_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/assert" + "github.com/KarnerTh/mermerd/config" "github.com/KarnerTh/mermerd/database" "github.com/KarnerTh/mermerd/mocks" ) @@ -284,11 +285,27 @@ func TestShouldSkipConstraint(t *testing.T) { } func TestGetConstraintData(t *testing.T) { + t.Run("The column name is used as the constraint label", func(t *testing.T) { + // Arrange + configMock := mocks.MermerdConfig{} + configMock.On("OmitConstraintLabels").Return(false).Once() + configMock.On("ShowSchemaPrefix").Return(false).Twice() + configMock.On("RelationshipLabels").Return([]config.RelationshipLabel{}).Once() + constraint := database.ConstraintResult{ColumnName: "Column1"} + + // Act + result := getConstraintData(&configMock, constraint) + + // Assert + configMock.AssertExpectations(t) + assert.Equal(t, result.ConstraintLabel, "Column1") + }) t.Run("OmitConstraintLabels should remove the constraint label", func(t *testing.T) { // Arrange configMock := mocks.MermerdConfig{} configMock.On("OmitConstraintLabels").Return(true).Once() configMock.On("ShowSchemaPrefix").Return(false).Twice() + configMock.On("RelationshipLabels").Return([]config.RelationshipLabel{}).Once() constraint := database.ConstraintResult{ColumnName: "Column1"} // Act @@ -298,6 +315,56 @@ func TestGetConstraintData(t *testing.T) { configMock.AssertExpectations(t) assert.Equal(t, result.ConstraintLabel, "") }) + t.Run("If a relationship label exists, it should be used", func(t *testing.T) { + // Arrange + configMock := mocks.MermerdConfig{} + configMock.On("OmitConstraintLabels").Return(true).Once() + configMock.On("ShowSchemaPrefix").Return(false).Twice() + configMock.On("RelationshipLabels").Return([]config.RelationshipLabel{ + { + PkName: "pk", + FkName: "fk", + Label: "relationship-label", + }, + }).Once() + constraint := database.ConstraintResult{ + PkTable: "pk", + FkTable: "fk", + ColumnName: "Column1", + } + + // Act + result := getConstraintData(&configMock, constraint) + + // Assert + configMock.AssertExpectations(t) + assert.Equal(t, result.ConstraintLabel, "relationship-label") + }) + t.Run("If a relationship label exists, it should be used even if we don't omit constraint labels", func(t *testing.T) { + // Arrange + configMock := mocks.MermerdConfig{} + configMock.On("OmitConstraintLabels").Return(false).Once() + configMock.On("ShowSchemaPrefix").Return(false).Twice() + configMock.On("RelationshipLabels").Return([]config.RelationshipLabel{ + { + PkName: "pk", + FkName: "fk", + Label: "relationship-label", + }, + }).Once() + constraint := database.ConstraintResult{ + PkTable: "pk", + FkTable: "fk", + ColumnName: "Column1", + } + + // Act + result := getConstraintData(&configMock, constraint) + + // Assert + configMock.AssertExpectations(t) + assert.Equal(t, result.ConstraintLabel, "relationship-label") + }) } func TestGetTableName(t *testing.T) { diff --git a/mocks/MermerdConfig.go b/mocks/MermerdConfig.go index 3ed867d..3c78e93 100644 --- a/mocks/MermerdConfig.go +++ b/mocks/MermerdConfig.go @@ -2,7 +2,10 @@ package mocks -import mock "github.com/stretchr/testify/mock" +import ( + config "github.com/KarnerTh/mermerd/config" + mock "github.com/stretchr/testify/mock" +) // MermerdConfig is an autogenerated mock type for the MermerdConfig type type MermerdConfig struct { @@ -109,6 +112,22 @@ func (_m *MermerdConfig) OutputFileName() string { return r0 } +// RelationshipLabels provides a mock function with given fields: +func (_m *MermerdConfig) RelationshipLabels() []config.RelationshipLabel { + ret := _m.Called() + + var r0 []config.RelationshipLabel + if rf, ok := ret.Get(0).(func() []config.RelationshipLabel); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]config.RelationshipLabel) + } + } + + return r0 +} + // SchemaPrefixSeparator provides a mock function with given fields: func (_m *MermerdConfig) SchemaPrefixSeparator() string { ret := _m.Called() From 5b3631d56f68570796bd869ed34da25b33ab57eb Mon Sep 17 00:00:00 2001 From: Dan Goslen Date: Sat, 30 Sep 2023 13:10:45 -0400 Subject: [PATCH 3/7] First full working version --- exampleRunConfig.yaml | 2 ++ readme.md | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/exampleRunConfig.yaml b/exampleRunConfig.yaml index c7449a4..32f40ab 100644 --- a/exampleRunConfig.yaml +++ b/exampleRunConfig.yaml @@ -24,5 +24,7 @@ showDescriptions: - enumValues - columnComments - notNull +relationshipLabels: + - "public_article public_article_comment : has_many" showSchemaPrefix: true schemaPrefixSeparator: "_" diff --git a/readme.md b/readme.md index 28c65e5..12e9227 100644 --- a/readme.md +++ b/readme.md @@ -86,6 +86,7 @@ via `mermerd -h` --showSchemaPrefix show schema prefix in table name --useAllSchemas use all available schemas --useAllTables use all available tables + --relationshipLabels strings use a different label besides the column name for specific table relationships; overrides `omitConstraintLabels` if specified ``` If the flag `--showAllConstraints` is provided, mermerd will print out all constraints of the selected tables, even when @@ -152,6 +153,12 @@ showDescriptions: - notNull showSchemaPrefix: true schemaPrefixSeparator: "_" + +# Names must match the pattern +relationshipLabels: + ? - public_table + - public_another-table + : label ``` ## Example usages From 163fb48b398426a8ad14f0f8bdfae5de746f5e88 Mon Sep 17 00:00:00 2001 From: Dan Goslen Date: Sat, 30 Sep 2023 17:03:20 -0400 Subject: [PATCH 4/7] Use a map for faster lookup --- config/config.go | 2 +- config/relationship_label.go | 6 ++--- diagram/diagram.go | 3 ++- diagram/diagram_util.go | 15 +++-------- diagram/diagram_util_test.go | 41 +++++++++++++++--------------- diagram/relationship_label_map.go | 42 +++++++++++++++++++++++++++++++ 6 files changed, 72 insertions(+), 37 deletions(-) create mode 100644 diagram/relationship_label_map.go diff --git a/config/config.go b/config/config.go index 4c4c3af..7cdddb4 100644 --- a/config/config.go +++ b/config/config.go @@ -76,7 +76,7 @@ func (c config) SelectedTables() []string { func (c config) RelationshipLabels() []RelationshipLabel { labels := viper.GetStringSlice(RelationshipLabelsKey) - return ParseLabels(labels) + return parseLabels(labels) } func (c config) EncloseWithMermaidBackticks() bool { diff --git a/config/relationship_label.go b/config/relationship_label.go index 1c6b2e3..78ea064 100644 --- a/config/relationship_label.go +++ b/config/relationship_label.go @@ -14,10 +14,10 @@ type RelationshipLabel struct { Label string } -func ParseLabels(labels []string) []RelationshipLabel { +func parseLabels(labels []string) []RelationshipLabel { var relationshipLabels []RelationshipLabel for _, label := range labels { - parsed, err := ParseLabel(label) + parsed, err := parseLabel(label) if err != nil { logrus.Warnf("label '%s' is not in the correct format", label) continue @@ -27,7 +27,7 @@ func ParseLabels(labels []string) []RelationshipLabel { return relationshipLabels } -func ParseLabel(label string) (RelationshipLabel, error) { +func parseLabel(label string) (RelationshipLabel, error) { label = strings.Trim(label, " \t") matched, groups := match(label) if !matched { diff --git a/diagram/diagram.go b/diagram/diagram.go index 710af9c..76193ca 100644 --- a/diagram/diagram.go +++ b/diagram/diagram.go @@ -59,12 +59,13 @@ func (d diagram) Create(result *database.Result) error { } var constraints []ErdConstraintData + relationshipLabelMap := BuildRelationshipLabelMap(d.config) for _, constraint := range allConstraints { if shouldSkipConstraint(d.config, tableData, constraint) { continue } - constraints = append(constraints, getConstraintData(d.config, constraint)) + constraints = append(constraints, getConstraintData(d.config, relationshipLabelMap, constraint)) } diagramData := ErdDiagramData{ diff --git a/diagram/diagram_util.go b/diagram/diagram_util.go index a694666..fd0d1fd 100644 --- a/diagram/diagram_util.go +++ b/diagram/diagram_util.go @@ -91,7 +91,7 @@ func shouldSkipConstraint(config config.MermerdConfig, tables []ErdTableData, co return !(tableNameInSlice(tables, constraint.PkTable) && tableNameInSlice(tables, constraint.FkTable)) } -func getConstraintData(config config.MermerdConfig, constraint database.ConstraintResult) ErdConstraintData { +func getConstraintData(config config.MermerdConfig, labelMap RelationshipLabelMap, constraint database.ConstraintResult) ErdConstraintData { pkTableName := getTableName(config, database.TableDetail{Schema: constraint.PkSchema, Name: constraint.PkTable}) fkTableName := getTableName(config, database.TableDetail{Schema: constraint.FkSchema, Name: constraint.FkTable}) @@ -99,8 +99,8 @@ func getConstraintData(config config.MermerdConfig, constraint database.Constrai if config.OmitConstraintLabels() { constraintLabel = "" } - if relationshipLabel := findRelationshipLabel(config, pkTableName, fkTableName); relationshipLabel != "" { - constraintLabel = relationshipLabel + if relationshipLabel, found := labelMap.LookupRelationshipLabel(pkTableName, fkTableName); found { + constraintLabel = relationshipLabel.Label } return ErdConstraintData{ @@ -111,15 +111,6 @@ func getConstraintData(config config.MermerdConfig, constraint database.Constrai } } -func findRelationshipLabel(config config.MermerdConfig, pkTableName, fkTableName string) string { - for _, label := range config.RelationshipLabels() { - if label.PkName == pkTableName && label.FkName == fkTableName { - return label.Label - } - } - return "" -} - func getTableName(config config.MermerdConfig, table database.TableDetail) string { if !config.ShowSchemaPrefix() { return table.Name diff --git a/diagram/diagram_util_test.go b/diagram/diagram_util_test.go index 9bc4371..27dc5d3 100644 --- a/diagram/diagram_util_test.go +++ b/diagram/diagram_util_test.go @@ -290,11 +290,10 @@ func TestGetConstraintData(t *testing.T) { configMock := mocks.MermerdConfig{} configMock.On("OmitConstraintLabels").Return(false).Once() configMock.On("ShowSchemaPrefix").Return(false).Twice() - configMock.On("RelationshipLabels").Return([]config.RelationshipLabel{}).Once() constraint := database.ConstraintResult{ColumnName: "Column1"} // Act - result := getConstraintData(&configMock, constraint) + result := getConstraintData(&configMock, &relationshipLabelMap{}, constraint) // Assert configMock.AssertExpectations(t) @@ -305,11 +304,11 @@ func TestGetConstraintData(t *testing.T) { configMock := mocks.MermerdConfig{} configMock.On("OmitConstraintLabels").Return(true).Once() configMock.On("ShowSchemaPrefix").Return(false).Twice() - configMock.On("RelationshipLabels").Return([]config.RelationshipLabel{}).Once() + constraint := database.ConstraintResult{ColumnName: "Column1"} // Act - result := getConstraintData(&configMock, constraint) + result := getConstraintData(&configMock, &relationshipLabelMap{}, constraint) // Assert configMock.AssertExpectations(t) @@ -320,13 +319,14 @@ func TestGetConstraintData(t *testing.T) { configMock := mocks.MermerdConfig{} configMock.On("OmitConstraintLabels").Return(true).Once() configMock.On("ShowSchemaPrefix").Return(false).Twice() - configMock.On("RelationshipLabels").Return([]config.RelationshipLabel{ - { - PkName: "pk", - FkName: "fk", - Label: "relationship-label", - }, - }).Once() + + labelsMap := &relationshipLabelMap{} + labelsMap.AddRelationshipLabel(config.RelationshipLabel{ + PkName: "pk", + FkName: "fk", + Label: "relationship-label", + }) + constraint := database.ConstraintResult{ PkTable: "pk", FkTable: "fk", @@ -334,7 +334,7 @@ func TestGetConstraintData(t *testing.T) { } // Act - result := getConstraintData(&configMock, constraint) + result := getConstraintData(&configMock, labelsMap, constraint) // Assert configMock.AssertExpectations(t) @@ -345,13 +345,14 @@ func TestGetConstraintData(t *testing.T) { configMock := mocks.MermerdConfig{} configMock.On("OmitConstraintLabels").Return(false).Once() configMock.On("ShowSchemaPrefix").Return(false).Twice() - configMock.On("RelationshipLabels").Return([]config.RelationshipLabel{ - { - PkName: "pk", - FkName: "fk", - Label: "relationship-label", - }, - }).Once() + + labelsMap := &relationshipLabelMap{} + labelsMap.AddRelationshipLabel(config.RelationshipLabel{ + PkName: "pk", + FkName: "fk", + Label: "relationship-label", + }) + constraint := database.ConstraintResult{ PkTable: "pk", FkTable: "fk", @@ -359,7 +360,7 @@ func TestGetConstraintData(t *testing.T) { } // Act - result := getConstraintData(&configMock, constraint) + result := getConstraintData(&configMock, labelsMap, constraint) // Assert configMock.AssertExpectations(t) diff --git a/diagram/relationship_label_map.go b/diagram/relationship_label_map.go new file mode 100644 index 0000000..a5f40cc --- /dev/null +++ b/diagram/relationship_label_map.go @@ -0,0 +1,42 @@ +package diagram + +import ( + "fmt" + + "github.com/KarnerTh/mermerd/config" +) + +type RelationshipLabelMap interface { + AddRelationshipLabel(label config.RelationshipLabel) + LookupRelationshipLabel(pkName, fkName string) (label config.RelationshipLabel, found bool) +} + +type relationshipLabelMap struct { + mapping map[string]config.RelationshipLabel +} + +func (r *relationshipLabelMap) AddRelationshipLabel(label config.RelationshipLabel) { + if r.mapping == nil { + r.mapping = make(map[string]config.RelationshipLabel) + } + key := r.buildMapKey(label.PkName, label.FkName) + r.mapping[key] = label +} + +func (r *relationshipLabelMap) LookupRelationshipLabel(pkName, fkName string) (label config.RelationshipLabel, found bool) { + key := r.buildMapKey(pkName, fkName) + label, found = r.mapping[key] + return +} + +func (r *relationshipLabelMap) buildMapKey(pkName, fkName string) string { + return fmt.Sprintf("%s-%s", pkName, fkName) +} + +func BuildRelationshipLabelMap(c config.MermerdConfig) RelationshipLabelMap { + labelMap := &relationshipLabelMap{} + for _, label := range c.RelationshipLabels() { + labelMap.AddRelationshipLabel(label) + } + return labelMap +} From 20324e9d4f948573e7908629883b722e18f653f2 Mon Sep 17 00:00:00 2001 From: Dan Goslen Date: Sat, 30 Sep 2023 18:19:52 -0400 Subject: [PATCH 5/7] Fix example labels --- readme.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/readme.md b/readme.md index 12e9227..2837893 100644 --- a/readme.md +++ b/readme.md @@ -156,9 +156,7 @@ schemaPrefixSeparator: "_" # Names must match the pattern
relationshipLabels: - ? - public_table - - public_another-table - : label + - "public_table public_another-table : label" ``` ## Example usages From bd1d2f3cf3dd4d3a37234936d5cf66149a746efd Mon Sep 17 00:00:00 2001 From: Dan Goslen Date: Sat, 30 Sep 2023 18:44:14 -0400 Subject: [PATCH 6/7] Adds comments for label regex --- config/relationship_label.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/config/relationship_label.go b/config/relationship_label.go index 78ea064..82b2104 100644 --- a/config/relationship_label.go +++ b/config/relationship_label.go @@ -41,7 +41,15 @@ func parseLabel(label string) (RelationshipLabel, error) { }, nil } -var labelRegex = regexp.MustCompile(`([\w\._-]+)[\s]+([\w\._-]+)[\s]+:[\s]+([\w._-]+)`) +// The regex works by creating three capture groups +// Each group allows for all word characters, `.`, `_` and `-` any number of times +// The first two groups (the table names) are separated by any amount of whitespace characters +// The table names and label are are separated by +// - any number of whitespace characters +// - a `:` +// - and then any other number of whitespace characters +// The string must start with the first table name and it must end with the label +var labelRegex = regexp.MustCompile(`^([\w\._-]+)[\s]+([\w\._-]+)[\s]+:[\s]+([\w._-]+)$`) func match(label string) (bool, [][]byte) { groups := labelRegex.FindSubmatch([]byte(label)) From 4a3511606ea84f2eef88fbc48f042d1d66b00804 Mon Sep 17 00:00:00 2001 From: Dan Goslen Date: Sun, 1 Oct 2023 09:39:15 -0400 Subject: [PATCH 7/7] Adds basic tests for relationship label map --- diagram/diagram.go | 2 +- diagram/relationship_label_map.go | 11 +++++++-- diagram/relationship_label_map_test.go | 33 ++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 3 deletions(-) create mode 100644 diagram/relationship_label_map_test.go diff --git a/diagram/diagram.go b/diagram/diagram.go index 76193ca..3223935 100644 --- a/diagram/diagram.go +++ b/diagram/diagram.go @@ -59,7 +59,7 @@ func (d diagram) Create(result *database.Result) error { } var constraints []ErdConstraintData - relationshipLabelMap := BuildRelationshipLabelMap(d.config) + relationshipLabelMap := BuildRelationshipLabelMapFromConfig(d.config) for _, constraint := range allConstraints { if shouldSkipConstraint(d.config, tableData, constraint) { continue diff --git a/diagram/relationship_label_map.go b/diagram/relationship_label_map.go index a5f40cc..46a729a 100644 --- a/diagram/relationship_label_map.go +++ b/diagram/relationship_label_map.go @@ -24,6 +24,9 @@ func (r *relationshipLabelMap) AddRelationshipLabel(label config.RelationshipLab } func (r *relationshipLabelMap) LookupRelationshipLabel(pkName, fkName string) (label config.RelationshipLabel, found bool) { + if r.mapping == nil { + return config.RelationshipLabel{}, false + } key := r.buildMapKey(pkName, fkName) label, found = r.mapping[key] return @@ -33,9 +36,13 @@ func (r *relationshipLabelMap) buildMapKey(pkName, fkName string) string { return fmt.Sprintf("%s-%s", pkName, fkName) } -func BuildRelationshipLabelMap(c config.MermerdConfig) RelationshipLabelMap { +func BuildRelationshipLabelMapFromConfig(c config.MermerdConfig) RelationshipLabelMap { + return BuildRelationshipLabelMap(c.RelationshipLabels()) +} + +func BuildRelationshipLabelMap(labels []config.RelationshipLabel) RelationshipLabelMap { labelMap := &relationshipLabelMap{} - for _, label := range c.RelationshipLabels() { + for _, label := range labels { labelMap.AddRelationshipLabel(label) } return labelMap diff --git a/diagram/relationship_label_map_test.go b/diagram/relationship_label_map_test.go new file mode 100644 index 0000000..192b5b8 --- /dev/null +++ b/diagram/relationship_label_map_test.go @@ -0,0 +1,33 @@ +package diagram_test + +import ( + "testing" + + "github.com/KarnerTh/mermerd/config" + "github.com/KarnerTh/mermerd/diagram" + "github.com/stretchr/testify/assert" +) + +func TestEmptyRelationshipMapDoesNotError(t *testing.T) { + relationshipMap := diagram.BuildRelationshipLabelMap([]config.RelationshipLabel{}) + + _, found := relationshipMap.LookupRelationshipLabel("pk", "fk") + + assert.False(t, found) +} + +func TestRelationshipMapCanAddAndLookupLabel(t *testing.T) { + relationshipMap := diagram.BuildRelationshipLabelMap([]config.RelationshipLabel{}) + + exampleLabel := config.RelationshipLabel{ + PkName: "name", + FkName: "another-name", + Label: "a-label", + } + relationshipMap.AddRelationshipLabel(exampleLabel) + + actual, found := relationshipMap.LookupRelationshipLabel("name", "another-name") + + assert.True(t, found) + assert.Equal(t, actual, exampleLabel) +}