Skip to content

Commit 1ec0fec

Browse files
authored
[image-processing] deprecate plot_keypoint_matching, make visualize_keypoint_matching as a standard (#39830)
* fix: deprecate plot_keypoint_matching and make visualize_keypoint_matching for all Keypoint Matching models * refactor: added copied from * fix: make style * fix: repo consistency * fix: make style * docs: added missing method in SuperGlue docs
1 parent 7b4d984 commit 1ec0fec

File tree

6 files changed

+225
-36
lines changed

6 files changed

+225
-36
lines changed

docs/source/en/model_doc/lightglue.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
107107

108108
```py
109109
# Easy visualization using the built-in plotting method
110-
processor.plot_keypoint_matching(images, processed_outputs)
110+
processor.visualize_keypoint_matching(images, processed_outputs)
111111
```
112112

113113
<div class="flex justify-center">
@@ -128,7 +128,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
128128

129129
- preprocess
130130
- post_process_keypoint_matching
131-
- plot_keypoint_matching
131+
- visualize_keypoint_matching
132132

133133
<frameworkcontent>
134134
<pt>

docs/source/en/model_doc/superglue.md

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -103,38 +103,11 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
103103
print(f"Keypoint at {keypoint0.numpy()} matches with keypoint at {keypoint1.numpy()} with score {matching_score}")
104104
```
105105

106-
- The example below demonstrates how to visualize matches between two images.
106+
- Visualize the matches between the images using the built-in plotting functionality.
107107

108108
```py
109-
import matplotlib.pyplot as plt
110-
import numpy as np
111-
112-
# Create side by side image
113-
merged_image = np.zeros((max(image1.height, image2.height), image1.width + image2.width, 3))
114-
merged_image[: image1.height, : image1.width] = np.array(image1) / 255.0
115-
merged_image[: image2.height, image1.width :] = np.array(image2) / 255.0
116-
plt.imshow(merged_image)
117-
plt.axis("off")
118-
119-
# Retrieve the keypoints and matches
120-
output = processed_outputs[0]
121-
keypoints0 = output["keypoints0"]
122-
keypoints1 = output["keypoints1"]
123-
matching_scores = output["matching_scores"]
124-
125-
# Plot the matches
126-
for keypoint0, keypoint1, matching_score in zip(keypoints0, keypoints1, matching_scores):
127-
plt.plot(
128-
[keypoint0[0], keypoint1[0] + image1.width],
129-
[keypoint0[1], keypoint1[1]],
130-
color=plt.get_cmap("RdYlGn")(matching_score.item()),
131-
alpha=0.9,
132-
linewidth=0.5,
133-
)
134-
plt.scatter(keypoint0[0], keypoint0[1], c="black", s=2)
135-
plt.scatter(keypoint1[0] + image1.width, keypoint1[1], c="black", s=2)
136-
137-
plt.savefig("matched_image.png", dpi=300, bbox_inches='tight')
109+
# Easy visualization using the built-in plotting method
110+
processor.visualize_keypoint_matching(images, processed_outputs)
138111
```
139112

140113
<div class="flex justify-center">
@@ -155,6 +128,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
155128

156129
- preprocess
157130
- post_process_keypoint_matching
131+
- visualize_keypoint_matching
158132

159133
<frameworkcontent>
160134
<pt>

src/transformers/models/efficientloftr/image_processing_efficientloftr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def visualize_keypoint_matching(
408408
images (`ImageInput`):
409409
Image pairs to plot. Same as `EfficientLoFTRImageProcessor.preprocess`. Expects either a list of 2
410410
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
411-
outputs (List[Dict[str, torch.Tensor]]]):
411+
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
412412
A post processed keypoint matching output
413413
414414
Returns:

src/transformers/models/lightglue/image_processing_lightglue.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1818
# See the License for the specific language governing permissions and
1919
# limitations under the License.
20+
import warnings
2021
from typing import Optional, Union
2122

2223
import numpy as np
@@ -44,6 +45,9 @@
4445
from .modeling_lightglue import LightGlueKeypointMatchingOutput
4546

4647

48+
if is_vision_available():
49+
from PIL import Image, ImageDraw
50+
4751
if is_vision_available():
4852
import PIL
4953

@@ -402,18 +406,88 @@ def post_process_keypoint_matching(
402406

403407
return results
404408

409+
def visualize_keypoint_matching(
410+
self,
411+
images: ImageInput,
412+
keypoint_matching_output: list[dict[str, torch.Tensor]],
413+
) -> list["Image.Image"]:
414+
"""
415+
Plots the image pairs side by side with the detected keypoints as well as the matching between them.
416+
417+
Args:
418+
images (`ImageInput`):
419+
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2
420+
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
421+
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
422+
A post processed keypoint matching output
423+
424+
Returns:
425+
`List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
426+
keypoints as well as the matching between them.
427+
"""
428+
images = validate_and_format_image_pairs(images)
429+
images = [to_numpy_array(image) for image in images]
430+
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
431+
432+
results = []
433+
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
434+
height0, width0 = image_pair[0].shape[:2]
435+
height1, width1 = image_pair[1].shape[:2]
436+
plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8)
437+
plot_image[:height0, :width0] = image_pair[0]
438+
plot_image[:height1, width0:] = image_pair[1]
439+
440+
plot_image_pil = Image.fromarray(plot_image)
441+
draw = ImageDraw.Draw(plot_image_pil)
442+
443+
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
444+
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
445+
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
446+
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
447+
):
448+
color = self._get_color(matching_score)
449+
draw.line(
450+
(keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
451+
fill=color,
452+
width=3,
453+
)
454+
draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
455+
draw.ellipse(
456+
(keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
457+
fill="black",
458+
)
459+
460+
results.append(plot_image_pil)
461+
return results
462+
463+
def _get_color(self, score):
464+
"""Maps a score to a color."""
465+
r = int(255 * (1 - score))
466+
g = int(255 * score)
467+
b = 0
468+
return (r, g, b)
469+
405470
def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput):
406471
"""
407472
Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires
408473
matplotlib to be installed.
409474
475+
.. deprecated::
476+
`plot_keypoint_matching` is deprecated and will be removed in a future version. Use `visualize_keypoint_matching` instead.
477+
410478
Args:
411479
images (`ImageInput`):
412480
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or
413481
a list of list of 2 images list with pixel values ranging from 0 to 255.
414-
outputs ([`LightGlueKeypointMatchingOutput`]):
482+
keypoint_matching_output ([`LightGlueKeypointMatchingOutput`]):
415483
Raw outputs of the model.
416484
"""
485+
warnings.warn(
486+
"`plot_keypoint_matching` is deprecated and will be removed in transformers v. "
487+
"Use `visualize_keypoint_matching` instead.",
488+
FutureWarning,
489+
)
490+
417491
if is_matplotlib_available():
418492
import matplotlib.pyplot as plt
419493
else:

src/transformers/models/lightglue/modular_lightglue.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
1415
from dataclasses import dataclass
1516
from typing import Callable, Optional, Union
1617

@@ -20,7 +21,7 @@
2021
from torch.nn.utils.rnn import pad_sequence
2122

2223
from ...configuration_utils import PretrainedConfig
23-
from ...image_utils import ImageInput, to_numpy_array
24+
from ...image_utils import ImageInput, is_vision_available, to_numpy_array
2425
from ...modeling_flash_attention_utils import FlashAttentionKwargs
2526
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
2627
from ...processing_utils import Unpack
@@ -35,6 +36,10 @@
3536
from ..superpoint import SuperPointConfig
3637

3738

39+
if is_vision_available():
40+
from PIL import Image, ImageDraw
41+
42+
3843
logger = logging.get_logger(__name__)
3944

4045

@@ -220,18 +225,90 @@ def post_process_keypoint_matching(
220225
) -> list[dict[str, torch.Tensor]]:
221226
return super().post_process_keypoint_matching(outputs, target_sizes, threshold)
222227

228+
# Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor.visualize_keypoint_matching with EfficientLoFTR->LightGlue
229+
def visualize_keypoint_matching(
230+
self,
231+
images: ImageInput,
232+
keypoint_matching_output: list[dict[str, torch.Tensor]],
233+
) -> list["Image.Image"]:
234+
"""
235+
Plots the image pairs side by side with the detected keypoints as well as the matching between them.
236+
237+
Args:
238+
images (`ImageInput`):
239+
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2
240+
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
241+
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
242+
A post processed keypoint matching output
243+
244+
Returns:
245+
`List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
246+
keypoints as well as the matching between them.
247+
"""
248+
images = validate_and_format_image_pairs(images)
249+
images = [to_numpy_array(image) for image in images]
250+
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
251+
252+
results = []
253+
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
254+
height0, width0 = image_pair[0].shape[:2]
255+
height1, width1 = image_pair[1].shape[:2]
256+
plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8)
257+
plot_image[:height0, :width0] = image_pair[0]
258+
plot_image[:height1, width0:] = image_pair[1]
259+
260+
plot_image_pil = Image.fromarray(plot_image)
261+
draw = ImageDraw.Draw(plot_image_pil)
262+
263+
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
264+
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
265+
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
266+
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
267+
):
268+
color = self._get_color(matching_score)
269+
draw.line(
270+
(keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
271+
fill=color,
272+
width=3,
273+
)
274+
draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
275+
draw.ellipse(
276+
(keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
277+
fill="black",
278+
)
279+
280+
results.append(plot_image_pil)
281+
return results
282+
283+
# Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor._get_color
284+
def _get_color(self, score):
285+
"""Maps a score to a color."""
286+
r = int(255 * (1 - score))
287+
g = int(255 * score)
288+
b = 0
289+
return (r, g, b)
290+
223291
def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput):
224292
"""
225293
Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires
226294
matplotlib to be installed.
227295
296+
.. deprecated::
297+
`plot_keypoint_matching` is deprecated and will be removed in a future version. Use `visualize_keypoint_matching` instead.
298+
228299
Args:
229300
images (`ImageInput`):
230301
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or
231302
a list of list of 2 images list with pixel values ranging from 0 to 255.
232-
outputs ([`LightGlueKeypointMatchingOutput`]):
303+
keypoint_matching_output ([`LightGlueKeypointMatchingOutput`]):
233304
Raw outputs of the model.
234305
"""
306+
warnings.warn(
307+
"`plot_keypoint_matching` is deprecated and will be removed in transformers v. "
308+
"Use `visualize_keypoint_matching` instead.",
309+
FutureWarning,
310+
)
311+
235312
if is_matplotlib_available():
236313
import matplotlib.pyplot as plt
237314
else:

src/transformers/models/superglue/image_processing_superglue.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
if is_vision_available():
4949
import PIL
50+
from PIL import Image, ImageDraw
5051

5152
logger = logging.get_logger(__name__)
5253

@@ -406,5 +407,68 @@ def post_process_keypoint_matching(
406407

407408
return results
408409

410+
# Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor.visualize_keypoint_matching with EfficientLoFTR->SuperGlue
411+
def visualize_keypoint_matching(
412+
self,
413+
images: ImageInput,
414+
keypoint_matching_output: list[dict[str, torch.Tensor]],
415+
) -> list["Image.Image"]:
416+
"""
417+
Plots the image pairs side by side with the detected keypoints as well as the matching between them.
418+
419+
Args:
420+
images (`ImageInput`):
421+
Image pairs to plot. Same as `SuperGlueImageProcessor.preprocess`. Expects either a list of 2
422+
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
423+
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
424+
A post processed keypoint matching output
425+
426+
Returns:
427+
`List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
428+
keypoints as well as the matching between them.
429+
"""
430+
images = validate_and_format_image_pairs(images)
431+
images = [to_numpy_array(image) for image in images]
432+
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
433+
434+
results = []
435+
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
436+
height0, width0 = image_pair[0].shape[:2]
437+
height1, width1 = image_pair[1].shape[:2]
438+
plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8)
439+
plot_image[:height0, :width0] = image_pair[0]
440+
plot_image[:height1, width0:] = image_pair[1]
441+
442+
plot_image_pil = Image.fromarray(plot_image)
443+
draw = ImageDraw.Draw(plot_image_pil)
444+
445+
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
446+
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
447+
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
448+
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
449+
):
450+
color = self._get_color(matching_score)
451+
draw.line(
452+
(keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
453+
fill=color,
454+
width=3,
455+
)
456+
draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
457+
draw.ellipse(
458+
(keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
459+
fill="black",
460+
)
461+
462+
results.append(plot_image_pil)
463+
return results
464+
465+
# Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor._get_color
466+
def _get_color(self, score):
467+
"""Maps a score to a color."""
468+
r = int(255 * (1 - score))
469+
g = int(255 * score)
470+
b = 0
471+
return (r, g, b)
472+
409473

410474
__all__ = ["SuperGlueImageProcessor"]

0 commit comments

Comments
 (0)