From 87e5f14d46240ba1e8706923975d0e306b976265 Mon Sep 17 00:00:00 2001 From: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com> Date: Tue, 30 Apr 2024 00:13:16 +0000 Subject: [PATCH] support conv3d on cpu for TF --- keras/src/backend/tensorflow/nn.py | 6 ++++++ keras/src/ops/nn_test.py | 16 +++++++++++----- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/keras/src/backend/tensorflow/nn.py b/keras/src/backend/tensorflow/nn.py index cb6bff3b444..e590d10e2d3 100644 --- a/keras/src/backend/tensorflow/nn.py +++ b/keras/src/backend/tensorflow/nn.py @@ -252,6 +252,12 @@ def _conv_xla(): # If kernel's in_channel does not match input's channels, it indicates # convolution is broken down into groups. return _conv_xla() + if data_format == "channels_first" and len(inputs.shape) == 5: + inputs = convert_to_tensor(inputs) + if inputs.device.split(":")[-2] == "CPU": + inputs = tf.transpose(inputs, perm=(0, 2, 3, 4, 1)) + data_format = "channels_last" + return tf.transpose(_conv(), perm=(0, 4, 1, 2, 3)) return _conv() diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index 2c825c4f416..4f85e96cd8c 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -1445,23 +1445,29 @@ def test_conv_2d_group_2(self, strides, dilation_rate): ) self.assertAllClose(outputs, expected) - @parameterized.product(strides=(1, (1, 1, 1), 2), padding=("valid", "same")) - def test_conv_3d(self, strides, padding): - if backend.config.image_data_format() == "channels_last": + @parameterized.product( + strides=(1, (1, 1, 1), 2), + padding=("valid", "same"), + data_format=("channels_first", "channels_last"), + ) + def test_conv_3d(self, strides, padding, data_format): + if data_format == "channels_last": input_shape = (2, 8, 8, 8, 3) else: input_shape = (2, 3, 8, 8, 8) inputs_3d = np.arange(3072, dtype=float).reshape(input_shape) kernel = np.arange(162, dtype=float).reshape([3, 3, 3, 3, 2]) - outputs = knn.conv(inputs_3d, kernel, strides, padding=padding) + outputs = knn.conv( + inputs_3d, kernel, strides, padding=padding, data_format=data_format + ) expected = np_conv3d( inputs_3d, kernel, bias_weights=np.zeros((2,)), strides=strides, padding=padding, - data_format=backend.config.image_data_format(), + data_format=data_format, dilation_rate=1, groups=1, )