Skip to content

Commit

Permalink
feat: fix equality check again, don't allow duplicate refs (v3.1.6)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vanderhoof committed Mar 17, 2024
1 parent a6a4e84 commit 048b348
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 7 deletions.
5 changes: 2 additions & 3 deletions pydbml/classes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class SQLObject:
Base class for all SQL objects.
'''
required_attributes: Tuple[str, ...] = ()
dont_compare_fields = ()

def check_attributes_for_sql(self):
'''
Expand Down Expand Up @@ -36,14 +37,12 @@ def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False
# not comparing those because they are circular references
not_compared_fields = ('parent', 'table', 'database')

self_dict = dict(self.__dict__)
other_dict = dict(other.__dict__)

for field in not_compared_fields:
for field in self.dont_compare_fields:
self_dict.pop(field, None)
other_dict.pop(field, None)

return self_dict == other_dict
return False
8 changes: 8 additions & 0 deletions pydbml/classes/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class Column(SQLObject):
'''Class representing table column.'''

required_attributes = ('name', 'type')
dont_compare_fields = ('table',)

def __init__(self,
name: str,
Expand All @@ -45,6 +46,13 @@ def __init__(self,
self.default = default
self.table: Optional['Table'] = None

def __eq__(self, other: 'Column') -> bool:
self_table = self.table.full_name if self.table else None
other_table = other.table.full_name if other.table else None
if self_table != other_table:
return False
return super().__eq__(other)

@property
def note(self):
return self._note
Expand Down
1 change: 1 addition & 0 deletions pydbml/classes/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
class Index(SQLObject):
'''Class representing index.'''
required_attributes = ('subjects', 'table')
dont_compare_fields = ('table',)

def __init__(self,
subjects: List[Union[str, 'Column', 'Expression']],
Expand Down
1 change: 1 addition & 0 deletions pydbml/classes/note.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


class Note(SQLObject):
dont_compare_fields = ('parent',)

def __init__(self, text: Union[str, 'Note']) -> None:
self.text: str
Expand Down
2 changes: 2 additions & 0 deletions pydbml/classes/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@


class Project:
dont_compare_fields = ('database',)

def __init__(self,
name: str,
items: Optional[Dict[str, str]] = None,
Expand Down
1 change: 1 addition & 0 deletions pydbml/classes/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Reference(SQLObject):
and its `sql` property contains the ALTER TABLE clause.
'''
required_attributes = ('type', 'col1', 'col2')
dont_compare_fields = ('database', '_inline')

def __init__(self,
type: Literal['>', '<', '-', '<>'],
Expand Down
1 change: 1 addition & 0 deletions pydbml/classes/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Table(SQLObject):
'''Class representing table.'''

required_attributes = ('name', 'schema')
dont_compare_fields = ('database',)

def __init__(self,
name: str,
Expand Down
1 change: 1 addition & 0 deletions pydbml/classes/table_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class TableGroup:
but after parsing the whole document, PyDBMLParseResults class replaces
them with references to actual tables.
'''
dont_compare_fields = ('database',)

def __init__(self,
name: str,
Expand Down
8 changes: 4 additions & 4 deletions pydbml/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def add(self, obj: Any) -> Any:
raise DatabaseValidationError(f'Unsupported type {type(obj)}.')

def add_table(self, obj: Table) -> Table:
if obj.database == self and obj in self.tables:
if obj in self.tables:
raise DatabaseValidationError(f'{obj} is already in the database.')
if obj.full_name in self.table_dict:
raise DatabaseValidationError(f'Table {obj.full_name} is already in the database.')
Expand All @@ -107,15 +107,15 @@ def add_reference(self, obj: Reference):
'Cannot add reference. At least one of the referenced tables'
' should belong to this database'
)
if obj.database == self and obj in self.refs:
if obj in self.refs:
raise DatabaseValidationError(f'{obj} is already in the database.')

self._set_database(obj)
self.refs.append(obj)
return obj

def add_enum(self, obj: Enum) -> Enum:
if obj.database == self and obj in self.enums:
if obj in self.enums:
raise DatabaseValidationError(f'{obj} is already in the database.')
for enum in self.enums:
if enum.name == obj.name and enum.schema == obj.schema:
Expand All @@ -126,7 +126,7 @@ def add_enum(self, obj: Enum) -> Enum:
return obj

def add_table_group(self, obj: TableGroup) -> TableGroup:
if obj.database == self and obj in self.table_groups:
if obj in self.table_groups:
raise DatabaseValidationError(f'{obj} is already in the database.')
for table_group in self.table_groups:
if table_group.name == obj.name:
Expand Down

0 comments on commit 048b348

Please sign in to comment.