Skip to content

Commit

Permalink
Fix one of the models TODOs (#1669)
Browse files Browse the repository at this point in the history
* fix mypy complaints for tortoise.queryset

* fix one TODO in tortoise.models

* fix codacy complaint
  • Loading branch information
waketzheng authored Jul 16, 2024
1 parent 57ac9d6 commit 9d2edb0
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 58 deletions.
32 changes: 6 additions & 26 deletions tortoise/backends/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,36 +133,16 @@ async def execute_select(
for row in raw_results:
if self.select_related_idx:
_, current_idx, _, _, path = self.select_related_idx[0]
dict_row = dict(row)
keys = list(dict_row.keys())
values = list(dict_row.values())
instance: "Model" = self.model._init_from_db(
**dict(zip(keys[:current_idx], values[:current_idx]))
)
row_items = list(dict(row).items())
instance: "Model" = self.model._init_from_db(**dict(row_items[:current_idx]))
instances: Dict[Any, Any] = {path: instance}
for (
model,
index,
model_name,
parent_model,
full_path,
) in self.select_related_idx[1:]:
for model, index, *__, full_path in self.select_related_idx[1:]:
(*path, attr) = full_path
related_values = values[current_idx : current_idx + index] # noqa
if not any(related_values):
related_items = row_items[current_idx : current_idx + index]
if not any((v for _, v in related_items)):
obj = None
else:
obj = model._init_from_db(
**dict(
zip(
(
x.split(".")[1]
for x in keys[current_idx : current_idx + index]
),
related_values,
)
)
)
obj = model._init_from_db(**{k.split(".")[1]: v for k, v in related_items})
target = instances.get(tuple(path))
if target is not None:
setattr(target, f"_{attr}", obj)
Expand Down
37 changes: 25 additions & 12 deletions tortoise/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,18 +455,15 @@ def _generate_db_fields(self) -> None:
model_field = self.fields_db_projection_reverse[key]
field = self.fields_map[model_field]

is_native_field_type = field.field_type in self.db.executor_class.DB_NATIVE
default_converter = field.__class__.to_python_value is Field.to_python_value
if (
field.skip_to_python_if_native
and field.field_type in self.db.executor_class.DB_NATIVE
):
self.db_native_fields.append((key, model_field, field))
elif not default_converter:
self.db_complex_fields.append((key, model_field, field))
elif field.field_type in self.db.executor_class.DB_NATIVE:

if is_native_field_type and (default_converter or field.skip_to_python_if_native):
self.db_native_fields.append((key, model_field, field))
else:
elif default_converter:
self.db_default_fields.append((key, model_field, field))
else:
self.db_complex_fields.append((key, model_field, field))

def _generate_filters(self) -> None:
get_overridden_filter_func = self.db.executor_class.get_overridden_filter_func
Expand Down Expand Up @@ -722,28 +719,44 @@ def _init_from_db(cls: Type[MODEL], **kwargs: Any) -> MODEL:
self._await_when_save = {}

meta = self._meta

inited_keys: Set[str] = set()
try:
# This is like so for performance reasons.
# We want to avoid conditionals and calling .to_python_value()
# Native fields are fields that are already converted to/from python to DB type
# by the DB driver
for key, model_field, field in meta.db_native_fields:
setattr(self, model_field, kwargs[key])
inited_keys.add(key)
# Fields that don't override .to_python_value() are converted without a call
# as we already know what we will be doing.
for key, model_field, field in meta.db_default_fields:
if (value := kwargs[key]) is not None:
value = field.field_type(value)
setattr(self, model_field, value)
inited_keys.add(key)
# These fields need manual .to_python_value()
for key, model_field, field in meta.db_complex_fields:
setattr(self, model_field, field.to_python_value(kwargs[key]))
inited_keys.add(key)
except KeyError:
self._partial = True
# TODO: Apply similar perf optimisation as above for partial
native_fields: List[Field] = [f for *_, f in meta.db_native_fields]
default_fields = complex_fields = None
for key, value in kwargs.items():
setattr(self, key, meta.fields_map[key].to_python_value(value))
if key in inited_keys or key not in meta.fields_map:
continue
if (field := meta.fields_map[key]) not in native_fields:
if default_fields is None:
default_fields = [f for *_, f in meta.db_default_fields]
if field in default_fields:
if value is not None:
value = field.field_type(value)
else:
if complex_fields is None:
complex_fields = [f for *_, f in meta.db_complex_fields]
value = field.to_python_value(value)
setattr(self, key, value)

return self

Expand Down
33 changes: 13 additions & 20 deletions tortoise/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,20 +556,17 @@ def values_list(self, *fields_: str, flat: bool = False) -> "ValuesListQuery[Lit
If no arguments are passed it will default to a tuple containing all fields
in order of declaration.
"""
fields_for_select_list = fields_ or [
field for field in self.model._meta.fields_map if field in self.model._meta.db_fields
] + list(self._annotations.keys())
return ValuesListQuery(
db=self._db,
model=self.model,
q_objects=self._q_objects,
single=self._single,
raise_does_not_exist=self._raise_does_not_exist,
flat=flat,
fields_for_select_list=fields_ # type: ignore
or [
field
for field in self.model._meta.fields_map.keys()
if field in self.model._meta.db_fields
]
+ list(self._annotations.keys()),
fields_for_select_list=fields_for_select_list,
distinct=self._distinct,
limit=self._limit,
offset=self._offset,
Expand Down Expand Up @@ -1480,7 +1477,7 @@ def __init__(
q_objects: List[Q],
single: bool,
raise_does_not_exist: bool,
fields_for_select_list: List[str],
fields_for_select_list: Union[Tuple[str, ...], List[str]],
limit: Optional[int],
offset: Optional[int],
distinct: bool,
Expand Down Expand Up @@ -1875,10 +1872,7 @@ def __init__(

def _make_query(self) -> None:
self.executor = self._db.executor_class(model=self.model, db=self._db)
if not self.ignore_conflicts and not self.update_fields:
self.insert_query_all = self.executor.insert_query_all
self.insert_query = self.executor.insert_query
else:
if self.ignore_conflicts or self.update_fields:
regular_columns, columns = self.executor._prepare_insert_columns()
self.insert_query = self.executor._prepare_insert_statement(
columns, ignore_conflicts=self.ignore_conflicts
Expand All @@ -1895,17 +1889,16 @@ def _make_query(self) -> None:
)
if self.update_fields:
alias = f"new_{self.model._meta.db_table}"
self.insert_query_all = self.insert_query_all.as_(alias).on_conflict( # type:ignore
*self.on_conflict
)
self.insert_query = self.insert_query.as_(alias).on_conflict( # type:ignore
self.insert_query_all = self.insert_query_all.as_(alias).on_conflict(
*self.on_conflict
)
self.insert_query = self.insert_query.as_(alias).on_conflict(*self.on_conflict)
for update_field in self.update_fields:
self.insert_query_all = self.insert_query_all.do_update( # type:ignore
update_field
)
self.insert_query = self.insert_query.do_update(update_field) # type:ignore
self.insert_query_all = self.insert_query_all.do_update(update_field)
self.insert_query = self.insert_query.do_update(update_field)
else:
self.insert_query_all = self.executor.insert_query_all
self.insert_query = self.executor.insert_query

async def _execute(self) -> None:
for instance_chunk in chunk(self.objects, self.batch_size):
Expand Down

0 comments on commit 9d2edb0

Please sign in to comment.