Skip to content

Commit

Permalink
make sure related attribute always contains a Collection
Browse files Browse the repository at this point in the history
  • Loading branch information
circulon committed Oct 29, 2024
1 parent a3e775b commit ab9308a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
19 changes: 12 additions & 7 deletions src/masoniteorm/relationships/HasManyThrough.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .BaseRelationship import BaseRelationship
from ..collection import Collection
from .BaseRelationship import BaseRelationship


class HasManyThrough(BaseRelationship):
Expand Down Expand Up @@ -130,6 +130,9 @@ def register_related(self, key, model, collection):
None
"""
related = collection.get(getattr(model, self.local_owner_key), None)
if related and not isinstance(related, Collection):
related = Collection(related)

model.add_relation({key: related if related else None})

def get_related(self, query, relation, eagers=None, callback=None):
Expand Down Expand Up @@ -203,9 +206,7 @@ def query_has(self, current_builder, method="where_exists"):

return self.distant_builder

def query_where_exists(
self, current_builder, callback, method="where_exists"
):
def query_where_exists(self, current_builder, callback, method="where_exists"):
dist_table = self.distant_builder.get_table_name()
int_table = self.intermediary_builder.get_table_name()

Expand All @@ -215,10 +216,12 @@ def query_where_exists(
f"{int_table}.{self.foreign_key}",
"=",
f"{dist_table}.{self.other_owner_key}",
).where_column(
)
.where_column(
f"{int_table}.{self.local_key}",
f"{current_builder.get_table_name()}.{self.local_owner_key}",
).when(callback, lambda q: (callback(q)))
)
.when(callback, lambda q: (callback(q)))
)

def get_with_count_query(self, current_builder, callback):
Expand Down Expand Up @@ -249,7 +252,9 @@ def get_with_count_query(self, current_builder, callback):
lambda q: (
q.where_in(
self.foreign_key,
callback(self.distant_builder.select(self.other_owner_key)),
callback(
self.distant_builder.select(self.other_owner_key)
),
)
),
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest

from src.masoniteorm.collection import Collection
from src.masoniteorm.models import Model
from src.masoniteorm.relationships import has_many_through
from tests.integrations.config.database import DATABASES
Expand Down Expand Up @@ -90,13 +91,14 @@ def test_has_many_through_can_eager_load(self):
courses = Course.where("name", "Math 101").with_("students").get()
students = courses.first().students

self.assertEqual(len(students), 2)
self.assertIsInstance(students, Collection)
self.assertEqual(students.count(), 2)

student1 = students[0]
student1 = students.shift()
self.assertIsInstance(student1, Student)
self.assertEqual(student1.name, "Alice")

student2 = students[1]
student2 = students.shift()
self.assertIsInstance(student2, Student)
self.assertEqual(student2.name, "Bob")

Expand All @@ -106,9 +108,14 @@ def test_has_many_through_can_eager_load(self):
.with_("students")
.first()
)
self.assertIsInstance(single.students, Collection)

single_get = (
Course.where("name", "History 101").with_("students").get()
)

print(single.students)
print(single_get.first().students)
self.assertEqual(single.students.count(), 1)
self.assertEqual(single_get.first().students.count(), 1)

Expand All @@ -126,6 +133,7 @@ def test_has_many_through_eager_load_can_be_empty(self):

def test_has_many_through_can_get_related(self):
course = Course.where("name", "Math 101").first()
self.assertIsInstance(course.students, Collection)
self.assertIsInstance(course.students.first(), Student)
self.assertEqual(course.students.count(), 2)

Expand Down

0 comments on commit ab9308a

Please sign in to comment.