Skip to content

Commit

Permalink
fix no fa test
Browse files Browse the repository at this point in the history
  • Loading branch information
li126com committed Dec 3, 2024
1 parent 715fbe9 commit 431b5e6
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions tests/test_training/test_forward_output_no_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
checkpoint=True,
num_attention_heads=32,
embed_split_hidden=True,
vocab_size=103168,
vocab_size=92544,
embed_grad_scale=1,
parallel_output=False,
hidden_size=4096,
Expand All @@ -68,8 +68,9 @@
layer_norm_epsilon=1e-5,
use_flash_attn=False,
num_chunks=1,
no_bias=True,
),
model_type="INTERNLM",
model_type="INTERNLM2_PUBLIC",
alert_address=None,
monitor=dict(alert=dict(enable_feishu_alert=False, feishu_alert_address=None, light_monitor_address=None)),
grad_scaler=dict(
Expand Down Expand Up @@ -178,7 +179,7 @@ def train_check_output(args):

optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)

train_dl, dataset_types = build_train_loader_with_data_type()
_, dataset_types = build_train_loader_with_data_type()

metric = AccPerplex(
device=get_current_device(),
Expand Down Expand Up @@ -226,9 +227,9 @@ def train_check_output(args):

if gpc.is_rank_for_log():
standard_output_with_fa = torch.load(
os.path.join(share_path, "quailty_assurance/7B_no_flash_attention/output_with_fa.pt")
os.path.join(share_path, "quailty_assurance/7B_no_flash_attention/output_with_fa_internlm2.pt")
)
tensor1 = standard_output_with_fa
tensor1 = standard_output_with_fa[0][0]
tensor2 = output[0][0][0]

if torch.equal(tensor1, tensor2):
Expand Down

0 comments on commit 431b5e6

Please sign in to comment.