-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Removal of Algorithm classes. #657
Merged
junpenglao
merged 13 commits into
blackjax-devs:main
from
ciguaran:ciguaran_removing_classes
Apr 22, 2024
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
bbfac72
more
ciguaran 0b795be
removing export
ciguaran a81e76e
removal of classes, tests passing
ciguaran 91b6133
linter
ciguaran 95bdc99
fix on test
ciguaran 718ea6c
Merge branch 'main' into ciguaran_removing_classes
ciguaran 399b788
linter
ciguaran 13aa86a
Merge branch 'ciguaran_removing_classes' of github.com:ciguaran/black…
ciguaran eaf9965
removing parametrization on test
ciguaran 4d9139c
code review updates
ciguaran 90a0642
exporting as_top_level_api in dynamic_hmc
ciguaran 693e8f8
linter
ciguaran 90188e9
code review update: replace imports
ciguaran File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,71 +1,163 @@ | ||
import dataclasses | ||
from typing import Callable | ||
|
||
from blackjax._version import __version__ | ||
|
||
from .adaptation.chees_adaptation import chees_adaptation | ||
from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size | ||
from .adaptation.meads_adaptation import meads_adaptation | ||
from .adaptation.pathfinder_adaptation import pathfinder_adaptation | ||
from .adaptation.window_adaptation import window_adaptation | ||
from .base import SamplingAlgorithm, VIAlgorithm | ||
from .diagnostics import effective_sample_size as ess | ||
from .diagnostics import potential_scale_reduction as rhat | ||
from .mcmc.barker import barker_proposal | ||
from .mcmc.dynamic_hmc import dynamic_hmc | ||
from .mcmc.elliptical_slice import elliptical_slice | ||
from .mcmc.ghmc import ghmc | ||
from .mcmc.hmc import hmc | ||
from .mcmc.mala import mala | ||
from .mcmc.marginal_latent_gaussian import mgrad_gaussian | ||
from .mcmc.mclmc import mclmc | ||
from .mcmc.nuts import nuts | ||
from .mcmc.periodic_orbital import orbital_hmc | ||
from .mcmc.random_walk import additive_step_random_walk, irmh, rmh | ||
from .mcmc.rmhmc import rmhmc | ||
from .mcmc import barker | ||
from .mcmc import dynamic_hmc as _dynamic_hmc | ||
from .mcmc import elliptical_slice as _elliptical_slice | ||
from .mcmc import ghmc as _ghmc | ||
from .mcmc import hmc as _hmc | ||
from .mcmc import mala as _mala | ||
from .mcmc import marginal_latent_gaussian | ||
from .mcmc import mclmc as _mclmc | ||
from .mcmc import nuts as _nuts | ||
from .mcmc import periodic_orbital, random_walk | ||
from .mcmc import rmhmc as _rmhmc | ||
from .mcmc.random_walk import additive_step_random_walk as _additive_step_random_walk | ||
from .mcmc.random_walk import ( | ||
irmh_as_top_level_api, | ||
normal_random_walk, | ||
rmh_as_top_level_api, | ||
) | ||
from .optimizers import dual_averaging, lbfgs | ||
from .sgmcmc.csgld import csgld | ||
from .sgmcmc.sghmc import sghmc | ||
from .sgmcmc.sgld import sgld | ||
from .sgmcmc.sgnht import sgnht | ||
from .smc.adaptive_tempered import adaptive_tempered_smc | ||
from .smc.inner_kernel_tuning import inner_kernel_tuning | ||
from .smc.tempered import tempered_smc | ||
from .vi.meanfield_vi import meanfield_vi | ||
from .vi.pathfinder import pathfinder | ||
from .vi.schrodinger_follmer import schrodinger_follmer | ||
from .vi.svgd import svgd | ||
from .sgmcmc import csgld as _csgld | ||
from .sgmcmc import sghmc as _sghmc | ||
from .sgmcmc import sgld as _sgld | ||
from .sgmcmc import sgnht as _sgnht | ||
from .smc import adaptive_tempered | ||
from .smc import inner_kernel_tuning as _inner_kernel_tuning | ||
from .smc import tempered | ||
from .vi import meanfield_vi as _meanfield_vi | ||
from .vi import pathfinder as _pathfinder | ||
from .vi import schrodinger_follmer as _schrodinger_follmer | ||
from .vi import svgd as _svgd | ||
from .vi.pathfinder import PathFinderAlgorithm | ||
|
||
""" | ||
The above three classes exist as a backwards compatible way of exposing both the high level, differentiable | ||
factory and the low level components, which may not be differentiable. Moreover, this design allows for the lower | ||
level to be mostly functional programming in nature and reducing boilerplate code. | ||
""" | ||
|
||
|
||
@dataclasses.dataclass | ||
class GenerateSamplingAPI: | ||
differentiable: Callable | ||
init: Callable | ||
build_kernel: Callable | ||
|
||
def __call__(self, *args, **kwargs) -> SamplingAlgorithm: | ||
return self.differentiable(*args, **kwargs) | ||
|
||
def register_factory(self, name, callable): | ||
setattr(self, name, callable) | ||
|
||
|
||
@dataclasses.dataclass | ||
class GenerateVariationalAPI: | ||
differentiable: Callable | ||
init: Callable | ||
step: Callable | ||
sample: Callable | ||
|
||
def __call__(self, *args, **kwargs) -> VIAlgorithm: | ||
return self.differentiable(*args, **kwargs) | ||
|
||
|
||
@dataclasses.dataclass | ||
class GeneratePathfinderAPI: | ||
differentiable: Callable | ||
approximate: Callable | ||
sample: Callable | ||
|
||
def __call__(self, *args, **kwargs) -> PathFinderAlgorithm: | ||
return self.differentiable(*args, **kwargs) | ||
|
||
|
||
def generate_top_level_api_from(module): | ||
return GenerateSamplingAPI( | ||
module.as_top_level_api, module.init, module.build_kernel | ||
) | ||
|
||
|
||
# MCMC | ||
hmc = generate_top_level_api_from(_hmc) | ||
nuts = generate_top_level_api_from(_nuts) | ||
rmh = GenerateSamplingAPI(rmh_as_top_level_api, random_walk.init, random_walk.build_rmh) | ||
irmh = GenerateSamplingAPI( | ||
irmh_as_top_level_api, random_walk.init, random_walk.build_irmh | ||
) | ||
dynamic_hmc = generate_top_level_api_from(_dynamic_hmc) | ||
rmhmc = generate_top_level_api_from(_rmhmc) | ||
mala = generate_top_level_api_from(_mala) | ||
mgrad_gaussian = generate_top_level_api_from(marginal_latent_gaussian) | ||
orbital_hmc = generate_top_level_api_from(periodic_orbital) | ||
|
||
additive_step_random_walk = GenerateSamplingAPI( | ||
_additive_step_random_walk, random_walk.init, random_walk.build_additive_step | ||
) | ||
|
||
additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk) | ||
|
||
mclmc = generate_top_level_api_from(_mclmc) | ||
elliptical_slice = generate_top_level_api_from(_elliptical_slice) | ||
ghmc = generate_top_level_api_from(_ghmc) | ||
barker_proposal = generate_top_level_api_from(barker) | ||
|
||
hmc_family = [hmc, nuts] | ||
|
||
# SMC | ||
adaptive_tempered_smc = generate_top_level_api_from(adaptive_tempered) | ||
tempered_smc = generate_top_level_api_from(tempered) | ||
inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning) | ||
|
||
smc_family = [tempered_smc, adaptive_tempered_smc] | ||
"Step_fn returning state has a .particles attribute" | ||
|
||
# stochastic gradient mcmc | ||
sgld = generate_top_level_api_from(_sgld) | ||
sghmc = generate_top_level_api_from(_sghmc) | ||
sgnht = generate_top_level_api_from(_sgnht) | ||
csgld = generate_top_level_api_from(_csgld) | ||
svgd = generate_top_level_api_from(_svgd) | ||
|
||
# variational inference | ||
meanfield_vi = GenerateVariationalAPI( | ||
_meanfield_vi.as_top_level_api, | ||
_meanfield_vi.init, | ||
_meanfield_vi.step, | ||
_meanfield_vi.sample, | ||
) | ||
schrodinger_follmer = GenerateVariationalAPI( | ||
_schrodinger_follmer.as_top_level_api, | ||
_schrodinger_follmer.init, | ||
_schrodinger_follmer.step, | ||
_schrodinger_follmer.sample, | ||
) | ||
|
||
pathfinder = GeneratePathfinderAPI( | ||
_pathfinder.as_top_level_api, _pathfinder.approximate, _pathfinder.sample | ||
) | ||
|
||
|
||
__all__ = [ | ||
"__version__", | ||
"dual_averaging", # optimizers | ||
"lbfgs", | ||
"hmc", # mcmc | ||
"dynamic_hmc", | ||
"rmhmc", | ||
"mala", | ||
"mgrad_gaussian", | ||
"nuts", | ||
"orbital_hmc", | ||
"additive_step_random_walk", | ||
"rmh", | ||
"irmh", | ||
"mclmc", | ||
"elliptical_slice", | ||
"ghmc", | ||
"barker_proposal", | ||
"sgld", # stochastic gradient mcmc | ||
"sghmc", | ||
"sgnht", | ||
"csgld", | ||
"window_adaptation", # mcmc adaptation | ||
"meads_adaptation", | ||
"chees_adaptation", | ||
"pathfinder_adaptation", | ||
"mclmc_find_L_and_step_size", # mclmc adaptation | ||
"adaptive_tempered_smc", # smc | ||
"tempered_smc", | ||
"inner_kernel_tuning", | ||
"meanfield_vi", # variational inference | ||
"pathfinder", | ||
"schrodinger_follmer", | ||
"svgd", | ||
"ess", # diagnostics | ||
"rhat", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this type not
SamplingAlgorithm
? I guess it is somewhat more specific than that, but it seems like it needs to at least implement theSamplingAlgorithm
protocol...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but adding that type annotation would tell the user that any subtype of SamplingAlgorithm can be used as parameter to that function, which would not be true.