Skip to content

Commit

Permalink
[BUG] Fix NotImpelementedError encountered while compiling the mode…
Browse files Browse the repository at this point in the history
…l `doctr_det_predictor` (#462)
  • Loading branch information
BolinSNLHM authored and vadiklyutiy committed Dec 19, 2024
1 parent d8ecff2 commit a7758db
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion python/hidet/graph/frontend/torch/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,17 @@ def _lookup_hidet_method(self, torch_method) -> Callable:
def _lookup_hidet_function(self, torch_func) -> Optional[OverloadedFunction]:
if torch_func not in Registry.registered_functions:
name = self._get_callable_name(torch_func)
from hidet.graph.ops import cast

pattern2func = {
'_dynamo_get_item_lambda': OverloadedFunction.from_lambda(lambda target, index: target[index])
'_dynamo_get_item_lambda': OverloadedFunction.from_lambda(lambda target, index: target[index]),
# Turns out the wrapped ops in issue #358 are some `numpy_method_wrapper` and `numpy_operator_wrapper`.
# According to the class definition in pytorch/torch/_dynamo/utils.py(line 2461-2497),
# they're just functionally equivalent to the original numpy functions.
'<Wrapped operator <original wrapped_ge>>': OverloadedFunction.from_lambda(lambda x, y: x >= y),
'<Wrapped method <original astype>>': OverloadedFunction.from_lambda(
lambda x, dtype: cast(x, data_type(dtype))
),
}
for pattern, func in pattern2func.items():
if pattern in name:
Expand Down

0 comments on commit a7758db

Please sign in to comment.