forked from huggingface/transformers
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
steven
committed
Feb 12, 2025
1 parent
57d0f02
commit 0f52278
Showing
1 changed file
with
77 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from typing import Optional, Tuple, Union | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from ...configuration_utils import PretrainedConfig | ||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward | ||
from ..superglue.modeling_superglue import KeypointMatchingOutput | ||
from ..superpoint.modeling_superpoint import SuperPointPreTrainedModel | ||
|
||
|
||
class EfficientLoFTRConfig(PretrainedConfig): | ||
model_type = "efficientloftr" | ||
|
||
class EfficientLoFTRBackbone(nn.Module): | ||
def __init__(self, config: EfficientLoFTRConfig): | ||
super().__init__() | ||
|
||
self.config = config | ||
|
||
class EfficientLoFTRPreTrainedModel(SuperPointPreTrainedModel): | ||
config_class = EfficientLoFTRConfig | ||
base_model_prefix = "efficientloftr" | ||
|
||
|
||
EFFICIENTLOFTR_START_DOCSTRING = r""" | ||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it | ||
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and | ||
behavior. | ||
Parameters: | ||
config ([`EfficientLoFTRConfig`]): Model configuration class with all the parameters of the model. | ||
Initializing with a config file does not load the weights associated with the model, only the | ||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. | ||
""" | ||
|
||
EFFICIENTLOFTR_INPUTS_DOCSTRING = r""" | ||
Args: | ||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): | ||
Pixel values. Pixel values can be obtained using [`SuperGlueImageProcessor`]. See | ||
[`SuperGlueImageProcessor.__call__`] for details. | ||
output_attentions (`bool`, *optional*): | ||
Whether or not to return the attentions tensors. See `attentions` under returned tensors for more detail. | ||
output_hidden_states (`bool`, *optional*): | ||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for | ||
more detail. | ||
return_dict (`bool`, *optional*): | ||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | ||
""" | ||
|
||
|
||
@add_start_docstrings( | ||
"SuperGlue model taking images as inputs and outputting the matching of them.", | ||
EFFICIENTLOFTR_START_DOCSTRING, | ||
) | ||
class EfficientLoFTRForKeypointMatching(EfficientLoFTRPreTrainedModel): | ||
""" | ||
TODO | ||
""" | ||
|
||
def __init__(self, config: EfficientLoFTRConfig) -> None: | ||
super().__init__(config) | ||
|
||
self.backbone = EfficientLoFTRBackbone(config) | ||
|
||
self.post_init() | ||
|
||
@add_start_docstrings_to_model_forward(EFFICIENTLOFTR_INPUTS_DOCSTRING) | ||
def forward( | ||
self, | ||
pixel_values: torch.FloatTensor, | ||
labels: Optional[torch.LongTensor] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
) -> Union[Tuple, KeypointMatchingOutput]: | ||
pass |