Skip to content
This repository has been archived by the owner on Jan 11, 2021. It is now read-only.

Commit

Permalink
fix #167
Browse files Browse the repository at this point in the history
  • Loading branch information
Ellery Newcomer committed Dec 10, 2014
1 parent 3d1be10 commit 99ac413
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 3 deletions.
14 changes: 12 additions & 2 deletions rest_framework_swagger/docgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,24 @@ def _find_field_serializers(self, serializers, found_serializers=set()):
"""
Returns set of serializers discovered from fields
"""
def get_thing(field, key):
if rest_framework.VERSION >= '3.0.0':
from rest_framework.serializers import ListSerializer
if isinstance(field, ListSerializer):
return key(field.child)
return key(field)

serializers_set = set()
for serializer in serializers:
fields = serializer().get_fields()
for name, field in fields.items():
if isinstance(field, BaseSerializer):
serializers_set.add(field)
serializers_set.add(get_thing(field, lambda f: f))
if field not in found_serializers:
serializers_set.update(self._find_field_serializers((field.__class__, ), serializers_set))
serializers_set.update(
self._find_field_serializers(
(get_thing(field, lambda f: f.__class__),),
serializers_set))

return serializers_set

Expand Down
5 changes: 5 additions & 0 deletions rest_framework_swagger/introspectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def strip_params_from_docstring(docstring):
def get_serializer_name(serializer):
if serializer is None:
return None
if rest_framework.VERSION >= '3.0.0':
from rest_framework.serializers import ListSerializer
assert serializer != ListSerializer, "uh oh, what now?"
if isinstance(serializer, ListSerializer):
serializer = serializer.child

if inspect.isclass(serializer):
return serializer.__name__
Expand Down
35 changes: 35 additions & 0 deletions rest_framework_swagger/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,20 @@ class ASerializer(serializers.Serializer):
self.assertFalse([s for s in serializerses
if isinstance(s, ASerializer)])

def test_nested_many_serializer(self):
class ASerializer(serializers.Serializer):
point = CommentSerializer()
query = QuerySerializer(many=True)

docgen = DocumentationGenerator()
serializerses = docgen._find_field_serializers([ASerializer])
self.assertTrue([s for s in serializerses
if isinstance(s, CommentSerializer)])
self.assertTrue([s for s in serializerses
if isinstance(s, QuerySerializer)])
self.assertFalse([s for s in serializerses
if isinstance(s, ASerializer)])


class IntrospectorHelperTest(TestCase):
def test_strip_yaml_from_docstring(self):
Expand Down Expand Up @@ -467,6 +481,27 @@ class TestView(APIView):

self.assertEqual(expected, docstring)

def test_get_serializer_name1(self):
self.assertEqual(
"CommentSerializer",
IntrospectorHelper.get_serializer_name(CommentSerializer))
self.assertEqual(
"CommentSerializer",
IntrospectorHelper.get_serializer_name(CommentSerializer()))

def test_get_serializer_name2(self):
class DaSerializer(serializers.Serializer):
comments = CommentSerializer(many=True)

serializer = DaSerializer()
comments = serializer.get_fields()["comments"]
self.assertEqual(
"DaSerializer",
IntrospectorHelper.get_serializer_name(serializer))
self.assertEqual(
"CommentSerializer",
IntrospectorHelper.get_serializer_name(comments))


class ViewSetTestIntrospectorTest(TestCase):
def test_get_allowed_methods_list(self):
Expand Down
4 changes: 4 additions & 0 deletions tests/cigar_example/cigar_example/restapi/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,7 @@ class JambalayaQuerySerializer(serializers.Serializer):
class CigarJambalayaSerializer(serializers.Serializer):
cigar = CigarSerializer()
jambalaya = JambalayaSerializer()


class JambalayaCigarsSerializer(serializers.Serializer):
cigar = CigarSerializer(many=True)
1 change: 1 addition & 0 deletions tests/cigar_example/cigar_example/restapi/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
url(r'^jambalaya_find/$', views.find_jambalaya, name='find-jambalaya'),
url(r'^jambalaya_retrieve/$', views.retrieve_jambalaya, name='retrieve-jambalaya'),
url(r'^drop_cigar_in_jambalaya/$', views.drop_cigar_in_jambalaya, name='cigar-jambalaya'),
url(r'^mix_cigars_in_jambalaya/$', views.mix_cigars_in_jambalaya, name='mix-cigars-jambalaya'),
)

urlpatterns += router.urls
17 changes: 16 additions & 1 deletion tests/cigar_example/cigar_example/restapi/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def link():
from cigar_example.app.models import Cigar, Manufacturer, Country, Jambalaya
from .serializers import CigarSerializer, ManufacturerSerializer, \
CountrySerializer, JambalayaSerializer, JambalayaQuerySerializer, \
CigarJambalayaSerializer
CigarJambalayaSerializer, JambalayaCigarsSerializer


class CigarViewSet(viewsets.ModelViewSet):
Expand Down Expand Up @@ -227,3 +227,18 @@ def drop_cigar_in_jambalaya(request):
if serializer.is_valid():
return Response("mmm.. an acquired taste!", status=status.HTTP_201_CREATED)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)


@api_view(['POST'])
def mix_cigars_in_jambalaya(request):
"""
Make a diverse cigar jambalaya. (In case you're wondering, I have no idea
how to try out this api, it just illustrates what nested many=True
serializers look like in swagger)
---
serializer: ..serializers.JambalayaCigarsSerializer
"""
serializer = JambalayaCigarsSerializer(data=request.DATA)
if serializer.is_valid():
return Response("mmm.. an acquired taste!", status=status.HTTP_201_CREATED)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)

0 comments on commit 99ac413

Please sign in to comment.