|
2 | 2 |
|
3 | 3 | import itertools |
4 | 4 | from abc import abstractmethod |
5 | | -from typing import Dict, List, Optional, Tuple |
| 5 | +from typing import Optional |
6 | 6 |
|
7 | 7 | import torch |
8 | 8 | import torch.nn.functional as F |
@@ -47,8 +47,8 @@ def adjust_marlin_shard(param, shard_size, shard_offset): |
47 | 47 |
|
48 | 48 |
|
49 | 49 | def adjust_bitsandbytes_4bit_shard(param: Parameter, |
50 | | - shard_offsets: Dict[str, Tuple[int, int]], |
51 | | - loaded_shard_id: str) -> Tuple[int, int]: |
| 50 | + shard_offsets: dict[str, tuple[int, int]], |
| 51 | + loaded_shard_id: str) -> tuple[int, int]: |
52 | 52 | """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" |
53 | 53 |
|
54 | 54 | total, _ = shard_offsets["total"] |
@@ -90,7 +90,7 @@ class LinearMethodBase(QuantizeMethodBase): |
90 | 90 | @abstractmethod |
91 | 91 | def create_weights(self, layer: torch.nn.Module, |
92 | 92 | input_size_per_partition: int, |
93 | | - output_partition_sizes: List[int], input_size: int, |
| 93 | + output_partition_sizes: list[int], input_size: int, |
94 | 94 | output_size: int, params_dtype: torch.dtype, |
95 | 95 | **extra_weight_attrs): |
96 | 96 | """Create weights for a linear layer. |
@@ -123,7 +123,7 @@ class UnquantizedLinearMethod(LinearMethodBase): |
123 | 123 |
|
124 | 124 | def create_weights(self, layer: torch.nn.Module, |
125 | 125 | input_size_per_partition: int, |
126 | | - output_partition_sizes: List[int], input_size: int, |
| 126 | + output_partition_sizes: list[int], input_size: int, |
127 | 127 | output_size: int, params_dtype: torch.dtype, |
128 | 128 | **extra_weight_attrs): |
129 | 129 | weight = Parameter(torch.empty(sum(output_partition_sizes), |
@@ -179,7 +179,8 @@ def __init__( |
179 | 179 | self.quant_method = quant_config.get_quant_method(self, |
180 | 180 | prefix=prefix) |
181 | 181 |
|
182 | | - def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 182 | + def forward(self, |
| 183 | + x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]: |
183 | 184 | raise NotImplementedError |
184 | 185 |
|
185 | 186 |
|
@@ -240,9 +241,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): |
240 | 241 | assert param.size() == loaded_weight.size() |
241 | 242 | param.data.copy_(loaded_weight) |
242 | 243 |
|
243 | | - def forward( |
244 | | - self, x: torch.Tensor |
245 | | - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: |
| 244 | + def forward(self, |
| 245 | + x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]: |
246 | 246 | bias = self.bias if not self.skip_bias_add else None |
247 | 247 | assert self.quant_method is not None |
248 | 248 | output = self.quant_method.apply(self, x, bias) |
@@ -288,7 +288,7 @@ def __init__(self, |
288 | 288 | skip_bias_add: bool = False, |
289 | 289 | params_dtype: Optional[torch.dtype] = None, |
290 | 290 | quant_config: Optional[QuantizationConfig] = None, |
291 | | - output_sizes: Optional[List[int]] = None, |
| 291 | + output_sizes: Optional[list[int]] = None, |
292 | 292 | prefix: str = ""): |
293 | 293 | super().__init__(input_size, output_size, skip_bias_add, params_dtype, |
294 | 294 | quant_config, prefix) |
@@ -374,7 +374,7 @@ def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): |
374 | 374 | loaded_weight = loaded_weight.reshape(1) |
375 | 375 | param.load_column_parallel_weight(loaded_weight=loaded_weight) |
376 | 376 |
|
377 | | - def forward(self, input_): |
| 377 | + def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]: |
378 | 378 | bias = self.bias if not self.skip_bias_add else None |
379 | 379 |
|
380 | 380 | # Matrix multiply. |
@@ -422,7 +422,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): |
422 | 422 |
|
423 | 423 | def __init__(self, |
424 | 424 | input_size: int, |
425 | | - output_sizes: List[int], |
| 425 | + output_sizes: list[int], |
426 | 426 | bias: bool = True, |
427 | 427 | gather_output: bool = False, |
428 | 428 | skip_bias_add: bool = False, |
@@ -500,7 +500,7 @@ def weight_loader(self, |
500 | 500 | current_shard_offset = 0 |
501 | 501 | use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", |
502 | 502 | False) |
503 | | - shard_offsets: List[Tuple[int, int, int]] = [] |
| 503 | + shard_offsets: list[tuple[int, int, int]] = [] |
504 | 504 | for i, output_size in enumerate(self.output_sizes): |
505 | 505 | shard_offsets.append((i, current_shard_offset, output_size)) |
506 | 506 | current_shard_offset += output_size |
@@ -602,7 +602,7 @@ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, |
602 | 602 | """ |
603 | 603 |
|
604 | 604 | current_shard_offset = 0 |
605 | | - shard_offsets: List[Tuple[int, int, int]] = [] |
| 605 | + shard_offsets: list[tuple[int, int, int]] = [] |
606 | 606 | for i, output_size in enumerate(self.output_sizes): |
607 | 607 | shard_offsets.append((i, current_shard_offset, output_size)) |
608 | 608 | current_shard_offset += output_size |
@@ -1124,7 +1124,7 @@ def weight_loader_v2(self, param: BasevLLMParameter, |
1124 | 1124 |
|
1125 | 1125 | param.load_row_parallel_weight(loaded_weight=loaded_weight) |
1126 | 1126 |
|
1127 | | - def forward(self, input_): |
| 1127 | + def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]: |
1128 | 1128 | if self.input_is_parallel: |
1129 | 1129 | input_parallel = input_ |
1130 | 1130 | else: |
|
0 commit comments