Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BindReturn type hint for make_funsor #518

Open
wants to merge 12 commits into
base: master
Choose a base branch
from

Conversation

ordabayevy
Copy link
Member

@ordabayevy ordabayevy commented Apr 4, 2021

Addresses #481.

BindReturn type hint is used for binding and returning a variable. For example:

@make_funsor
def Unroll(
    x: Has[{"ax"}],  # noqa: F821
    ax: BindReturn[lambda ax, k: Bint[ax.size - k + 1]],
    k: Value[int],
    kernel: Fresh[lambda k: Bint[k]],
) -> Fresh[lambda x: x]:
    return x(**{ax.name: ax + kernel})

x = random_tensor(OrderedDict(a=Bint[5]))
with reflect:
    y = Unroll(x, "a", 2, "kernel")
assert y.fresh == frozenset({"a", "kernel"})
assert all(bound in y.x.inputs and bound[1:8] == "__BOUND" for bound in y.bound)
check_funsor(y, {"a": Bint[5 - 2 + 1], "kernel": Bint[2]}, Real)

or

@make_funsor
def Softmax(
    x: Has[{"ax"}],  # noqa: F821
    ax: BindReturn[lambda ax: ax],
) -> Fresh[lambda x: x]:
    y = x - x.reduce(ops.logaddexp, ax)
    return y.exp()

x = random_tensor(OrderedDict(a=Bint[3], b=Bint[4]))
with reflect:
    y = Softmax(x, "a")
assert y.fresh == frozenset({"a"})
assert all(bound in y.x.inputs and bound[1:8] == "__BOUND" for bound in y.bound)
check_funsor(y, {"a": Bint[3], "b": Bint[4]}, Real)

@ordabayevy ordabayevy added enhancement New feature or request awaiting review labels Apr 4, 2021
@eb8680
Copy link
Member

eb8680 commented Apr 5, 2021

Thanks for adding this! WDYT about making BindReturn the default behavior of Fresh? That would be consistent with default behavior in existing terms like Cat, Stack and Independent and avoid growing the number of special make_funsor type annotations. The conservative alternative is having make_funsor raise an error when a Fresh variable appears in the inputs of another argument.

@ordabayevy
Copy link
Member Author

WDYT about making BindReturn the default behavior of Fresh?

Do you mean that in the example below Fresh type hint would be smart to make ax both bound and fresh and make kernel only fresh?

@make_funsor
def Unroll(
    x: Has[{"ax"}],  # noqa: F821
    ax: Fresh[lambda ax, k: Bint[ax.size - k + 1]],
    k: Value[int],
    kernel: Fresh[lambda k: Bint[k]],
) -> Fresh[lambda x: x]:
    return x(**{ax.name: ax + kernel})

@eb8680
Copy link
Member

eb8680 commented Apr 5, 2021

Do you mean that in the example below Fresh type hint would be smart to make ax both bound and fresh and make kernel only fresh?

Yes, exactly.

On a related note, we should also start using funsor.domains.Dependent rather than Fresh for annotating return types of make_funsor, but that's for another PR.

@ordabayevy
Copy link
Member Author

Yes, exactly.

I like the idea. I'll make the changes then.

@eb8680
Copy link
Member

eb8680 commented May 12, 2021

Sorry for taking so long to review this (especially since I suggested you try it in the first place). I am still not sure how to go about fixing alpha-conversion as a whole in a way that remains compatible with cons-hashing, so I had put off thinking about details.

I think for the behavior implemented in this PR to be safe and correct by construction in general, we would need to eagerly alpha-mangle the arguments to a make_funsor term with a bound Fresh variable before evaluating that term with a rewrite rule to guarantee that there are no collisions between the fresh variable and the implicitly bound variable inside the rule body. In fact, we could always do this for any Funsor; the fact that we only alpha-convert in reflect is an (important) optimization.

We could write a simple decorator for rewrite rules to perform this extra step:

def bind_args(term):

    def binding_wrapper(rule):
        def wrapped_rule(*args):
            mangled_args = reflect.interpret(term, *args)._ast_values
            return rule(*mangled_args)
        return functools.wraps(rule)(wrapped_rule)

    return binding_wrapper

To illustrate the use of bind_args, in your Softmax example, we could separate the term definition and the default reduction rule and manually apply bind_args to the rule:

@make_funsor
def Softmax(
    x: Has[{"ax"}],  # noqa: F821
    ax: Fresh[lambda ax: ax],
) -> Fresh[lambda x: x]:
    return None

@eager.register(Softmax, Tensor, Variable)
@bind_args(Softmax)
def _eager_softmax(x, ax):
    y = x - x.reduce(ops.logaddexp, ax)
    return y.exp()

Of course, this is less ergonomic than the original syntax, so I could imagine folding bind_args into make_funsor or even into interpretations, although this would come at a considerable computational cost with the current implementation of alpha-conversion and is not necessary for this particular example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants