Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1. Add point cloud encoding policy - pct_policy #12

Merged
merged 1 commit into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions iris/policies/layers/keras_image_encoder_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A keras layer for encoding image into patches."""

from typing import Tuple

import tensorflow as tf


class ImageEncoder(tf.keras.layers.Layer):
"""Keras layer for encoding image into patches."""

def __init__(self,
patch_height: int,
patch_width: int,
stride_height: int,
stride_width: int,
normalize_positions: bool = True) -> None:
"""Initializes Keras layer for encoding image into patches.

Args:
patch_height: Height of image patch for encoding.
patch_width: Width of image patch for encoding.
stride_height: Stride (shift) height for consecutive image patches.
stride_width: Stride (shift) width for consecutive image patches.
normalize_positions: True to normalize patch center positions.
"""
super().__init__()
self._patch_height = patch_height
self._patch_width = patch_width
self._stride_height = stride_height
self._stride_width = stride_width
self._normalize_positions = normalize_positions

def call(self, images: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
batch_shape, image_height, image_width, channels = images.shape
if batch_shape is None:
batch_shape = tf.shape(images)[0]
patches = tf.image.extract_patches(
images,
sizes=[1, self._patch_height, self._patch_width, 1],
strides=[1, self._stride_height, self._stride_width, 1],
rates=[1, 1, 1, 1],
padding='VALID')
encoding = tf.reshape(
patches,
[batch_shape, -1, self._patch_height * self._patch_width * channels])
pos_x = tf.range(self._patch_height // 2, image_height, self._stride_height)
pos_y = tf.range(self._patch_width // 2, image_width, self._stride_width)
if self._normalize_positions:
pos_x /= image_height
pos_y /= image_width
x, y = tf.meshgrid(pos_x, pos_y)
x = tf.transpose(x)
y = tf.transpose(y)
centers = tf.stack([x, y], axis=-1)
centers = tf.reshape(centers, (-1, 2))
centers = tf.tile(centers, (batch_shape, 1))
centers = tf.reshape(centers, (batch_shape, -1, 2))
centers = tf.cast(centers, 'float32')
return encoding, centers
40 changes: 40 additions & 0 deletions iris/policies/layers/keras_image_encoder_layer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from iris.policies.layers import keras_image_encoder_layer
import numpy as np
import tensorflow as tf
from absl.testing import absltest


class ImageEncoderTest(absltest.TestCase):

def test_layer_output(self):
"""Tests the output of ImageEncoder layer."""
input_layer = tf.keras.layers.Input(
batch_input_shape=(2, 5, 6, 2), dtype="float", name="input")
output_layer = keras_image_encoder_layer.ImageEncoder(
patch_height=2,
patch_width=2,
stride_height=1,
stride_width=1)(input_layer)
model = tf.keras.models.Model(inputs=[input_layer], outputs=[output_layer])
images = np.arange(2*5*6*2).reshape((2, 5, 6, 2))
encoding, centers = model.predict(images)[0]
self.assertEqual(encoding.shape, (2, 20, 8))
self.assertEqual(centers.shape, (2, 20, 2))


if __name__ == "__main__":
absltest.main()
75 changes: 75 additions & 0 deletions iris/policies/layers/keras_masking_attention_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A keras layer for masking based attention."""

from typing import Callable
import tensorflow as tf


class FavorMaskingAttention(tf.keras.layers.Layer):
"""A keras layer for masking based attention.

A layer that creates a representation of the RGB(D)-image using attention
mechanism from https://arxiv.org/abs/2009.14794. It leverages Performer-ReLU
(go/performer) attention module in order to bypass explicit materialization of
the L x L attention tensor, where L is the number of patches (potentially even
individual pixels). This reduces time complexity of the attention module from
quadratic to linear in L and provides a gateway to processing high-resolution
images, where explicitly calculating attention tensor is not feasible. The
ranking procedure is adopted from https://arxiv.org/abs/2003.08165, where
scores of patches are defined as sums of the entries of the corresponding
column in the attention tensor. After ranking, top K tokens are preserved and
the rest of them are masked by 0.
"""

def __init__(
self,
kernel_transformation: Callable[..., tf.Tensor],
top_k: int = 5) -> None: # pytype: disable=annotation-type-mismatch
"""Initializes FavorMaskingAttention layer.

Args:
kernel_transformation: Transformation used to get finite kernel features.
top_k: Number of top patches that will be chosen to "summarize" entire
image.
"""
super().__init__()
self._kernel_transformation = kernel_transformation
self._top_k = top_k

def call(self,
queries: tf.Tensor,
keys: tf.Tensor,
values: tf.Tensor) -> tf.Tensor:
queries_prime = self._kernel_transformation(
data=tf.expand_dims(queries, axis=2),
is_query=True)
queries_prime = tf.squeeze(queries_prime, axis=2)
keys_prime = self._kernel_transformation(
data=tf.expand_dims(keys, axis=2),
is_query=False)
keys_prime = tf.squeeze(keys_prime, axis=2)
_, length, _ = queries_prime.shape
all_ones = tf.ones([1, length])
reduced_queries_prime = tf.matmul(all_ones, queries_prime)
scores = tf.matmul(reduced_queries_prime, keys_prime, transpose_b=True)
scores = tf.reshape(scores, (-1, length))
sorted_idxs = tf.argsort(scores, axis=-1, direction='DESCENDING')
cutoff = tf.gather(
scores, sorted_idxs[:, self._top_k], axis=1, batch_dims=1)
cond = scores > tf.expand_dims(cutoff, -1)
return tf.where(tf.expand_dims(cond, -1),
values,
tf.zeros_like(values))
46 changes: 46 additions & 0 deletions iris/policies/layers/keras_masking_attention_layer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from iris.policies.layers import keras_masking_attention_layer
from lingvo.core import favor_attention as favor
import numpy as np
import tensorflow as tf
from absl.testing import absltest


