-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add DictionaryMatchingOp #509
base: main
Are you sure you want to change the base?
Conversation
📚 Documentation |
@rkcatarina : If you join this weeks hackathon, would you be interested to work on this? |
This already looks promising :) Why did you change from the list of x values (in the initial commit here) to only supporting a single T1? It would be great if the matching op can be used with multiple parameters (fingerprinting for example) |
… Change to double precision of dictionary matching without DictionaryMatchingOp
… Change to double precision of dictionary matching without DictionaryMatchingOp
…dified DictionaryMatchOp and corresponding example
Coverage Report
|
@rkcatarina
to reset your local state to what is on github |
torch.testing.assert_close(t1_matched, t1, atol=1e-4, rtol=0.0) | ||
|
||
|
||
# TODO: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some todos
m0_1 = rng_1.rand_tensor(shape, dtype=dtype, low=0.2, high=1.0) | ||
t1_1 = rng_1.rand_tensor(shape, dtype=dtype.to_real(), low=0.1, high=1.0) | ||
|
||
rng_2 = RandomGenerator(3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can reuse the same rng
t1_2 = rng_2.rand_tensor(shape, dtype=dtype.to_real(), low=0.1, high=1.0) | ||
|
||
#concatenation of the tensors | ||
m0_cat = torch.cat((m0_1, m0_2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this to the #dictionary matching when appending the individual tensors section
operator = DictionaryMatchOp(model, index_of_scaling_parameter=index_of_scaling_parameter) | ||
operator.append(m0, t1) | ||
|
||
if index_of_scaling_parameter is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just remove this test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the cutrrent version does not actually test behavior of the operator.
and in the operator we actually currently seem to support any value for the index, and just wrap it around.
so even for a two parameter model, you could set it to 1000.
there is no really nice way to check it when creating the operator, only later on...
from tests import RandomGenerator | ||
|
||
|
||
@pytest.mark.parametrize('dtype', [torch.float32, torch.complex64], ids=['float32', 'complex64']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add here a "-2" case for index_of_scaling_parameter , name it predict_scale_negative_index
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some comments
Draft of dictionary matching function
Does only non-differentiable argmax matching, but good enough for now.
Needs tests, example, docstrings etc..
Related #465