Skip to content

Commit

Permalink
nits
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Nov 9, 2023
1 parent 0caeade commit 0d678ad
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/models/gptsan_japanese/test_modeling_gptsan_japanese.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ def test_logits(self):
outputs = model(input_ids)
output_logits = outputs.logits.detach().cpu().numpy()
# Output of original model created with mesh-tensoflow
# fmt: off
target = [
# fmt: off
[-12.037839889526367, -12.433061599731445, -14.333840370178223, -12.450345993041992, -11.1661376953125,
-11.930137634277344, -10.659740447998047, -12.909574508666992, -13.241043090820312, -13.398579597473145,
-11.107524871826172, -12.3685941696167, -22.97943115234375, -10.481067657470703, -12.484030723571777,
Expand All @@ -242,8 +242,8 @@ def test_logits(self):
-10.113405227661133, -10.546867370605469, -10.04369068145752, -10.907809257507324, -10.504216194152832,
-11.129199028015137, -10.151124000549316, -21.96586799621582, -9.086349487304688, -11.730339050292969,
-10.460667610168457, -10.298049926757812, -10.784148216247559, -10.840693473815918, -22.03152847290039],
# fmt: on
]
# fmt: on
target = np.array(target).flatten()
predict = output_logits[0, :, :20].flatten()

Expand Down Expand Up @@ -341,8 +341,8 @@ def test_spout_generation(self):
input_ids_batch = tokenizer([input_text, input_text], return_tensors="pt").input_ids.to(torch_device)

# spout from uniform and one-hot
# fmt: off
spouts = [
# fmt: off
[0.87882208, 0.38426396, 0.33220248, 0.43890406, 0.16562252,
0.04803985, 0.211572 , 0.23188473, 0.37153068, 0.7836377 ,
0.02160172, 0.38761719, 0.75290772, 0.90198857, 0.34365777,
Expand Down Expand Up @@ -378,8 +378,8 @@ def test_spout_generation(self):
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0.],
# fmt: on
]
# fmt: on

output1 = model.generate(
input_ids=input_ids,
Expand Down

0 comments on commit 0d678ad

Please sign in to comment.