Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2944c22
support for celu, one hot, avg pooling
danielenricocahall Oct 31, 2025
2b09bc3
ctc loss
danielenricocahall Oct 31, 2025
a37a7f4
revert test change
danielenricocahall Oct 31, 2025
2d6527e
Update keras/src/backend/openvino/nn.py
danielenricocahall Oct 31, 2025
b9d7618
Update keras/src/backend/openvino/nn.py
danielenricocahall Oct 31, 2025
a54b884
fix one hot with sparse check
danielenricocahall Oct 31, 2025
aa4ff5e
simplify pooling
danielenricocahall Oct 31, 2025
5421b36
handle dtype of one_hot
danielenricocahall Oct 31, 2025
2548b33
use swish op for silu
danielenricocahall Oct 31, 2025
80dcb9c
support for log_sigmoid
danielenricocahall Oct 31, 2025
467f418
address gemini feedback
danielenricocahall Oct 31, 2025
e5e5c0b
Update keras/src/backend/openvino/nn.py
danielenricocahall Oct 31, 2025
357f5e9
fix consolidated pool function
danielenricocahall Oct 31, 2025
a32fc57
enable testing
danielenricocahall Nov 6, 2025
4805c2e
fix max_pool call
danielenricocahall Nov 6, 2025
618cf53
fix dtype for ctc_loss
danielenricocahall Nov 6, 2025
51152ee
fix dtype for swish
danielenricocahall Nov 6, 2025
fb133d2
permit nn test to be run
danielenricocahall Nov 6, 2025
c745159
support for more activation functions, enabling tests that should sta…
danielenricocahall Nov 6, 2025
6397f92
support squareplus and sparse_plus
danielenricocahall Nov 6, 2025
38eca85
enable selu test
danielenricocahall Nov 6, 2025
fdb2d0a
support threshold
danielenricocahall Nov 6, 2025
a993580
selu test
danielenricocahall Nov 6, 2025
8714a49
Merge branch 'master' into openvino-nn-functions
danielenricocahall Nov 11, 2025
764e4fb
support scaled dot product attention
danielenricocahall Nov 11, 2025
93fd96d
Update keras/src/backend/openvino/nn.py
danielenricocahall Nov 11, 2025
c7b59f0
Update keras/src/ops/nn_test.py
danielenricocahall Nov 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,28 @@ TestMathErrors::test_stft_invalid_window
TestMathErrors::test_stft_invalid_window_shape
LinalgOpsCorrectnessTest::test_cholesky
LinalgOpsCorrectnessTest::test_cholesky_inverse
NNOpsDynamicShapeTest::test_binary_crossentropy
NNOpsDynamicShapeTest::test_categorical_crossentropy
NNOpsDynamicShapeTest::test_multi_hot_dtype_
NNOpsCorrectnessTest::test_conv_transpose_
NNOpsCorrectnessTest::test_ctc_decode
NNOpsCorrectnessTest::test_multi_hot_
NNOpsCorrectnessTest::test_binary_crossentropy
NNOpsCorrectnessTest::test_categorical_crossentropy
NNOpsCorrectnessTest::test_log_softmax_correctness_with_axis_tuple
NNOpsCorrectnessTest::test_softmax_correctness_with_axis_tuple
NNOpsCorrectnessTest::test_separable_conv_
NNOpsCorrectnessTest::test_glu
NNOpsCorrectnessTest::test_moments
NNOpsCorrectnessTest::test_normalize
NNOpsCorrectnessTest::test_polar_corectness
NNOpsCorrectnessTest::test_psnr
NNOpsCorrectnessTest::test_sparse_categorical_crossentropy
NNOpsCorrectnessTest::test_sparsemax
NNOpsCorrectnessTest::test_rms_normalization_10.0
NNOpsDtypeTest::test_ctc_decode
NNOpsDtypeTest::test_glu_
NNOpsDtypeTest::test_polar_
NNOpsDynamicShapeTest::test_glu
NNOpsBehaviorTest::test_invalid_strategy_ctc_decode
NNOpsBehaviorTest::test_logit_recovery_binary_crossentropy
1 change: 0 additions & 1 deletion keras/src/backend/openvino/excluded_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ keras/src/metrics
keras/src/models
keras/src/ops/image_test.py
keras/src/ops/linalg_test.py
keras/src/ops/nn_test.py
keras/src/optimizers
keras/src/quantizers
keras/src/random/seed_generator_test.py
Expand Down
263 changes: 247 additions & 16 deletions keras/src/backend/openvino/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from openvino import Type

