-
-
Notifications
You must be signed in to change notification settings - Fork 221
Commit
…ion for very long contexts
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,9 +68,11 @@ def __init__(self, model_config_path): | |
|
||
# Optional settings | ||
|
||
self.max_seq_len = 2048 # Reduce to save memory. Can also be increased, but the pretrained models produce degenerate output after 2048 tokens in any case. Should be possible to finetune for longer sequence lengths. | ||
self.max_seq_len = 2048 # Reduce to save memory. Can also be increased, ideally while also using compress_pos_emn and a compatible model/LoRA | ||
self.max_input_len = 2048 # Maximum length of input IDs in a single forward pass. Sequences longer than this will be processed in multiple steps | ||
self.max_attention_size = 2048**2 # Sequences will be processed in chunks to keep the size of the attention weights matrix <= this | ||
self.compress_pos_emb = 1.0 # Increase to compress positional embeddings applied to sequence | ||
self.gpu_peer_fix = False # Apparently Torch can have problems transferring tensors directly one GPU to another sometimes. Enable this to move tensors via system RAM instead, where needed | ||
self.gpu_peer_fix = False # Apparently Torch can have problems transferring tensors directly one GPU to another sometimes. Enable this to expliticly move tensors via system RAM instead, where needed | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
turboderp
Author
Owner
|
||
self.auto_map = None # List of floats with memory allocation in GB, per CUDA device, overrides device_map | ||
|
||
# Tuning | ||
|
@@ -779,7 +781,7 @@ def __init__(self, config): | |
device_buffers = {} | ||
self.buffers.append(device_buffers) | ||
|
||
temp_state = torch.zeros((config.max_seq_len, config.intermediate_size), dtype = torch.float16, device = dev) | ||
temp_state = torch.zeros((config.max_input_len, config.intermediate_size), dtype = torch.float16, device = dev) | ||
temp_mlp = torch.zeros((config.fused_mlp_thd * 2, config.intermediate_size), dtype = torch.float16, device = dev) | ||
temp_zeros_float = torch.zeros((1, 65536), dtype = torch.float32, device = dev) | ||
temp_dq = torch.zeros((1, max_dq_buffer_size), dtype = torch.float16, device = dev) | ||
|
@@ -800,7 +802,70 @@ def __init__(self, config): | |
torch.cuda.empty_cache() | ||
|
||
|
||
def forward(self, input_ids, cache, last_id_only = True, preprocess_only = False, lora = None, output_device = None, input_mask = None): | ||
def forward(self, | ||
input_ids, | ||
cache, | ||
last_id_only = True, | ||
preprocess_only = False, | ||
lora = None, | ||
output_device = None, | ||
input_mask = None): | ||
|
||
q_len = input_ids.shape[-1] | ||
remaining_q_len = q_len | ||
|
||
# Split forward pass | ||
|
||
result = None | ||
|
||
chunk_begin = 0 | ||
while chunk_begin < q_len: | ||
|
||
# Limit chunk_size to max_input_len | ||
|
||
chunk_size = min(remaining_q_len, self.config.max_input_len) | ||
|
||
# Limit chunk_size to keep size of attention operation <= max_attention_size | ||
|
||
past_len = cache.current_seq_len | ||
attn_size = (past_len + remaining_q_len) * remaining_q_len | ||
max_a = self.config.max_attention_size | ||
if attn_size > max_a: | ||
cs = (math.sqrt(past_len ** 2 + 4 * max_a) - past_len) / 2 | ||
chunk_size = math.floor(cs) | ||
|
||
# Process chunk | ||
|
||
chunk_end = min(chunk_begin + chunk_size, q_len) | ||
|
||
_last_id_only = last_id_only | ||
_preprocess_only = preprocess_only or (chunk_end < q_len and last_id_only) | ||
|
||
r = self._forward(input_ids[:, chunk_begin : chunk_end], | ||
cache, | ||
_last_id_only, | ||
_preprocess_only, | ||
lora, | ||
output_device, | ||
input_mask) | ||
|
||
if not _preprocess_only: | ||
result = r if result is None else torch.cat((result, r), dim = -1) | ||
|
||
chunk_begin = chunk_end | ||
remaining_q_len -= chunk_size | ||
|
||
return result | ||
|
||
|
||
def _forward(self, | ||
input_ids, | ||
cache, | ||
last_id_only = True, | ||
preprocess_only = False, | ||
lora = None, | ||
output_device = None, | ||
input_mask = None): | ||
|
||
assert input_mask is None or input_mask.shape == input_ids.shape | ||
|
||
|
You might be interested to know that this could be a driver issue. If you're using 2x RTX 4090, there has been a driver bug in Linux causing corrupt results. Seems like although nvidia doesn't want their latest generation consumer cards to be able to do peer to peer transfers, they forgot to actually disable it fully in the driver. So the data transfer would appear to succeed but you'd be reading uninitialised memory on the next card.
525.105.17
and above should fix this.