-
Notifications
You must be signed in to change notification settings - Fork 95
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Create docs, tests and tutorial for ambiguity measure * update tutorial * Add link to colab tutorial * update style * add ambiguity measure * Apply suggestions from code review Co-authored-by: ad-daniel <44834743+ad-daniel@users.noreply.github.com> * add ambiguity measure to package list and changelog * Apply suggestions from code review Co-authored-by: ad-daniel <44834743+ad-daniel@users.noreply.github.com> Co-authored-by: ad-daniel <44834743+ad-daniel@users.noreply.github.com>
- Loading branch information
1 parent
d51b145
commit 423fcd6
Showing
12 changed files
with
2,295 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
## ambiguity_measure module | ||
|
||
The *ambiguity_measure* module contains the *AmbiguityMeasure* class. | ||
|
||
### Class AmbiguityMeasure | ||
Bases: `object` | ||
|
||
The *AmbiguityMeasure* class is a tool that allows to obtain an ambiguity measure of vision-based models that output pixel-wise value estimates. | ||
This tool can be used in combination with vision-based manipulation models such as Transporter Nets [[1]](#transporter-paper). | ||
|
||
The [AmbiguityMeasure](../../src/opendr/utils/ambiguity_measure/ambiguity_measure.py) class has the following public methods: | ||
|
||
#### `AmbiguityMeasure` constructor | ||
```python | ||
AmbiguityMeasure(self, threshold, temperature) | ||
``` | ||
|
||
Constructor parameters: | ||
|
||
- **threshold**: *float, default=0.5*\ | ||
Ambiguity threshold, should be in [0, 1). | ||
- **temperature**: *float, default=1.0*\ | ||
Temperature of the sigmoid function. | ||
Should be > 0. | ||
Higher temperatures will result in higher ambiguity measures. | ||
|
||
#### `AmbiguityMeasure.get_ambiguity_measure` | ||
```python | ||
AmbiguityMeasure.get_ambiguity_measure(self, heatmap) | ||
``` | ||
|
||
This method allows to obtain an ambiguity measure of the output of a model. | ||
|
||
Parameters: | ||
|
||
- **heatmap**: *np.ndarray*\ | ||
Pixel-wise value estimates. | ||
These can be obtained using from for example a Transporter Nets model [[1]](#transporter-paper). | ||
|
||
#### Demos and tutorial | ||
|
||
A demo showcasing the usage and functionality of the *AmbiguityMeasure* is available [here](https://colab.research.google.com/github/opendr-eu/opendr/blob/ambiguity_measure/projects/python/utils/ambiguity_measure/ambiguity_measure_tutorial.ipynb). | ||
|
||
|
||
#### Examples | ||
|
||
* **Ambiguity measure example** | ||
|
||
This example shows how to obtain the ambiguity measure from pixel-wise value estimates. | ||
|
||
```python | ||
import numpy as np | ||
from opendr.utils.ambiguity_measure.ambiguity_measure import AmbiguityMeasure | ||
|
||
# Simulate image and value pixel-wise value estimates (normally you would get this from a model such as Transporter) | ||
img = 255 * np.random.random((128, 128, 3)) | ||
img = np.asarray(img, dtype="uint8") | ||
heatmap = 10 * np.random.random((128, 128)) | ||
|
||
# Initialize ambiguity measure | ||
am = AmbiguityMeasure(threshold=0.1, temperature=3) | ||
|
||
# Get ambiguity measure of the heatmap | ||
ambiguous, locs, maxima, probs = am.get_ambiguity_measure(heatmap) | ||
|
||
# Plot ambiguity measure | ||
am.plot_ambiguity_measure(heatmap, locs, probs, img) | ||
``` | ||
|
||
#### References | ||
<a name="transporter-paper" href="https://proceedings.mlr.press/v155/zeng21a/zeng21a.pdf">[1]</a> | ||
Zeng, A., Florence, P., Tompson, J., Welker, S., Chien, J., Attarian, M., ... & Lee, J. (2021, October). | ||
Transporter networks: Rearranging the visual world for robotic manipulation. | ||
In Conference on Robot Learning (pp. 726-747). | ||
PMLR. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1,796 changes: 1,796 additions & 0 deletions
1,796
projects/python/utils/ambiguity_measure/ambiguity_measure_tutorial.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
## Utils Module | ||
|
||
This module contains utility tools of the OpenDR toolkit, such as the | ||
[hyperparameter tuning tool](hyperparameter_tuner/hyperparameter_tuner.py). | ||
[hyperparameter tuning tool](hyperparameter_tuner/hyperparameter_tuner.py) and the [AmbiguityMeasure tool](ambiguity_measure/ambiguity_measure.py). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# OpenDR Ambiguity Measure | ||
|
||
This folder contains a tool for obtaining ambiguity measures for pixel-wise values estimates. | ||
This tool can be used in combination with vision-based manipulation models such as Transporter Nets [[1]](#transporter-paper). | ||
The contents of the file `persistence.py` were adapted from [persitence.py](https://git.sthu.org/?p=persistence.git;a=blob;f=imagepers.py) and | ||
[union_find.py](https://git.sthu.org/?p=persistence.git;a=blob;f=union_find.py) created by Stefan Huber. | ||
|
||
|
||
#### References | ||
<a name="transporter-paper" href="https://proceedings.mlr.press/v155/zeng21a/zeng21a.pdf">[1]</a> | ||
Zeng, A., Florence, P., Tompson, J., Welker, S., Chien, J., Attarian, M., ... & Lee, J. (2021, October). | ||
Transporter networks: Rearranging the visual world for robotic manipulation. | ||
In Conference on Robot Learning (pp. 726-747). | ||
PMLR. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from opendr.utils.ambiguity_measure.ambiguity_measure import AmbiguityMeasure | ||
|
||
__all__ = ["AmbiguityMeasure"] |
189 changes: 189 additions & 0 deletions
189
src/opendr/utils/ambiguity_measure/ambiguity_measure.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
# Copyright 2020-2022 OpenDR European Project | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import numpy as np | ||
from opendr.utils.ambiguity_measure.persistence import get_persistence | ||
from opendr.engine.data import Image | ||
from matplotlib import pyplot as plt, transforms, cm | ||
from copy import deepcopy | ||
from typing import Optional, Union, List | ||
|
||
|
||
class AmbiguityMeasure(object): | ||
""" | ||
AmbiguityMeasure tool. | ||
This tool can be used to obtain an ambiguity measure of the output of vision-based manipulation models, such as | ||
Transporter Nets and CLIPort. | ||
""" | ||
|
||
def __init__(self, threshold: float = 0.5, temperature: float = 1.0): | ||
""" | ||
Constructor of AmbiguityMeasure | ||
:param threshold: Ambiguity threshold, should be in [0, 1). | ||
:type threshold: float | ||
:param temperature: Temperature of the sigmoid function. | ||
:type temperature: float | ||
""" | ||
assert threshold >= 0 < 1, "Threshold should be in [0, 1)." | ||
assert temperature > 0, "Temperature should be greater than 0." | ||
self._threshold = threshold | ||
self._temperature = temperature | ||
|
||
def get_ambiguity_measure(self, heatmap: np.ndarray): | ||
""" | ||
Get Ambiguity Measure. | ||
:param heatmap: Pixel-wise value estimates. | ||
:type heatmap: np.ndarray | ||
:return: Tuple[ambiguous, locs, maxima, probs] | ||
- ambiguous: Whether or not output was ambiguous. | ||
- locs: Pixel locations of significant local maxima. | ||
- maxima: Values corresponding to local maxima. | ||
- probs: Probability mass function based on local maxima. | ||
:rtype: Tuple[ambiguous, locs, maxima, probs] | ||
- ambiguous: bool | ||
- locs: list | ||
- maxima: list | ||
- probs: list | ||
""" | ||
# Calculate persistence to find local maxima | ||
persistence = get_persistence(heatmap) | ||
|
||
maxima = [] | ||
locs = [] | ||
for i, homclass in enumerate(persistence): | ||
p_birth, _, _, _ = homclass | ||
locs.append(p_birth) | ||
maxima.append(heatmap[p_birth[0], p_birth[1]]) | ||
probs = self.__softmax(np.asarray(maxima)) | ||
ambiguous = 1.0 - max(probs) < self._threshold | ||
return ambiguous, locs, maxima, probs | ||
|
||
def plot_ambiguity_measure( | ||
self, | ||
heatmap: np.ndarray, | ||
locs: List[List[int]], | ||
probs: Union[List[float], np.ndarray], | ||
img: Image = None, | ||
img_offset: float = -250.0, | ||
view_init: List[int] = [30, 30], | ||
plot_threshold: float = 0.05, | ||
title: str = "Ambiguity Measure", | ||
save_path: Optional[str] = None, | ||
): | ||
""" | ||
Plot the obtained ambiguity measure. | ||
:param heatmap: Pixel-wise value estimates. | ||
:type heatmap: np.ndarray | ||
:param locs: Pixel locations of significant local maxima. | ||
:type locs: List[List[int]] | ||
:param probs: Probability mass function based on local maxima. | ||
:type probs: List[float] | ||
:param img: Top view input image. | ||
:type img: Union[np.ndarray, Image] | ||
:param img_offset: Specifies the distance between value estimates and image. | ||
:type img_offset: float | ||
:param view_init: Set the elevation and azimuth of the axes in degrees (not radians). | ||
:type view_init: List[float] | ||
:param plot_threshold: Threshold for plotting probabilities. | ||
Probabilities lower than this value will not be plotted. | ||
:param title: Title of the plot. | ||
:type title: str | ||
:param save_path: Path for saving figure, if None, | ||
:type plot_threshold: float | ||
""" | ||
fig = plt.figure() | ||
ax = plt.axes(projection="3d") | ||
ax.computed_zorder = False | ||
trans_offset = transforms.offset_copy(ax.transData, fig=fig, y=2, units="dots") | ||
X, Y = np.mgrid[0:heatmap.shape[0], 0:heatmap.shape[1]] | ||
Z = heatmap | ||
ax.set_title(title) | ||
ax.plot_surface(X, Y, Z, cmap=cm.viridis, linewidth=0, antialiased=False, shade=False, zorder=-1) | ||
|
||
if img is not None: | ||
if type(img) is Image: | ||
img = np.moveaxis(img.numpy(), 0, -1) | ||
|
||
img = deepcopy(img) | ||
if np.max(img) > 1: | ||
img = img / 255 | ||
x_image, y_image = np.mgrid[0:img.shape[0], 0:img.shape[1]] | ||
ax.plot_surface( | ||
x_image, | ||
y_image, | ||
np.ones(img.shape[:2]) * -img_offset, | ||
rstride=1, | ||
cstride=1, | ||
facecolors=img, | ||
shade=False, | ||
) | ||
|
||
ax.set_zlim(-img_offset - 1, 50) | ||
ax.view_init(view_init[0], view_init[1]) | ||
for loc, value in zip(locs, probs): | ||
if value > plot_threshold: | ||
ax.plot3D([loc[0]], [loc[1]], [value], "r.", zorder=9) | ||
ax.plot3D([loc[0]], [loc[1]], [-img_offset], "r.", zorder=-2) | ||
ax.text( | ||
loc[0], | ||
loc[1], | ||
value, | ||
f"{value:.2f}", | ||
zorder=10, | ||
transform=trans_offset, | ||
horizontalalignment="center", | ||
verticalalignment="bottom", | ||
c="r", | ||
fontsize="large", | ||
) | ||
ax.grid(False) | ||
ax.set_axis_off() | ||
ax.set_xticklabels([]) | ||
ax.set_yticklabels([]) | ||
ax.set_zticklabels([]) | ||
if save_path: | ||
plt.savefig(save_path) | ||
plt.show() | ||
|
||
@property | ||
def threshold(self): | ||
""" | ||
Getter of threshold. | ||
:return: Threshold value. | ||
:rtype: float | ||
""" | ||
return self._threshold | ||
|
||
@threshold.setter | ||
def threshold(self, value: float): | ||
""" | ||
Setter of threshold. | ||
:param threshold: Threshold value. | ||
:type threshold: float | ||
""" | ||
if type(value) != float: | ||
raise TypeError("threshold should be a float") | ||
else: | ||
self._threshold = value | ||
|
||
def __softmax(self, x): | ||
x /= self._temperature | ||
e_x = np.exp(x - np.max(x)) | ||
return e_x / e_x.sum() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
[runtime] | ||
# 'python' key expects a value using the Python requirements file format | ||
# https://pip.pypa.io/en/stable/reference/pip_install/#requirements-file-format | ||
python=numpy | ||
matplotlib | ||
wheel | ||
|
||
opendr=opendr-toolkit-engine |
Oops, something went wrong.