Skip to content

Commit aaf6f5f

Browse files
authored
Fix class of DTypeCast wrapper module (#67)
We were incorrectly always using the DTypeCastModule for all submodules in the model, which would lead to always performing the input and output casting. This should only be done at the root model
1 parent 385d06e commit aaf6f5f

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

autoparallel/cast_parametrization.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,8 @@ def getter(
9696
# Different instances of the same class can resolve their parameter access to instance-specific getters
9797
# (which contains unique objects used in that instance-specific parameter's unshard operation).
9898
namespace[p_name] = create_dtype_cast_managed_attr(p_name)
99-
new_cls = type(
100-
f"DTypeCast{cls.__name__}", (DTypeCastModule, cls), namespace
101-
)
99+
cls_t = (DTypeCastModule, cls) if mod is model else (cls,)
100+
new_cls = type(f"DTypeCast{cls.__name__}", cls_t, namespace)
102101
cls_key_to_dtype_cast_cls[(cls, param_properties_key)] = new_cls
103102
mod.__class__ = new_cls
104103
mod._name_to_dtype_cast_managed_attr_getter = param_properties

0 commit comments

Comments
 (0)