Skip to content

Commit

Permalink
Set fixed values and column attrs on autogen class (#800)
Browse files Browse the repository at this point in the history
  • Loading branch information
rly authored Oct 26, 2023
1 parent c00ae11 commit efced9e
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 7 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
- Added `target_tables` attribute to `DynamicTable` to allow users to specify the target table of any predefined
`DynamicTableRegion` columns of a `DynamicTable` subclass. @rly [#971](https://github.com/hdmf-dev/hdmf/pull/971)

### Bug fixes
- Updated custom class generation to handle specs with fixed values and required names. @rly [#800](https://github.com/hdmf-dev/hdmf/pull/800)
- Fixed custom class generation of `DynamicTable` subtypes to set attributes corresponding to column names for correct write. @rly [#800](https://github.com/hdmf-dev/hdmf/pull/800)

## HDMF 3.10.0 (October 3, 2023)

Since version 3.9.1 should have been released as 3.10.0 but failed to release on PyPI and conda-forge, this release
Expand Down
36 changes: 30 additions & 6 deletions src/hdmf/build/classgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,19 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i
'doc': field_spec['doc']}
if cls._ischild(dtype) and issubclass(parent_cls, Container) and not isinstance(field_spec, LinkSpec):
fields_conf['child'] = True
# if getattr(field_spec, 'value', None) is not None: # TODO set the fixed value on the class?
# fields_conf['settable'] = False
fixed_value = getattr(field_spec, 'value', None)
if fixed_value is not None:
fields_conf['settable'] = False
if isinstance(field_spec, (BaseStorageSpec, LinkSpec)) and field_spec.data_type is not None:
# subgroups, datasets, and links with data types can have fixed names
fixed_name = getattr(field_spec, 'name', None)
if fixed_name is not None:
fields_conf['required_name'] = fixed_name
classdict.setdefault(parent_cls._fieldsname, list()).append(fields_conf)

if fixed_value is not None: # field has fixed value - do not create arg on __init__
return

docval_arg = dict(
name=attr_name,
doc=field_spec.doc,
Expand Down Expand Up @@ -285,17 +294,27 @@ def post_process(cls, classdict, bases, docval_args, spec):
# set default name in docval args if provided
cls._set_default_name(docval_args, spec.default_name)

@classmethod
def _get_attrs_not_to_set_init(cls, classdict, parent_docval_args):
return parent_docval_args

@classmethod
def set_init(cls, classdict, bases, docval_args, not_inherited_fields, name):
# get docval arg names from superclass
base = bases[0]
parent_docval_args = set(arg['name'] for arg in get_docval(base.__init__))
new_args = list()
attrs_to_set = list()
fixed_value_attrs_to_set = list()
attrs_not_to_set = cls._get_attrs_not_to_set_init(classdict, parent_docval_args)
for attr_name, field_spec in not_inherited_fields.items():
# store arguments for fields that are not in the superclass and not in the superclass __init__ docval
# so that they are set after calling base.__init__
if attr_name not in parent_docval_args:
new_args.append(attr_name)
# except for fields that have fixed values -- these are set at the class level
fixed_value = getattr(field_spec, 'value', None)
if fixed_value is not None:
fixed_value_attrs_to_set.append(attr_name)
elif attr_name not in attrs_not_to_set:
attrs_to_set.append(attr_name)

@docval(*docval_args, allow_positional=AllowPositional.WARNING)
def __init__(self, **kwargs):
Expand All @@ -305,7 +324,7 @@ def __init__(self, **kwargs):
# remove arguments from kwargs that correspond to fields that are new (not inherited)
# set these arguments after calling base.__init__
new_kwargs = dict()
for f in new_args:
for f in attrs_to_set:
new_kwargs[f] = popargs(f, kwargs) if f in kwargs else None

# NOTE: the docval of some constructors do not include all of the fields. the constructor may set
Expand All @@ -319,6 +338,11 @@ def __init__(self, **kwargs):
for f, arg_val in new_kwargs.items():
setattr(self, f, arg_val)

# set the fields that have fixed values using the fields dict directly
# because the setters do not allow setting the value
for f in fixed_value_attrs_to_set:
self.fields[f] = getattr(not_inherited_fields[f], 'value')

classdict['__init__'] = __init__


Expand Down
9 changes: 9 additions & 0 deletions src/hdmf/common/io/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,12 @@ def post_process(cls, classdict, bases, docval_args, spec):
columns = classdict.get('__columns__')
if columns is not None:
classdict['__columns__'] = tuple(columns)

@classmethod
def _get_attrs_not_to_set_init(cls, classdict, parent_docval_args):
# exclude columns from the args that are set in __init__
attrs_not_to_set = parent_docval_args.copy()
if "__columns__" in classdict:
column_names = [column_conf["name"] for column_conf in classdict["__columns__"]]
attrs_not_to_set.update(column_names)
return attrs_not_to_set
5 changes: 5 additions & 0 deletions src/hdmf/spec/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,11 @@ def data_type_inc(self):
''' The data type of target specification '''
return self.get(_target_type_key)

@property
def data_type(self):
''' The data type of target specification '''
return self.get(_target_type_key)

def is_many(self):
return self.quantity not in (1, ZERO_OR_ONE)

Expand Down
152 changes: 151 additions & 1 deletion tests/unit/build_tests/test_classgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,156 @@ def test_multi_container_spec_one_or_more_ok(self):
assert len(multi.bars) == 1


class TestDynamicContainerFixedValue(TestCase):

def setUp(self):
self.baz_spec = GroupSpec(
doc='A test group specification with a data type',
data_type_def='Baz',
attributes=[AttributeSpec(name='attr1', doc='a string attribute', dtype='text', value="fixed")]
)
self.type_map = create_test_type_map([], {}) # empty typemap
self.spec_catalog = self.type_map.namespace_catalog.get_namespace(CORE_NAMESPACE).catalog
self.spec_catalog.register_spec(self.baz_spec, 'extension.yaml')

def test_init_docval(self):
cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) # generate the class
expected_args = {'name'} # 'attr1' should not be included
received_args = set()
for x in get_docval(cls.__init__):
received_args.add(x['name'])
self.assertSetEqual(expected_args, received_args)

def test_init_fields(self):
cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) # generate the class
self.assertEqual(cls.get_fields_conf(), ({'name': 'attr1', 'doc': 'a string attribute', 'settable': False},))

def test_init_object(self):
cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) # generate the class
obj = cls(name="test")
self.assertEqual(obj.attr1, "fixed")

def test_set_value(self):
cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) # generate the class
obj = cls(name="test")
with self.assertRaises(AttributeError):
obj.attr1 = "new"


class TestDynamicContainerIncludingFixedName(TestCase):

def setUp(self):
self.baz_spec1 = GroupSpec(
doc='A test group specification with a data type',
data_type_def='Baz1',
)
self.baz_spec2 = GroupSpec(
doc='A test dataset specification with a data type',
data_type_def='Baz2',
)
self.baz_spec3 = GroupSpec(
doc='A test group specification with a data type',
data_type_def='Baz3',
groups=[
GroupSpec(
doc='A composition inside with a fixed name',
name="my_baz1",
data_type_inc='Baz1'
),
],
datasets=[
DatasetSpec(
doc='A composition inside with a fixed name',
name="my_baz2",
data_type_inc='Baz2'
),
],
links=[
LinkSpec(
doc='A composition inside with a fixed name',
name="my_baz1_link",
target_type='Baz1'
),
],
)
self.type_map = create_test_type_map([], {}) # empty typemap
self.spec_catalog = self.type_map.namespace_catalog.get_namespace(CORE_NAMESPACE).catalog
self.spec_catalog.register_spec(self.baz_spec1, 'extension.yaml')
self.spec_catalog.register_spec(self.baz_spec2, 'extension.yaml')
self.spec_catalog.register_spec(self.baz_spec3, 'extension.yaml')

def test_gen_parent_class(self):
baz1_cls = self.type_map.get_dt_container_cls('Baz1', CORE_NAMESPACE) # generate the class
baz2_cls = self.type_map.get_dt_container_cls('Baz2', CORE_NAMESPACE)
baz3_cls = self.type_map.get_dt_container_cls('Baz3', CORE_NAMESPACE)
self.assertEqual(get_docval(baz3_cls.__init__), (
{'name': 'name', 'type': str, 'doc': 'the name of this container'},
{'name': 'my_baz1', 'doc': 'A composition inside with a fixed name', 'type': baz1_cls},
{'name': 'my_baz2', 'doc': 'A composition inside with a fixed name', 'type': baz2_cls},
{'name': 'my_baz1_link', 'doc': 'A composition inside with a fixed name', 'type': baz1_cls},
))

def test_init_fields(self):
cls = self.type_map.get_dt_container_cls('Baz3', CORE_NAMESPACE) # generate the class
self.assertEqual(cls.get_fields_conf(), (
{
'name': 'my_baz1',
'doc': 'A composition inside with a fixed name',
'child': True,
'required_name': 'my_baz1'
},
{
'name': 'my_baz2',
'doc': 'A composition inside with a fixed name',
'child': True,
'required_name': 'my_baz2'
},
{
'name': 'my_baz1_link',
'doc': 'A composition inside with a fixed name',
'required_name': 'my_baz1_link'
},
))

def test_set_field(self):
baz1_cls = self.type_map.get_dt_container_cls('Baz1', CORE_NAMESPACE) # generate the class
baz2_cls = self.type_map.get_dt_container_cls('Baz2', CORE_NAMESPACE)
baz3_cls = self.type_map.get_dt_container_cls('Baz3', CORE_NAMESPACE)
baz1 = baz1_cls(name="my_baz1")
baz2 = baz2_cls(name="my_baz2")
baz1_link = baz1_cls(name="my_baz1_link")
baz3 = baz3_cls(name="test", my_baz1=baz1, my_baz2=baz2, my_baz1_link=baz1_link)
self.assertEqual(baz3.my_baz1, baz1)
self.assertEqual(baz3.my_baz2, baz2)
self.assertEqual(baz3.my_baz1_link, baz1_link)

def test_set_field_bad(self):
baz1_cls = self.type_map.get_dt_container_cls('Baz1', CORE_NAMESPACE) # generate the class
baz2_cls = self.type_map.get_dt_container_cls('Baz2', CORE_NAMESPACE)
baz3_cls = self.type_map.get_dt_container_cls('Baz3', CORE_NAMESPACE)

baz1 = baz1_cls(name="test")
baz2 = baz2_cls(name="my_baz2")
baz1_link = baz1_cls(name="my_baz1_link")
msg = "Field 'my_baz1' on Baz3 must be named 'my_baz1'."
with self.assertRaisesWith(ValueError, msg):
baz3_cls(name="test", my_baz1=baz1, my_baz2=baz2, my_baz1_link=baz1_link)

baz1 = baz1_cls(name="my_baz1")
baz2 = baz2_cls(name="test")
baz1_link = baz1_cls(name="my_baz1_link")
msg = "Field 'my_baz2' on Baz3 must be named 'my_baz2'."
with self.assertRaisesWith(ValueError, msg):
baz3_cls(name="test", my_baz1=baz1, my_baz2=baz2, my_baz1_link=baz1_link)

baz1 = baz1_cls(name="my_baz1")
baz2 = baz2_cls(name="my_baz2")
baz1_link = baz1_cls(name="test")
msg = "Field 'my_baz1_link' on Baz3 must be named 'my_baz1_link'."
with self.assertRaisesWith(ValueError, msg):
baz3_cls(name="test", my_baz1=baz1, my_baz2=baz2, my_baz1_link=baz1_link)


class TestGetClassSeparateNamespace(TestCase):

def setUp(self):
Expand Down Expand Up @@ -899,7 +1049,7 @@ def test_process_field_spec_link(self):
spec=GroupSpec('dummy', 'doc')
)

expected = {'__fields__': [{'name': 'attr3', 'doc': 'a link'}]}
expected = {'__fields__': [{'name': 'attr3', 'doc': 'a link', 'required_name': 'attr3'}]}
self.assertDictEqual(classdict, expected)

def test_post_process_fixed_name(self):
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/common/test_generate_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,13 @@ def test_dynamic_table_region_non_dtr_target(self):
self.TestDTRTable(name='test_dtr_table', description='my table',
target_tables={'optional_col3': test_table})

def test_attribute(self):
test_table = self.TestTable(name='test_table', description='my test table')
assert test_table.my_col is not None
assert test_table.indexed_col is not None
assert test_table.my_col is test_table['my_col']
assert test_table.indexed_col is test_table['indexed_col'].target

def test_roundtrip(self):
# NOTE this does not use H5RoundTripMixin because this requires custom validation
test_table = self.TestTable(name='test_table', description='my test table')
Expand Down

0 comments on commit efced9e

Please sign in to comment.