from keras.src import backend
from keras.src.backend.openvino.core import OPENVINO_DTYPES
from keras.src.backend.openvino.core import OpenVINOKerasTensor
from keras.src.backend.openvino.core import get_ov_output

Expand All @@ -16,6 +17,23 @@ def relu6(x):
return OpenVINOKerasTensor(ov_opset.clamp(x, 0.0, 6.0).output(0))


def celu(x, alpha=1.0):
x = get_ov_output(x)
const_zero = get_ov_output(0.0, x.get_element_type())
const_alpha = get_ov_output(alpha, x.get_element_type())
const_one = get_ov_output(1.0, x.get_element_type())
exp_x_div_alpha = ov_opset.exp(ov_opset.divide(x, const_alpha)).output(0)
negative_branch = ov_opset.multiply(
const_alpha, ov_opset.subtract(exp_x_div_alpha, const_one)
)

celu_x = ov_opset.add(
ov_opset.maximum(x, const_zero).output(0),
ov_opset.minimum(negative_branch, const_zero).output(0),
)
return OpenVINOKerasTensor(celu_x.output(0))


def sigmoid(x):
x = get_ov_output(x)
return OpenVINOKerasTensor(ov_opset.sigmoid(x).output(0))
Expand All @@ -26,6 +44,42 @@ def tanh(x):
return OpenVINOKerasTensor(ov_opset.tanh(x).output(0))


def tanh_shrink(x):
x = get_ov_output(x)
return OpenVINOKerasTensor(ov_opset.subtract(x, ov_opset.tanh(x)).output(0))


def hard_tanh(x):
x = get_ov_output(x)
return OpenVINOKerasTensor(ov_opset.clamp(x, -1.0, 1.0).output(0))


def soft_shrink(x, threshold=0.5):
x = get_ov_output(x)
et = x.get_element_type()
thr = get_ov_output(threshold, et)
zero = get_ov_output(0.0, et)
abs_x = ov_opset.abs(x)
sub = ov_opset.subtract(abs_x, thr)
shrunk = ov_opset.maximum(sub, zero)
sign = ov_opset.sign(x)
out = ov_opset.multiply(sign, shrunk)
return OpenVINOKerasTensor(out.output(0))


def hard_shrink(x, threshold=0.5):
x = get_ov_output(x)
et = x.get_element_type()

thr = get_ov_output(threshold, et)
zero = get_ov_output(0.0, et)

cond = ov_opset.greater(ov_opset.abs(x), thr)

out = ov_opset.select(cond, x, zero)
return OpenVINOKerasTensor(out.output(0))


def softplus(x):
x = get_ov_output(x)
return OpenVINOKerasTensor(ov_opset.softplus(x).output(0))
Expand All @@ -38,14 +92,15 @@ def softsign(x):

def silu(x):
x = get_ov_output(x)
return OpenVINOKerasTensor(
ov_opset.multiply(x, ov_opset.sigmoid(x)).output(0)
)
beta = get_ov_output(1.0, x.get_element_type())
return OpenVINOKerasTensor(ov_opset.swish(x, beta=beta).output(0))


