From 048b3480e198da933c696aec41bab49ce97e0568 Mon Sep 17 00:00:00 2001 From: Vanderhoof Date: Sun, 17 Mar 2024 08:02:24 +0100 Subject: [PATCH] feat: fix equality check again, don't allow duplicate refs (v3.1.6) --- pydbml/classes/base.py | 5 ++--- pydbml/classes/column.py | 8 ++++++++ pydbml/classes/index.py | 1 + pydbml/classes/note.py | 1 + pydbml/classes/project.py | 2 ++ pydbml/classes/reference.py | 1 + pydbml/classes/table.py | 1 + pydbml/classes/table_group.py | 1 + pydbml/database.py | 8 ++++---- 9 files changed, 21 insertions(+), 7 deletions(-) diff --git a/pydbml/classes/base.py b/pydbml/classes/base.py index 07c4330..ec8bfd9 100644 --- a/pydbml/classes/base.py +++ b/pydbml/classes/base.py @@ -9,6 +9,7 @@ class SQLObject: Base class for all SQL objects. ''' required_attributes: Tuple[str, ...] = () + dont_compare_fields = () def check_attributes_for_sql(self): ''' @@ -36,14 +37,12 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return False # not comparing those because they are circular references - not_compared_fields = ('parent', 'table', 'database') self_dict = dict(self.__dict__) other_dict = dict(other.__dict__) - for field in not_compared_fields: + for field in self.dont_compare_fields: self_dict.pop(field, None) other_dict.pop(field, None) return self_dict == other_dict - return False diff --git a/pydbml/classes/column.py b/pydbml/classes/column.py index 4743581..d599c2a 100644 --- a/pydbml/classes/column.py +++ b/pydbml/classes/column.py @@ -21,6 +21,7 @@ class Column(SQLObject): '''Class representing table column.''' required_attributes = ('name', 'type') + dont_compare_fields = ('table',) def __init__(self, name: str, @@ -45,6 +46,13 @@ def __init__(self, self.default = default self.table: Optional['Table'] = None + def __eq__(self, other: 'Column') -> bool: + self_table = self.table.full_name if self.table else None + other_table = other.table.full_name if other.table else None + if self_table != other_table: + return False + return super().__eq__(other) + @property def note(self): return self._note diff --git a/pydbml/classes/index.py b/pydbml/classes/index.py index b51a4b4..a3d771a 100644 --- a/pydbml/classes/index.py +++ b/pydbml/classes/index.py @@ -19,6 +19,7 @@ class Index(SQLObject): '''Class representing index.''' required_attributes = ('subjects', 'table') + dont_compare_fields = ('table',) def __init__(self, subjects: List[Union[str, 'Column', 'Expression']], diff --git a/pydbml/classes/note.py b/pydbml/classes/note.py index a3965eb..eee65cf 100644 --- a/pydbml/classes/note.py +++ b/pydbml/classes/note.py @@ -7,6 +7,7 @@ class Note(SQLObject): + dont_compare_fields = ('parent',) def __init__(self, text: Union[str, 'Note']) -> None: self.text: str diff --git a/pydbml/classes/project.py b/pydbml/classes/project.py index 3133a85..6069fc8 100644 --- a/pydbml/classes/project.py +++ b/pydbml/classes/project.py @@ -8,6 +8,8 @@ class Project: + dont_compare_fields = ('database',) + def __init__(self, name: str, items: Optional[Dict[str, str]] = None, diff --git a/pydbml/classes/reference.py b/pydbml/classes/reference.py index f093016..cb37e3b 100644 --- a/pydbml/classes/reference.py +++ b/pydbml/classes/reference.py @@ -25,6 +25,7 @@ class Reference(SQLObject): and its `sql` property contains the ALTER TABLE clause. ''' required_attributes = ('type', 'col1', 'col2') + dont_compare_fields = ('database', '_inline') def __init__(self, type: Literal['>', '<', '-', '<>'], diff --git a/pydbml/classes/table.py b/pydbml/classes/table.py index 5022725..f493e9c 100644 --- a/pydbml/classes/table.py +++ b/pydbml/classes/table.py @@ -28,6 +28,7 @@ class Table(SQLObject): '''Class representing table.''' required_attributes = ('name', 'schema') + dont_compare_fields = ('database',) def __init__(self, name: str, diff --git a/pydbml/classes/table_group.py b/pydbml/classes/table_group.py index 3386fa0..1f38978 100644 --- a/pydbml/classes/table_group.py +++ b/pydbml/classes/table_group.py @@ -11,6 +11,7 @@ class TableGroup: but after parsing the whole document, PyDBMLParseResults class replaces them with references to actual tables. ''' + dont_compare_fields = ('database',) def __init__(self, name: str, diff --git a/pydbml/database.py b/pydbml/database.py index f2e8a35..910b6c5 100644 --- a/pydbml/database.py +++ b/pydbml/database.py @@ -83,7 +83,7 @@ def add(self, obj: Any) -> Any: raise DatabaseValidationError(f'Unsupported type {type(obj)}.') def add_table(self, obj: Table) -> Table: - if obj.database == self and obj in self.tables: + if obj in self.tables: raise DatabaseValidationError(f'{obj} is already in the database.') if obj.full_name in self.table_dict: raise DatabaseValidationError(f'Table {obj.full_name} is already in the database.') @@ -107,7 +107,7 @@ def add_reference(self, obj: Reference): 'Cannot add reference. At least one of the referenced tables' ' should belong to this database' ) - if obj.database == self and obj in self.refs: + if obj in self.refs: raise DatabaseValidationError(f'{obj} is already in the database.') self._set_database(obj) @@ -115,7 +115,7 @@ def add_reference(self, obj: Reference): return obj def add_enum(self, obj: Enum) -> Enum: - if obj.database == self and obj in self.enums: + if obj in self.enums: raise DatabaseValidationError(f'{obj} is already in the database.') for enum in self.enums: if enum.name == obj.name and enum.schema == obj.schema: @@ -126,7 +126,7 @@ def add_enum(self, obj: Enum) -> Enum: return obj def add_table_group(self, obj: TableGroup) -> TableGroup: - if obj.database == self and obj in self.table_groups: + if obj in self.table_groups: raise DatabaseValidationError(f'{obj} is already in the database.') for table_group in self.table_groups: if table_group.name == obj.name: