From 7867f87a03521d8a1089c763b9c114fc2afdb697 Mon Sep 17 00:00:00 2001 From: Katarzyna Mitrus Date: Tue, 29 Oct 2024 19:14:14 +0100 Subject: [PATCH] [STFT][Op][Python] Fix STFT Python API to pass attribute (#27311) ### Details: - Fix STFT Python API to pass "transpose_frames" attribute ### Tickets: - 147160 --- .../python/src/openvino/runtime/opset15/ops.py | 2 +- src/bindings/python/tests/test_graph/test_create_op.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/bindings/python/src/openvino/runtime/opset15/ops.py b/src/bindings/python/src/openvino/runtime/opset15/ops.py index 45b01a11bc3588..b3a131602af703 100644 --- a/src/bindings/python/src/openvino/runtime/opset15/ops.py +++ b/src/bindings/python/src/openvino/runtime/opset15/ops.py @@ -326,7 +326,7 @@ def stft( :return: The new node performing STFT operation. """ inputs = as_nodes(data, window, frame_size, frame_step, name=name) - return _get_node_factory_opset15().create("STFT", inputs) + return _get_node_factory_opset15().create("STFT", inputs, {"transpose_frames": transpose_frames}) @nameable_op diff --git a/src/bindings/python/tests/test_graph/test_create_op.py b/src/bindings/python/tests/test_graph/test_create_op.py index 87787e1e29bc32..98d0ec3583882c 100644 --- a/src/bindings/python/tests/test_graph/test_create_op.py +++ b/src/bindings/python/tests/test_graph/test_create_op.py @@ -2492,8 +2492,8 @@ def test_stft(): window = ov.parameter([7], name="window", dtype=np.float32) frame_size = ov.constant(np.array(11, dtype=np.int32)) frame_step = ov.constant(np.array(3, dtype=np.int32)) - transpose_frames = True + transpose_frames = False op = ov_opset15.stft(data, window, frame_size, frame_step, transpose_frames) assert op.get_type_name() == "STFT" @@ -2501,6 +2501,14 @@ def test_stft(): assert op.get_output_element_type(0) == Type.f32 assert op.get_output_shape(0) == [4, 13, 6, 2] + transpose_frames = True + op = ov_opset15.stft(data, window, frame_size, frame_step, transpose_frames) + + assert op.get_type_name() == "STFT" + assert op.get_output_size() == 1 + assert op.get_output_element_type(0) == Type.f32 + assert op.get_output_shape(0) == [4, 6, 13, 2] + def test_search_sorted(): sorted_sequence = ov.parameter([7, 256, 200, 200], name="sorted", dtype=np.float32)