def log_sigmoid(x):
raise NotImplementedError(
"`log_sigmoid` is not supported with openvino backend"
x = get_ov_output(x)
neg_x = ov_opset.negative(x)
return OpenVINOKerasTensor(
ov_opset.negative(ov_opset.softplus(neg_x)).output(0)
)


Expand All @@ -58,6 +113,20 @@ def leaky_relu(x, negative_slope=0.2):
return OpenVINOKerasTensor(leaky_relu)


def sparse_sigmoid(x):
x = get_ov_output(x)
et = x.get_element_type()

one = get_ov_output(1.0, et)
neg_one = get_ov_output(-1.0, et)
half = get_ov_output(0.5, et)

y = ov_opset.minimum(ov_opset.maximum(x, neg_one), one)

out = ov_opset.multiply(half, ov_opset.add(y, one))
return OpenVINOKerasTensor(out.output(0))


def hard_sigmoid(x):
x = get_ov_output(x)
alpha = get_ov_output(1.0 / 6.0, x.get_element_type())
Expand Down Expand Up @@ -121,15 +190,80 @@ def log_softmax(x, axis=-1):
return OpenVINOKerasTensor(ov_opset.log_softmax(x, axis).output(0))


def squareplus(x, b=4):
x = get_ov_output(x)
et = x.get_element_type()

b = get_ov_output(b, et)
two = get_ov_output(2.0, et)

x_squared = ov_opset.multiply(x, x)
inside = ov_opset.add(x_squared, b)
root = ov_opset.sqrt(inside)
summed = ov_opset.add(x, root)

out = ov_opset.divide(summed, two)

return OpenVINOKerasTensor(out.output(0))


def sparse_plus(x):
x = get_ov_output(x)
et = x.get_element_type()

one = get_ov_output(1.0, et)
neg_one = get_ov_output(-1.0, et)
zero = get_ov_output(0.0, et)
quarter = get_ov_output(0.25, et)

x_plus_1 = ov_opset.add(x, one)
quad = ov_opset.multiply(quarter, ov_opset.multiply(x_plus_1, x_plus_1))

leq_than_neg_one = ov_opset.less_equal(x, neg_one)
less_than_one = ov_opset.less(x, one)

out = ov_opset.select(
leq_than_neg_one,
zero,
ov_opset.select(less_than_one, quad, x),
)

return OpenVINOKerasTensor(out.output(0))


def threshold(x, threshold, default_value):
x = get_ov_output(x)
et = x.get_element_type()

thr = get_ov_output(threshold, et)
dv = get_ov_output(default_value, et)

cond = ov_opset.greater(x, thr)

out = ov_opset.select(cond, x, dv)

return OpenVINOKerasTensor(out.output(0))


def max_pool(
inputs,
pool_size,
strides=None,
padding="valid",
data_format=None,
):
raise NotImplementedError(
"`max_pool` is not supported with openvino backend"
num_spatial_dims = (
get_ov_output(inputs).get_partial_shape().rank.get_length() - 2
)
kwargs = {"dilations": [1] * num_spatial_dims} # required for ov max_pool
return _pool(
inputs,
pool_size,
ov_opset.max_pool,
strides,
padding,
data_format,
**kwargs,
)


Expand All @@ -140,11 +274,52 @@ def average_pool(
padding="valid",
data_format=None,
):
raise NotImplementedError(
"`average_pool` is not supported with openvino backend"
return _pool(
inputs,
pool_size,
ov_opset.avg_pool,
strides,
padding,
data_format,
exclude_pad=True,
)


def _pool(
inputs,
pool_size,
pooling_func,
strides=None,
padding="valid",
data_format=None,
**kwargs,
):
data_format = backend.standardize_data_format(data_format)
inputs = get_ov_output(inputs)

num_spatial_dims = inputs.get_partial_shape().rank.get_length() - 2
if isinstance(pool_size, int):
pool_size = [pool_size] * num_spatial_dims

if strides is None:
strides = pool_size

strides = _adjust_strides_dilation(strides, num_spatial_dims)
pad_mode, pads_begin, pads_end = _adjust_padding(padding)
inputs = _adjust_input(inputs, num_spatial_dims, data_format)
pool_kwargs = {
"kernel_shape": pool_size,
"strides": strides,
"auto_pad": pad_mode,
"pads_begin": pads_begin,
"pads_end": pads_end,
**kwargs,
}
pooled = pooling_func(inputs, **pool_kwargs).output(0)
adjusted_pooled = _adjust_outputs(pooled, num_spatial_dims, data_format)
return OpenVINOKerasTensor(adjusted_pooled)


def _adjust_strides_dilation(
x,
num_spatial_dims,
Expand Down Expand Up @@ -374,9 +549,22 @@ def conv_transpose(


def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
raise NotImplementedError(
"`one_hot` is not supported with openvino backend"
)
if sparse:
raise ValueError("`sparse=True` is not supported with openvino backend")
x = get_ov_output(x)
if dtype is None:
dtype = backend.floatx()
ov_dtype = OPENVINO_DTYPES[dtype]
on_value = get_ov_output(1, ov_dtype)
off_value = get_ov_output(0, ov_dtype)
one_hot_encoded = ov_opset.one_hot(
x,
depth=num_classes,
axis=axis,
on_value=on_value,
off_value=off_value,
).output(0)
return OpenVINOKerasTensor(one_hot_encoded)


def multi_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
Expand Down Expand Up @@ -465,9 +653,15 @@ def batch_normalization(


def ctc_loss(target, output, target_length, output_length, mask_index=0):
raise NotImplementedError(
"`ctc_loss` is not supported with openvino backend"
target = get_ov_output(target)
output = get_ov_output(output)
target_length = get_ov_output(target_length)
output_length = get_ov_output(output_length)
ctc_loss_ = ov_opset.ctc_loss(
output, output_length, target, target_length, blank_index=mask_index
)
ctc_loss_ = ov_opset.convert(ctc_loss_, OPENVINO_DTYPES[backend.floatx()])
return OpenVINOKerasTensor(ctc_loss_.output(0))


def ctc_decode(
Expand Down Expand Up @@ -499,9 +693,46 @@ def dot_product_attention(
flash_attention=None,
attn_logits_soft_cap=None,
):
raise NotImplementedError(
"`dot_product_attention` is not supported with openvino backend"
if bias is not None:
raise NotImplementedError(
"`dot_product_attention` with `bias` is not supported "
"with openvino backend"
)
if flash_attention is not None:
raise NotImplementedError(
"`dot_product_attention` with `flash_attention` is not supported "
"with openvino backend"
)
if attn_logits_soft_cap is not None:
raise NotImplementedError(
"`dot_product_attention` with `attn_logits_soft_cap` is not "
"supported with openvino backend"
)
query = get_ov_output(query)
key = get_ov_output(key)
value = get_ov_output(value)
if query.get_element_type() != key.get_element_type():
ov_type = OPENVINO_DTYPES[backend.floatx()]
query = ov_opset.convert(query, ov_type).output(0)
key = ov_opset.convert(key, ov_type).output(0)
if value.get_element_type() != query.get_element_type():
value = ov_opset.convert(value, query.get_element_type()).output(0)
axes_const = ov_opset.constant([0, 2, 1, 3], Type.i32).output(0)

query = ov_opset.transpose(query, axes_const)
key = ov_opset.transpose(key, axes_const)
value = ov_opset.transpose(value, axes_const)
mask = get_ov_output(mask) if mask is not None else None
scale = (
get_ov_output(scale, query.get_element_type())
if scale is not None
else None
)
dpa = ov_opset.scaled_dot_product_attention(
query, key, value, attention_mask=mask, scale=scale, causal=is_causal
)
dpa = ov_opset.transpose(dpa, axes_const)
return OpenVINOKerasTensor(dpa.output(0))


def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
Expand Down
Loading
Loading