|
121 | 121 | _EMBEDDING_MODELS = { |
122 | 122 | # [Text-only] |
123 | 123 | "BertModel": ("bert", "BertEmbeddingModel"), |
124 | | - "RobertaModel": ("roberta", "RobertaEmbeddingModel"), |
125 | | - "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"), |
126 | | - "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), |
127 | | - "NomicBertModel": ("bert", "NomicBertEmbeddingModel"), |
128 | | - "GteModel": ("bert", "GteEmbeddingModel"), |
129 | 124 | "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), |
130 | 125 | "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), |
131 | 126 | "GlmForCausalLM": ("glm", "GlmForCausalLM"), |
132 | 127 | "GritLM": ("gritlm", "GritLM"), |
| 128 | + "GteModel": ("bert", "GteEmbeddingModel"), |
133 | 129 | "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"), |
134 | 130 | "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501 |
135 | 131 | "LlamaModel": ("llama", "LlamaForCausalLM"), |
|
139 | 135 | if arch == "LlamaForCausalLM" |
140 | 136 | }, |
141 | 137 | "MistralModel": ("llama", "LlamaForCausalLM"), |
| 138 | + "NomicBertModel": ("bert", "NomicBertEmbeddingModel"), |
142 | 139 | "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), |
143 | 140 | "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"), |
144 | 141 | "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), |
145 | 142 | "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), |
146 | 143 | "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"), |
| 144 | + "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"), |
| 145 | + "RobertaModel": ("roberta", "RobertaEmbeddingModel"), |
147 | 146 | "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), |
| 147 | + "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), |
148 | 148 | # [Multimodal] |
149 | 149 | "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 |
150 | 150 | "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), |
|
0 commit comments