diff --git a/butterfree/migrations/database_migration/cassandra_migration.py b/butterfree/migrations/database_migration/cassandra_migration.py index 9bcc268d..4141a7d5 100644 --- a/butterfree/migrations/database_migration/cassandra_migration.py +++ b/butterfree/migrations/database_migration/cassandra_migration.py @@ -90,7 +90,7 @@ def _get_alter_column_type_query(self, columns: List[Diff], table_name: str) -> return f"ALTER TABLE {table_name} ALTER ({parsed_columns});" @staticmethod - def _get_create_table_query(columns: List[Dict[str, Any]], table_name: str,) -> str: + def _get_create_table_query(columns: List[Dict[str, Any]], table_name: str) -> str: """Creates CQL statement to create a table. Args: @@ -193,30 +193,3 @@ def _get_queries( logging.info("This operation is not supported by Cassandra DB.") return queries - - def create_query( - self, - fs_schema: List[Dict[str, Any]], - table_name: str, - db_schema: List[Dict[str, Any]] = None, - write_on_entity: bool = None, - ) -> List[str]: - """Create a query regarding Cassandra. - - Args: - fs_schema: object that contains feature set's schemas. - table_name: table name. - db_schema: object that contains the table of a given db schema. - write_on_entity: boolean flag that indicates if data is being - loaded into an entity table. - - Returns: - List of queries regarding schemas' changes. - - """ - if not db_schema: - return [self._get_create_table_query(fs_schema, table_name)] - - schema_diff = self._get_diff(fs_schema, db_schema) - - return self._get_queries(schema_diff, table_name, write_on_entity) diff --git a/butterfree/migrations/database_migration/database_migration.py b/butterfree/migrations/database_migration/database_migration.py index b188e509..a2106f3c 100644 --- a/butterfree/migrations/database_migration/database_migration.py +++ b/butterfree/migrations/database_migration/database_migration.py @@ -40,6 +40,37 @@ class DatabaseMigration(ABC): """Abstract base class for Migrations.""" @abstractmethod + def _get_create_table_query( + self, columns: List[Dict[str, Any]], table_name: str + ) -> Any: + """Creates desired statement to create a table. + + Args: + columns: object that contains column's schemas. + table_name: table name. + + Returns: + Create table query. + + """ + pass + + @abstractmethod + def _get_queries( + self, schema_diff: Set[Diff], table_name: str, write_on_entity: bool = None + ) -> Any: + """Create the desired queries for migration. + + Args: + schema_diff: list of Diff objects. + table_name: table name. + + Returns: + List of queries. + + """ + pass + def create_query( self, fs_schema: List[Dict[str, Any]], @@ -53,6 +84,12 @@ def create_query( The desired queries for the given database. """ + if not db_schema: + return [self._get_create_table_query(fs_schema, table_name)] + + schema_diff = self._get_diff(fs_schema, db_schema) + + return self._get_queries(schema_diff, table_name, write_on_entity) def _apply_migration(self, feature_set: FeatureSet) -> None: """Apply the migration in the respective database."""