|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 |
|
4 | | -import itertools |
5 | 4 | from collections.abc import Iterable |
6 | 5 | from typing import Optional, Union |
7 | 6 |
|
@@ -39,8 +38,10 @@ def __init__(self, config: RobertaConfig): |
39 | 38 | config.hidden_size) |
40 | 39 | self.LayerNorm = nn.LayerNorm(config.hidden_size, |
41 | 40 | eps=config.layer_norm_eps) |
42 | | - self.position_ids = nn.Parameter( |
43 | | - torch.empty((1, config.max_position_embeddings)), ) |
| 41 | + self.register_buffer( |
| 42 | + "position_ids", |
| 43 | + torch.arange(config.max_position_embeddings).unsqueeze(0), |
| 44 | + ) |
44 | 45 |
|
45 | 46 | self.position_embedding_type = config.position_embedding_type |
46 | 47 | if self.position_embedding_type != "absolute": |
@@ -238,27 +239,3 @@ def create_position_ids_from_input_ids(input_ids, |
238 | 239 | past_key_values_length) * mask |
239 | 240 |
|
240 | 241 | return incremental_indices.long() + padding_idx |
241 | | - |
242 | | - |
243 | | -def roberta_task_weights_filter( |
244 | | - all_weights: Iterable[tuple[str, torch.Tensor]] |
245 | | -) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, |
246 | | - torch.Tensor]]]: |
247 | | - """ |
248 | | - Separate task-specific weights that are applied on top |
249 | | - of the encoder-decoder bert base. |
250 | | - To do so, return two generators over the original iterator. |
251 | | - Also, remove the "roberta." prefix to make it loadable |
252 | | - from vanilla BertModel. |
253 | | - """ |
254 | | - # Copy of a lazy iterator without in-memory overhead so both |
255 | | - # iterators can be iterated upon independently. |
256 | | - all_weights1, all_weights2 = itertools.tee(all_weights) |
257 | | - |
258 | | - def encoder_decoder_weights(): |
259 | | - for name, weight in all_weights1: |
260 | | - if name.startswith("roberta."): |
261 | | - yield (name[len("roberta."):], weight) |
262 | | - |
263 | | - return encoder_decoder_weights(), ((n, w) for n, w in all_weights2 |
264 | | - if not n.startswith("roberta.")) |
0 commit comments