Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Allow serialization of custom networks #284

Merged
merged 4 commits into from
Dec 20, 2024

Conversation

vpratz
Copy link
Collaborator

@vpratz vpratz commented Dec 13, 2024

This PR addresses #228. It introduces utility functions and extends existing networks to
enable serialization of complete networks when custom network types are
passed as arguments (e.g., for sub-networks in coupling flows).

The main complications were:

  • Objects of type type (uninstantiated classes) cannot be serialized using keras.saving.serialize_keras_object, as the have no get_config function. keras.saving.get_registered_name has to be used.

  • We want to support both strings and types as parameters, leading to the need to distinguish those during manual serialization/deserialization.

  • Auto-discovery of init parameters is only active when get_config is not overridden, necessitating to manually store the configuration for serialization.

For storing the types, we use keras.saving.get_registered_name, which can be reconstructed at deserialization using keras.saving.get_registered_object.

Handling the different cases is moved the utility functions (de)serialize_val_or_type, which uses a naming scheme to determine which deserialization method to use.

The same setup can be extended to other custom types, e.g. distributions.

This commit adds utility functions and extends existing networks to
enable serialization of complete networks when custom network types are
passed as arguments (e.g., for sub-networks in coupling flows).

The main complications were:

* Objects of type `type` (uninstantiated classes) cannot be serialized
  using `keras.saving.serialize_keras_object`, as the have no
  `get_config` function.

* We want to support both strings and types as parameters, leading to
  the need to distinguish those during manual
  serialization/deserialization.

* Auto-discovery of __init__ parameters is only active when `get_config`
  is not overridden, necessitating to manually store the configuration
  for serialization.

For storing the types, we use `keras.saving.get_registered_name`,
which can be reconstructed at deserialization using
`keras.saving.get_registered_object`.

Handling the different cases is moved the utility functions
`(de)serialize_val_or_type`, which uses a naming scheme to determine
which deserialization method to use.

The same setup can be extended to other custom types, e.g.
distributions.
@vpratz vpratz requested a review from LarsKue December 13, 2024 11:28
@vpratz
Copy link
Collaborator Author

vpratz commented Dec 13, 2024

If you agree with the general design, I think the main question is for which parameters we want to use this. For now, I only implemented it for subnets, but it should be easily transferable to other parameters like distributions. Do we want to apply this broadly before merging this, or do we do it incrementally?

I am not totally happy with my adaptations to the test, but I did not see a better way to solve this. Simply adding the subnet fixture to all inference nets doubles the test time for the networks, approximately from 5min to 10min. As this is only relevant for serialization, I deemed this not acceptable and split it up, leading to a reduced runtime but less readable test code. Any ideas to make those tests prettier/easier maintainable are welcome.

@vpratz vpratz marked this pull request as ready for review December 13, 2024 11:35
@paul-buerkner
Copy link
Contributor

Thank you for working on this feature! I think it is very important to have!

I am not the best person to ask about the internals so I am refering to @LarsKue and @stefanradev93 for a proper review.

@stefanradev93
Copy link
Contributor

Hi Valentin, I like the general design. I would only opt for a more verbose name of the utility: (de)serialize_value_or_type. I also agree with your choice to keep the runtime of the tests low. Do you like me to merge the PR or did I miss a certain TODO?

@vpratz
Copy link
Collaborator Author

vpratz commented Dec 20, 2024

Hey Stefan, thanks for taking a look, I have renamed the functions. I think for this set of changes we are ready to merge. If you have time you can comment on the following:
For which parameters do we want to use this? We have it for subnets now, but we could also do it e.g. for distributions.

@stefanradev93
Copy link
Contributor

stefanradev93 commented Dec 20, 2024

Thanks! I think the most pertinent usage will be for subnets. We can always enable it for distros later if there is demand.

@stefanradev93 stefanradev93 merged commit 2068b5e into bayesflow-org:dev Dec 20, 2024
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants