Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
arxyzan committed Aug 20, 2023
1 parent e0550a8 commit ea23c47
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions hezar/models/model_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@dataclass
class ModelOutputs:
class ModelOutput:
"""
Base class for all model outputs (named based on tasks)
Expand Down Expand Up @@ -41,30 +41,30 @@ def items(self):


@dataclass
class LanguageModelingOutputs(ModelOutputs):
class LanguageModelingOutput(ModelOutput):
last_hidden_state: Optional[torch.FloatTensor] = None
hidden_state: Optional[torch.FloatTensor] = None
attentions: Optional[torch.FloatTensor] = None


@dataclass
class TextClassificationOutputs(ModelOutputs):
class TextClassificationOutput(ModelOutput):
labels: Optional[List[str]] = None
probs: Optional[List[float]] = None


@dataclass
class SequenceLabelingOutputs(ModelOutputs):
class SequenceLabelingOutput(ModelOutput):
tokens: Optional[List[List[str]]] = None
tags: Optional[List[List[str]]] = None
probs: Optional[List[List[float]]] = None


@dataclass
class Text2TextOutputs(ModelOutputs):
class Text2TextOutput(ModelOutput):
generated_texts: Optional[List[str]] = None


@dataclass
class SpeechRecognitionOutputs(ModelOutputs):
class SpeechRecognitionOutput(ModelOutput):
transcription: Optional[List[str]] = None

0 comments on commit ea23c47

Please sign in to comment.