-
Notifications
You must be signed in to change notification settings - Fork 133
Allow string keys in eval
utility
#242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
b088b01
8976c94
b7f745e
8869a9d
d889182
6e2efaa
9008a32
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -558,6 +558,46 @@ def get_parents(self): | |||||
return [self.owner] | ||||||
return [] | ||||||
|
||||||
def convert_string_keys_to_pytensor_variables(self, inputs_to_values): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given we are in
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good, will replace the function name with suggested one. |
||||||
r"""Convert the string keys to corresponding `Variable` with nearest name. | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
inputs_to_values : | ||||||
A dictionary mapping PyTensor `Variable`\s to values. | ||||||
|
||||||
Examples | ||||||
-------- | ||||||
|
||||||
>>> import numpy as np | ||||||
>>> import pytensor.tensor as at | ||||||
>>> x = at.dscalar('x') | ||||||
>>> y = at.dscalar('y') | ||||||
>>> z = x + y | ||||||
>>> np.allclose(z.eval({'x' : 3, 'y' : 1}), 4) | ||||||
True | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does not show a use of the function. In any case we don't need a docstring with an example I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we are making it internal to eval I also think that the doc-string will be no more required. |
||||||
""" | ||||||
process_input_to_values = {} | ||||||
for i in inputs_to_values: | ||||||
if isinstance(i, str): | ||||||
nodes_with_matching_names = get_var_by_name([self], i) | ||||||
length_of_nodes_with_matching_names = len(nodes_with_matching_names) | ||||||
if length_of_nodes_with_matching_names == 0: | ||||||
raise Exception(f"{i} not found in graph") | ||||||
else: | ||||||
if length_of_nodes_with_matching_names > 1: | ||||||
warnings.warn( | ||||||
f"Found {length_of_nodes_with_matching_names} pytensor variables with name {i} taking the first declared named variable for computation" | ||||||
) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems safer to just fail. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of warning will throw Exception instead. |
||||||
process_input_to_values[ | ||||||
nodes_with_matching_names[ | ||||||
length_of_nodes_with_matching_names - 1 | ||||||
] | ||||||
] = inputs_to_values[i] | ||||||
else: | ||||||
process_input_to_values[i] = inputs_to_values[i] | ||||||
return process_input_to_values | ||||||
|
||||||
def eval(self, inputs_to_values=None): | ||||||
r"""Evaluate the `Variable`. | ||||||
|
||||||
|
@@ -597,6 +637,10 @@ def eval(self, inputs_to_values=None): | |||||
if inputs_to_values is None: | ||||||
inputs_to_values = {} | ||||||
|
||||||
inputs_to_values = self.convert_string_keys_to_pytensor_variables( | ||||||
inputs_to_values | ||||||
) | ||||||
|
||||||
if not hasattr(self, "_fn_cache"): | ||||||
self._fn_cache = dict() | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -290,9 +290,11 @@ def test_outputs_clients(self): | |
|
||
class TestEval: | ||
def setup_method(self): | ||
self.x, self.y = scalars("x", "y") | ||
self.x, self.y, self.e = scalars("x", "y", "e") | ||
self.z = self.x + self.y | ||
self.w = 2 * self.z | ||
self.t = self.e + 1 | ||
self.t.name = "e" | ||
|
||
def test_eval(self): | ||
assert self.w.eval({self.x: 1.0, self.y: 2.0}) == 6.0 | ||
|
@@ -302,6 +304,13 @@ def test_eval(self): | |
pickle.loads(pickle.dumps(self.w)), "_fn_cache" | ||
), "temporary functions must not be serialized" | ||
|
||
def test_eval_with_strings(self): | ||
assert self.w.eval({"x": 1.0, self.y: 2.0}) == 6.0 | ||
assert self.w.eval({self.z: 3}) == 6.0 | ||
|
||
def test_eval_with_strings_with_mulitple_same_name(self): | ||
assert self.t.eval({"e": 1.0}) == 2.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would create the new variables here instead of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, instead of creating variables in setup_method will add the variables in test_function. |
||
|
||
|
||
class TestAutoName: | ||
def test_auto_name(self): | ||
|
Uh oh!
There was an error while loading. Please reload this page.