Skip to content

Commit

Permalink
[PT FE] Support aten::lerp and aten::lerp_ (#27272)
Browse files Browse the repository at this point in the history
### Details:
 - *Support `aten::lerp` and `aten::lerp_`*

### Tickets:
 - *CVS-156191*
  • Loading branch information
mvafin authored Oct 28, 2024
1 parent 0d113d9 commit 1baf261
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 0 deletions.
36 changes: 36 additions & 0 deletions src/frontends/pytorch/src/op/lerp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/subtract.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_lerp(const NodeContext& context) {
// Tensor = aten::lerp(%lhs.1, %rhs.1, %self.weight)
num_inputs_check(context, 3, 3);
Output<Node> start;
Output<Node> end;
std::tie(start, end) = get_inputs_with_promoted_types(context, 0, 1);

Output<Node> weight = context.get_input(2);
auto scale = context.mark_node(std::make_shared<v1::Subtract>(end, start));
weight = context.mark_node(std::make_shared<v1::ConvertLike>(weight, scale));
auto delta = context.mark_node(std::make_shared<v1::Multiply>(scale, weight));
return {context.mark_node(std::make_shared<v1::Add>(start, delta))};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ OP_CONVERTER(translate_inverse);
OP_CONVERTER(translate_is_nonzero);
OP_CONVERTER(translate_layer_norm);
OP_CONVERTER(translate_len);
OP_CONVERTER(translate_lerp);
OP_CONVERTER(translate_linalg_cross);
OP_CONVERTER(translate_linalg_norm);
OP_CONVERTER(translate_linalg_matrix_norm);
Expand Down Expand Up @@ -509,6 +510,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::le", op::translate_1to1_match_2_inputs_align_types<opset10::LessEqual>},
{"aten::leaky_relu", op::translate_1to1_match_2_inputs<opset10::PRelu>},
{"aten::len", op::translate_len},
{"aten::lerp", op::translate_lerp},
// lift op is torchscript specific op responsible for tensors coping with guarantee of new memory allocation
{"aten::lift", op::skip_node},
{"aten::lift_fresh", op::skip_node},
Expand Down
56 changes: 56 additions & 0 deletions tests/layer_tests/pytorch_tests/test_lerp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
import torch

from pytorch_layer_test_class import PytorchLayerTest, skip_if_export


class TestLerp(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(2, 5, 3, 4).astype(np.float32), self.input_rhs)

def create_model(self, weight, op_type):
class aten_lerp(torch.nn.Module):
def __init__(self, weight, op) -> None:
super().__init__()
self.weight = weight
self.forward = self.forward1 if op == "lerp" else self.forward2

def forward1(self, lhs, rhs):
return torch.lerp(lhs, rhs, weight=self.weight)

def forward2(self, lhs, rhs):
return lhs.lerp_(rhs, weight=self.weight)

return aten_lerp(weight, op_type), None, f"aten::{op_type}"

@pytest.mark.parametrize("weight", (-0.5,
0,
0.5,
1,
2,
skip_if_export([1, 5, 3, 4]))
)
@pytest.mark.parametrize("input_shape_rhs", [[2, 5, 3, 4],
[1, 5, 3, 4],
[1]])
@pytest.mark.parametrize("op_type", ["lerp", "lerp_"])
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_torch_export
@pytest.mark.precommit_fx_backend
def test_lerp(self, ie_device, precision, ir_version,
weight, input_shape_rhs, op_type):
self.input_rhs = np.random.randn(*input_shape_rhs).astype(np.float32)
if isinstance(weight, list):
weight = torch.rand(weight)
self._test(
*self.create_model(weight, op_type),
ie_device,
precision,
ir_version,
use_convert_model=True,
)

0 comments on commit 1baf261

Please sign in to comment.