diff --git a/postgres_composite_types/forms.py b/postgres_composite_types/forms.py
index 03d530f..f32939e 100644
--- a/postgres_composite_types/forms.py
+++ b/postgres_composite_types/forms.py
@@ -109,6 +109,16 @@ def __init__(self, *args, fields=None, model=None, **kwargs):
self.widget.widgets.values()):
widget.attrs['placeholder'] = field.label
+ def prepare_value(self, value):
+ """
+ Prepare the field data for the CompositeTypeWidget, which expects data
+ as a dict.
+ """
+ if isinstance(value, CompositeType):
+ return value.__to_dict__()
+
+ return value
+
def validate(self, value):
pass
@@ -155,7 +165,9 @@ def get_bound_field(self, form, field_name):
class CompositeTypeWidget(forms.Widget):
"""
- Takes an ordered dict of widgets to produce a composite form widget
+ Takes an ordered dict of widgets to produce a composite form widget. This
+ widget knows nothing about CompositeTypes, and works only with dicts for
+ initial and output data.
"""
template_name = \
'postgres_composite_types/forms/widgets/composite_type.html'
@@ -170,7 +182,7 @@ def __init__(self, widgets, **kwargs):
@property
def is_hidden(self):
- return all(w.is_hidden for w in self.widgets)
+ return all(w.is_hidden for w in self.widgets.values())
def get_context(self, name, value, attrs):
context = super().get_context(name, value, attrs)
@@ -187,12 +199,9 @@ def get_context(self, name, value, attrs):
if id_:
widget_attrs['id'] = '%s-%s' % (id_, subname)
- subwidgets[subname] = widget.render('%s-%s' % (name, subname),
- getattr(value, subname, None),
- final_attrs)
widget_context = widget.get_context(
'%s-%s' % (name, subname),
- getattr(value, subname, None),
+ value.get(subname),
widget_attrs)
subwidgets[subname] = widget_context['widget']
diff --git a/tests/test_forms.py b/tests/test_forms.py
index 922b5e0..02744b8 100644
--- a/tests/test_forms.py
+++ b/tests/test_forms.py
@@ -3,6 +3,7 @@
from django import forms
from django.test import SimpleTestCase
+from django.test.testcases import assert_and_parse_html
from postgres_composite_types.forms import CompositeTypeField
@@ -56,6 +57,14 @@ def test_validation(self):
self.assertEqual(str(form.errors['simple_field'][0]),
'A number: Enter a whole number.')
+ # Fields with validation errors should render with their invalid input
+ self.assertHTMLContains(
+ """
+
+ """,
+ str(form['simple_field']))
+
def test_subfield_validation(self):
"""Errors on subfields should be accessible"""
form = self.SimpleForm(data={
@@ -85,6 +94,42 @@ def test_nested_prefix(self):
self.assertEqual(a_bound_field.html_name,
'step1-simple_field-a')
+ def test_initial_data(self):
+ """
+ Check that forms with initial data render with the fields prepopulated.
+ """
+ initial = SimpleType(
+ a=1, b='foo', c=datetime.datetime(2016, 5, 24, 17, 38, 32))
+ form = self.SimpleForm(initial={'simple_field': initial})
+
+ self.assertHTMLContains(
+ """
+
+ """,
+ str(form['simple_field']))
+
+ # pylint:disable=invalid-name
+ def assertHTMLContains(self, text, content, count=None, msg=None):
+ """
+ Assert that the HTML snippet ``text`` is found within the HTML snippet
+ ``content``. Like assertContains, but works with plain strings instead
+ of Response instances.
+ """
+ content = assert_and_parse_html(
+ self, content, None, "HTML content to search in is not valid:")
+ text = assert_and_parse_html(
+ self, text, None, "HTML content to search for is not valid:")
+
+ matches = content.count(text)
+ if count is None:
+ self.assertTrue(
+ matches > 0, msg=msg or 'Could not find HTML snippet')
+ else:
+ self.assertEqual(
+ matches, count,
+ msg=msg or 'Found %d matches, expecting %d' % (matches, count))
+
class OptionalFieldTests(SimpleTestCase):
"""