From e1aa8c4d35b7d9516cae21f3121c3e0da9b2b20f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?N=C3=A9fix=20Estrada?= Date: Wed, 26 Jul 2023 18:05:51 +0200 Subject: [PATCH] feat: add default database option --- README.md | 41 +++++++++++++++++++++++++++++++++++ rethinkdb_mock/ast.py | 11 ++++++++++ rethinkdb_mock/db.py | 31 +++++++++++++++++++++----- rethinkdb_mock/rql_rewrite.py | 10 +++++++++ 4 files changed, 88 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 60a87d0..199e6db 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,47 @@ pipenv install --dev rethinkdb-mock # ] ``` +### Set the default database for the connection + +> Like `r.connect(db='database')` + +```python + from pprint import pprint + from rethinkdb_mock import MockThink + import rethinkdb as r + + db = MockThink({ + 'dbs': { + 'tara': { + 'tables': { + 'people': [ + {'id': 'john-id', 'first_name': 'John', 'last_name': 'Generic'}, + {'id': 'sam-id', 'first_name': 'Sam', 'last_name': 'Dull'}, + {'id': 'adam-id', 'first_name': 'Adam', 'last_name': 'Average'} + ] + } + } + } + 'default': 'tara' + }) + + with db.connect() as conn: + + r.db('tara').table('people').index_create( + 'full_name', + lambda doc: doc['last_name'] + doc['first_name'] + ).run(conn) + + r.table('people').index_wait().run(conn) + + result = r..table('people').get_all( + 'GenericJohn', 'AverageAdam', index='full_name' + ).run(conn) + pprint(list(result)) + # {'id': 'john-id', 'first_name': 'John', 'last_name': 'Generic'}, + # {'id': 'adam-id', 'first_name': 'Adam', 'last_name': 'Average'} +``` + ### Full support for secondary indexes ```python diff --git a/rethinkdb_mock/ast.py b/rethinkdb_mock/ast.py index d517ca1..cbc1f49 100644 --- a/rethinkdb_mock/ast.py +++ b/rethinkdb_mock/ast.py @@ -726,6 +726,17 @@ def __init__(self, *args, **kwargs): def run(self, arg, scope): return arg.list_dbs() +class TableListTL(RBase): + def __init__(self, *args, **kwargs): + pass + + def run(self, arg, scope): + tables = [] + for db in arg.list_dbs(): + tables += arg.list_tables_in_db(db) + + return tables + # ################################# # Index manipulation functions # ################################# diff --git a/rethinkdb_mock/db.py b/rethinkdb_mock/db.py index 0998278..268921f 100644 --- a/rethinkdb_mock/db.py +++ b/rethinkdb_mock/db.py @@ -6,8 +6,9 @@ from . import rtime from . import util -from .rql_rewrite import rewrite_query +from .rql_rewrite import rewrite_query, RQL_TYPE_TRANSLATIONS from .scope import Scope +from .ast_base import BinExp def fill_missing_report_results(report): @@ -198,8 +199,9 @@ def set_table(self, table_name, new_table_instance): class MockDb(object): - def __init__(self, dbs_by_name): + def __init__(self, dbs_by_name, defaultDB = None): self.dbs_by_name = dbs_by_name + self.defaultDB = defaultDB def get_db(self, db_name): return self.dbs_by_name[db_name] @@ -208,7 +210,7 @@ def set_db(self, db_name, db_data_instance): assert (isinstance(db_data_instance, MockDbData)) dbs_by_name = util.obj_clone(self.dbs_by_name) dbs_by_name[db_name] = db_data_instance - return MockDb(dbs_by_name) + return MockDb(dbs_by_name, self.defaultDB) def create_table_in_db(self, db_name, table_name): new_db = self.get_db(db_name) @@ -227,7 +229,7 @@ def create_db(self, db_name): return self.set_db(db_name, MockDbData({})) def drop_db(self, db_name): - return MockDb(util.without([db_name], self.dbs_by_name)) + return MockDb(util.without([db_name], self.dbs_by_name), self.defaultDB) def list_dbs(self): return list(self.dbs_by_name.keys()) @@ -299,8 +301,24 @@ def objects_from_pods(data): table_name, table_data, indexes ) dbs_by_name[db_name] = MockDbData(tables_by_name) - return MockDb(dbs_by_name) + defaultDB = None + if 'default' in data: + defaultDB = data['default'] + + return MockDb(dbs_by_name, defaultDB) + +def set_default_db(query, name): + if len(query._args) > 0: + if not (query._args[0].__class__ in RQL_TYPE_TRANSLATIONS and issubclass(RQL_TYPE_TRANSLATIONS[query._args[0].__class__], BinExp)): + query._args = [rethinkdb.ast.DB(name)] + query._args + + else: + set_default_db(query._args[0], name) + + else: + if query.__class__ in RQL_TYPE_TRANSLATIONS and issubclass(RQL_TYPE_TRANSLATIONS[query.__class__], BinExp): + query._args = [rethinkdb.ast.DB(name)] class MockThinkConn(object): def __init__(self, rethinkdb_mock_parent): @@ -310,6 +328,9 @@ def reset_data(self, data): self.rethinkdb_mock_parent._modify_initial_data(data) def _start(self, rql_query, **global_optargs): + if self.rethinkdb_mock_parent.data.defaultDB: + set_default_db(rql_query, self.rethinkdb_mock_parent.data.defaultDB) + return self.rethinkdb_mock_parent.run_query(rewrite_query(rql_query)) def is_open(self): diff --git a/rethinkdb_mock/rql_rewrite.py b/rethinkdb_mock/rql_rewrite.py index 856b7e9..491bcc0 100644 --- a/rethinkdb_mock/rql_rewrite.py +++ b/rethinkdb_mock/rql_rewrite.py @@ -12,6 +12,7 @@ def rewrite_query(query): RQL_TYPE_HANDLERS = {} +RQL_TYPE_TRANSLATIONS = {} def type_dispatch(rql_node): @@ -144,6 +145,7 @@ def binop_splat(Mt_Constructor, node): NORMAL_ZEROPS = { r_ast.Now: mt_ast.Now, r_ast.DbList: mt_ast.DbList, + r_ast.TableListTL: mt_ast.TableListTL, } @@ -336,27 +338,35 @@ def binop_splat(Mt_Constructor, node): for r_type, mt_type in iteritems(NORMAL_ZEROPS): RQL_TYPE_HANDLERS[r_type] = handle_generic_zerop(mt_type) + RQL_TYPE_TRANSLATIONS[r_type] = mt_type for r_type, mt_type in iteritems(NORMAL_MONOPS): RQL_TYPE_HANDLERS[r_type] = handle_generic_monop(mt_type) + RQL_TYPE_TRANSLATIONS[r_type] = mt_type for r_type, mt_type in iteritems(NORMAL_BINOPS): RQL_TYPE_HANDLERS[r_type] = handle_generic_binop(mt_type) + RQL_TYPE_TRANSLATIONS[r_type] = mt_type for r_type, arg_2_map in iteritems(BINOPS_BY_ARG_2_TYPE): RQL_TYPE_HANDLERS[r_type] = handle_generic_binop_poly_2(arg_2_map) + RQL_TYPE_TRANSLATIONS[r_type] = mt_type for r_type, mt_type in iteritems(SPLATTED_BINOPS): RQL_TYPE_HANDLERS[r_type] = binop_splat(mt_type) + RQL_TYPE_TRANSLATIONS[r_type] = mt_type for r_type, mt_type in iteritems(NORMAL_TERNOPS): RQL_TYPE_HANDLERS[r_type] = handle_generic_ternop(mt_type) + RQL_TYPE_TRANSLATIONS[r_type] = mt_type for r_type, mt_type in iteritems(OPS_BY_ARITY): RQL_TYPE_HANDLERS[r_type] = handle_n_ary(mt_type) + RQL_TYPE_TRANSLATIONS[r_type] = mt_type for r_type, type_map in iteritems(NORMAL_AGGREGATIONS): RQL_TYPE_HANDLERS[r_type] = handle_generic_aggregation(type_map) + RQL_TYPE_TRANSLATIONS[r_type] = mt_type @handles_type(r_ast.Datum)