diff --git a/rest_framework_hstore/serializers.py b/rest_framework_hstore/serializers.py index 01f911c..347080b 100755 --- a/rest_framework_hstore/serializers.py +++ b/rest_framework_hstore/serializers.py @@ -18,46 +18,16 @@ class HStoreSerializer(ModelSerializer): to django-rest-framework """ def __init__(self, *args, **kwargs): - self.__map_virtual_fields() + self.contribute_to_field_mapping() super(HStoreSerializer, self).__init__(*args, **kwargs) - def __map_virtual_fields(self): + def contribute_to_field_mapping(self): """ - the standard DRF field_mapping uses model field classes as keys: - - field_mapping = { - ... - models.DateField: DateField, - ... - } - - we need to add strings with the name of the classes to the mapping: - - field_mapping = { - ... - models.DateField: DateField, - "models.DateField": DateField, - ... - } - - The reason is that a virtual field won't match the standard django field, - but can match the string version. + add DictionaryField to field_mapping """ - # add DictionaryField to field_mapping - self.field_mapping[DictionaryField] = HStoreField # TODO: support ReferenceField + self.field_mapping[DictionaryField] = HStoreField - additional_fields = {} - # iterate over self.field_mapping - for field_class, serializer_field in self.field_mapping.items(): - # if the field can be represented as a string - if hasattr(field_class, '__name__'): - # add mapping using string instead of class - additional_fields[field_class.__name__] = serializer_field - - # update field_mapping dictionary - self.field_mapping.update(additional_fields) - def get_field(self, model_field): """ Creates a default instance of a basic non-relational field. @@ -112,9 +82,14 @@ def get_field(self, model_field): if model_field.__class__ == DictionaryField and model_field.schema: kwargs['schema'] = True + + # === django-rest-framework-hstore specific ==== + # if available, use __basefield__ attribute instead of __class__ + # this will cause DRF to pick the correct DRF-field + key = getattr(model_field, '__basefield__', model_field.__class__) try: - return self.field_mapping[model_field.__class__](**kwargs) + return self.field_mapping[key](**kwargs) except KeyError: pass diff --git a/setup.py b/setup.py index 30e6700..a620aef 100755 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ license = 'BSD' install_requires = [ 'djangorestframework', - 'django_hstore' + 'django_hstore>=1.3.1' ] classifiers=[ 'Development Status :: 3 - Alpha',