|
31 | 31 |
|
32 | 32 | import torch |
33 | 33 | from torch import nn |
34 | | - |
35 | | -# TODO(yinfan.1024): How to handle that, still use Qwen2Config? |
36 | 34 | from transformers import PretrainedConfig |
37 | 35 |
|
38 | 36 | from vllm.attention import Attention, AttentionType |
@@ -180,7 +178,6 @@ def __init__( |
180 | 178 | for ut_step in range(total_ut_steps): |
181 | 179 | base_layer_idx = extract_layer_index(prefix) |
182 | 180 | unique_layer_idx = ut_step * total_layers + base_layer_idx |
183 | | - # print(unique_layer_idx) |
184 | 181 |
|
185 | 182 | unique_prefix = prefix.replace( |
186 | 183 | f"layers.{base_layer_idx}", f"layers.{unique_layer_idx}" |
@@ -343,7 +340,7 @@ def __init__( |
343 | 340 | prefix=f"{prefix}.embed_tokens", |
344 | 341 | ) |
345 | 342 |
|
346 | | - # Use the provided decoder layer type or default to Qwen2DecoderLayer |
| 343 | + # Use the provided decoder layer type or default to OuroDecoderLayer |
347 | 344 | decoder_layer_type = decoder_layer_type or OuroDecoderLayer |
348 | 345 | self.start_layer, self.end_layer, self.layers = make_layers( |
349 | 346 | config.num_hidden_layers, |
@@ -380,7 +377,6 @@ def forward( |
380 | 377 |
|
381 | 378 | # Get total_ut_steps from config, default to 4 if not specified |
382 | 379 | total_ut_steps = getattr(self.config, "total_ut_steps", 4) |
383 | | - # print(total_ut_steps) |
384 | 380 |
|
385 | 381 | for current_ut in range(total_ut_steps): |
386 | 382 | for layer in self.layers[self.start_layer : self.end_layer]: |
|
0 commit comments