class FavorMaskingAttentionTest(absltest.TestCase):

def test_layer_output(self):
"""Tests the output of RankingAttention layer."""
query_layer = tf.keras.layers.Input(
batch_input_shape=(2, 3, 4), dtype="float", name="query")
key_layer = tf.keras.layers.Input(
batch_input_shape=(2, 3, 4), dtype="float", name="keys")
value_layer = tf.keras.layers.Input(
batch_input_shape=(2, 3, 4), dtype="float", name="values")
output_layer = keras_masking_attention_layer.FavorMaskingAttention(
kernel_transformation=favor.relu_kernel_transformation,
top_k=2)(query_layer, key_layer, value_layer)
model = tf.keras.models.Model(
inputs=[query_layer, key_layer, value_layer], outputs=[output_layer])
queries = np.arange(2 * 3 * 4).reshape((2, 3, 4))
top_values = model.predict((queries, queries, queries))
self.assertEqual(top_values.shape, (2, 3, 4))
true_values = np.arange(2 * 3 * 4).reshape((2, 3, 4))
true_values[:, 0, :] = 0
np.testing.assert_array_almost_equal(top_values, true_values, 1)


if __name__ == "__main__":
absltest.main()
40 changes: 40 additions & 0 deletions iris/policies/layers/keras_positional_encoding_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A keras layer for positional encoding."""

from typing import Tuple

import tensorflow as tf


class PositionalEncoding(tf.keras.layers.Layer):
"""Keras layer for positional encoding."""

def call(self,
seq_len: int,
encoding_dimension: int) -> Tuple[tf.Tensor, tf.Tensor]:
num_freq = encoding_dimension // 2
indices = tf.expand_dims(tf.range(seq_len), 0)
indices = tf.tile(indices, [num_freq, 1])
freq_fn = lambda k: 1.0/(10000 ** (2*k/encoding_dimension))
freq = tf.keras.layers.Lambda(freq_fn)(tf.range(num_freq))
freq = tf.expand_dims(freq, 1)
freq = tf.tile(freq, [1, seq_len])
args = tf.multiply(freq, tf.cast(indices, dtype=tf.float64))
sin_enc = tf.math.sin(args)
cos_enc = tf.math.sin(args)
encoding = tf.keras.layers.Concatenate(axis=0)([sin_enc, cos_enc])
encoding = tf.expand_dims(tf.transpose(encoding), 0)
return encoding
27 changes: 27 additions & 0 deletions iris/policies/layers/keras_positional_encoding_layer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from iris.policies.layers import keras_positional_encoding_layer
from absl.testing import absltest


class PositionalEncodingTest(absltest.TestCase):

def test_layer_output(self):
"""Tests the output of PositionalEncoding layer."""
encoding = keras_positional_encoding_layer.PositionalEncoding()(7, 4)
self.assertEqual(encoding.shape, (1, 7, 4))

if __name__ == "__main__":
absltest.main()
70 changes: 70 additions & 0 deletions iris/policies/layers/keras_ranking_attention_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A keras layer for ranking based attention."""

from typing import Callable
import tensorflow as tf


class FavorRankingAttention(tf.keras.layers.Layer):
"""A keras layer for ranking based attention.

A layer that creates a representation of the RGB(D)-image using attention
mechanism from https://arxiv.org/abs/2009.14794. It leverages Performer-ReLU
(go/performer) attention module in order to bypass explicit materialization of
the L x L attention tensor, where L is the number of patches (potentially even
individual pixels). This reduces time complexity of the attention module from
quadratic to linear in L and provides a gateway to processing high-resolution
images, where explicitly calculating attention tensor is not feasible. The
ranking procedure is adopted from https://arxiv.org/abs/2003.08165, where
scores of patches are defined as sums of the entries of the corresponding
column in the attention tensor.
"""

def __init__(
self,
kernel_transformation: Callable[..., tf.Tensor],
top_k: int = 5) -> None: # pytype: disable=annotation-type-mismatch
"""Initializes FavorRankingAttention layer.

Args:
kernel_transformation: Transformation used to get finite kernel features.
top_k: Number of top patches that will be chosen to "summarize" entire
image.
"""
super().__init__()
self._kernel_transformation = kernel_transformation
self._top_k = top_k

def call(self,
queries: tf.Tensor,
keys: tf.Tensor,
values: tf.Tensor) -> tf.Tensor:
queries_prime = self._kernel_transformation(
data=tf.expand_dims(queries, axis=1),
is_query=True)
queries_prime = tf.squeeze(queries_prime, axis=1)
keys_prime = self._kernel_transformation(
data=tf.expand_dims(keys, axis=1),
is_query=False)
keys_prime = tf.squeeze(keys_prime, axis=1)
_, length, _ = queries_prime.shape
all_ones = tf.ones([1, length])
reduced_queries_prime = tf.matmul(all_ones, queries_prime)
scores = tf.matmul(reduced_queries_prime, keys_prime, transpose_b=True)
scores = tf.reshape(scores, (-1, length))
sorted_idxs = tf.argsort(scores, axis=-1, direction='DESCENDING')
top_idxs = sorted_idxs[:, :self._top_k]
return tf.gather(values, top_idxs, axis=1, batch_dims=1)
Loading
Loading