Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Truncated RVs #96

Open
ricardoV94 opened this issue Dec 5, 2021 · 0 comments · May be fixed by #131
Open

Implement Truncated RVs #96

ricardoV94 opened this issue Dec 5, 2021 · 0 comments · May be fixed by #131
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@ricardoV94
Copy link
Contributor

ricardoV94 commented Dec 5, 2021

The logprob graph for truncated distributions is straightforward, not very different from that of censored distributions that we have already implemented.

However we need a new Op / graph to represent / generate truncated draws from an arbitrary RV.

Symbolically we can take truncated draws from any RV by doing rejection sampling untill all desired draws are in the range lower/upper. This however can be incredibly slow if the truncation interval corresponds to a very small mass of the original distribution.

If we have an inverse cdf function we can more easily obtain random draws by drawing from a uniform distribution in the truncation range and taking the inverse cdf of those points.

Finally we might have specialized Ops, graphs that could be dispatched to an RV (for instance if one has already implemented a TruncatedNormal based on the scipy distribution)

I think we could use a hierarchical dispatch strategy to go from the most to least specialized forms, and perhaps wrap the result in an OpFromGraph for easy logprob parsing. Here is a pseudo-code suggestion of how this could be implemented;

def truncate(rv, lower, upper):
  # Try to dispatch on specific graph/Op
  try:
    truncated = _truncate(rv)
  except NotImplementedError:

  # Try to use [i]cdf if they are implemented for given RV
    try: 
      icdf = _icdf(rv)
      cdf_lower = at.exp(_logcdf(rv, lower))
      cdf_upper = at.exp(_logcdf(rv, upper))
      uniform = at.random.uniform(cdf_lower, cdf_upper, size=rv.size)
      truncated = icdf(uniform)
    except NotImplementedError:
    
      # Default to slow while scan graph
      ...

  # Wrap truncated in a custom OpFromGraph that can be easily parsed for the logprob component?
  ...

This is just a quick sketch, happy to hear other ideas.

@brandonwillard brandonwillard added enhancement New feature or request help wanted Extra attention is needed labels Dec 5, 2021
@rlouf rlouf moved this to Graph features in AePPL Roadmap Feb 6, 2023
@rlouf rlouf moved this from Graph features to Transforms in AePPL Roadmap Feb 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
Status: Transforms
Development

Successfully merging a pull request may close this issue.

2 participants