-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BindReturn type hint for make_funsor #518
base: master
Are you sure you want to change the base?
Conversation
Thanks for adding this! WDYT about making |
Do you mean that in the example below @make_funsor
def Unroll(
x: Has[{"ax"}], # noqa: F821
ax: Fresh[lambda ax, k: Bint[ax.size - k + 1]],
k: Value[int],
kernel: Fresh[lambda k: Bint[k]],
) -> Fresh[lambda x: x]:
return x(**{ax.name: ax + kernel}) |
Yes, exactly. On a related note, we should also start using |
I like the idea. I'll make the changes then. |
Sorry for taking so long to review this (especially since I suggested you try it in the first place). I am still not sure how to go about fixing alpha-conversion as a whole in a way that remains compatible with cons-hashing, so I had put off thinking about details. I think for the behavior implemented in this PR to be safe and correct by construction in general, we would need to eagerly alpha-mangle the arguments to a We could write a simple decorator for rewrite rules to perform this extra step: def bind_args(term):
def binding_wrapper(rule):
def wrapped_rule(*args):
mangled_args = reflect.interpret(term, *args)._ast_values
return rule(*mangled_args)
return functools.wraps(rule)(wrapped_rule)
return binding_wrapper To illustrate the use of @make_funsor
def Softmax(
x: Has[{"ax"}], # noqa: F821
ax: Fresh[lambda ax: ax],
) -> Fresh[lambda x: x]:
return None
@eager.register(Softmax, Tensor, Variable)
@bind_args(Softmax)
def _eager_softmax(x, ax):
y = x - x.reduce(ops.logaddexp, ax)
return y.exp() Of course, this is less ergonomic than the original syntax, so I could imagine folding |
Addresses #481.
BindReturn
type hint is used for binding and returning a variable. For example:or