diff --git a/postgres_composite_types/forms.py b/postgres_composite_types/forms.py index 03d530f..f32939e 100644 --- a/postgres_composite_types/forms.py +++ b/postgres_composite_types/forms.py @@ -109,6 +109,16 @@ def __init__(self, *args, fields=None, model=None, **kwargs): self.widget.widgets.values()): widget.attrs['placeholder'] = field.label + def prepare_value(self, value): + """ + Prepare the field data for the CompositeTypeWidget, which expects data + as a dict. + """ + if isinstance(value, CompositeType): + return value.__to_dict__() + + return value + def validate(self, value): pass @@ -155,7 +165,9 @@ def get_bound_field(self, form, field_name): class CompositeTypeWidget(forms.Widget): """ - Takes an ordered dict of widgets to produce a composite form widget + Takes an ordered dict of widgets to produce a composite form widget. This + widget knows nothing about CompositeTypes, and works only with dicts for + initial and output data. """ template_name = \ 'postgres_composite_types/forms/widgets/composite_type.html' @@ -170,7 +182,7 @@ def __init__(self, widgets, **kwargs): @property def is_hidden(self): - return all(w.is_hidden for w in self.widgets) + return all(w.is_hidden for w in self.widgets.values()) def get_context(self, name, value, attrs): context = super().get_context(name, value, attrs) @@ -187,12 +199,9 @@ def get_context(self, name, value, attrs): if id_: widget_attrs['id'] = '%s-%s' % (id_, subname) - subwidgets[subname] = widget.render('%s-%s' % (name, subname), - getattr(value, subname, None), - final_attrs) widget_context = widget.get_context( '%s-%s' % (name, subname), - getattr(value, subname, None), + value.get(subname), widget_attrs) subwidgets[subname] = widget_context['widget'] diff --git a/tests/test_forms.py b/tests/test_forms.py index 922b5e0..02744b8 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -3,6 +3,7 @@ from django import forms from django.test import SimpleTestCase +from django.test.testcases import assert_and_parse_html from postgres_composite_types.forms import CompositeTypeField @@ -56,6 +57,14 @@ def test_validation(self): self.assertEqual(str(form.errors['simple_field'][0]), 'A number: Enter a whole number.') + # Fields with validation errors should render with their invalid input + self.assertHTMLContains( + """ + + """, + str(form['simple_field'])) + def test_subfield_validation(self): """Errors on subfields should be accessible""" form = self.SimpleForm(data={ @@ -85,6 +94,42 @@ def test_nested_prefix(self): self.assertEqual(a_bound_field.html_name, 'step1-simple_field-a') + def test_initial_data(self): + """ + Check that forms with initial data render with the fields prepopulated. + """ + initial = SimpleType( + a=1, b='foo', c=datetime.datetime(2016, 5, 24, 17, 38, 32)) + form = self.SimpleForm(initial={'simple_field': initial}) + + self.assertHTMLContains( + """ + + """, + str(form['simple_field'])) + + # pylint:disable=invalid-name + def assertHTMLContains(self, text, content, count=None, msg=None): + """ + Assert that the HTML snippet ``text`` is found within the HTML snippet + ``content``. Like assertContains, but works with plain strings instead + of Response instances. + """ + content = assert_and_parse_html( + self, content, None, "HTML content to search in is not valid:") + text = assert_and_parse_html( + self, text, None, "HTML content to search for is not valid:") + + matches = content.count(text) + if count is None: + self.assertTrue( + matches > 0, msg=msg or 'Could not find HTML snippet') + else: + self.assertEqual( + matches, count, + msg=msg or 'Found %d matches, expecting %d' % (matches, count)) + class OptionalFieldTests(SimpleTestCase): """