Skip to content

Commit c3a445e

Browse files
Fei Yufacebook-github-bot
authored andcommitted
fix cuda test criteria and dedupe similar cpu tests to OSS (#3397)
Summary: as title, follow up addressing OSS comments, i.e. https://www.internalfb.com/diff/D67302872?dst_version_fbid=815631520895451&transaction_fbid=1694362554586457 & https://www.internalfb.com/diff/D67302872?dst_version_fbid=815631520895451&transaction_fbid=683400918136325 Pull Request resolved: #3397 Reviewed By: TroyGarden Differential Revision: D83306596 fbshipit-source-id: 8c609c6d78b0983e2fb895bf24cf06cea413d308
1 parent 8abfd90 commit c3a445e

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

torchrec/modules/tests/test_itep_embedding_modules.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def generate_expected_address_lookup_buffer(
192192

193193
# pyre-ignore[56]: Pyre was not able to infer the type of argument
194194
@unittest.skipIf(
195-
torch.cuda.device_count() <= 1,
195+
torch.cuda.device_count() < 1,
196196
"Not enough GPUs, this test requires at least two GPUs",
197197
)
198198
def test_init_itep_module(self) -> None:
@@ -245,7 +245,7 @@ def test_init_itep_module_without_pruned_table(self) -> None:
245245

246246
# pyre-ignore[56]: Pyre was not able to infer the type of argument
247247
@unittest.skipIf(
248-
torch.cuda.device_count() <= 1,
248+
torch.cuda.device_count() < 1,
249249
"Not enough GPUs, this test requires at least one GPU",
250250
)
251251
def test_train_forward(self) -> None:
@@ -270,7 +270,7 @@ def test_train_forward(self) -> None:
270270

271271
# pyre-ignore[56]: Pyre was not able to infer the type of argument
272272
@unittest.skipIf(
273-
torch.cuda.device_count() <= 1,
273+
torch.cuda.device_count() < 1,
274274
"Not enough GPUs, this test requires at least one GPU",
275275
)
276276
def test_train_forward_vbe(self) -> None:
@@ -295,7 +295,7 @@ def test_train_forward_vbe(self) -> None:
295295

296296
# pyre-ignore[56]: Pyre was not able to infer the type of argument
297297
@unittest.skipIf(
298-
torch.cuda.device_count() <= 1,
298+
torch.cuda.device_count() < 1,
299299
"Not enough GPUs, this test requires at least one GPU",
300300
)
301301
# Mock out reset_weight_momentum to count calls
@@ -329,7 +329,7 @@ def test_check_pruning_schedule(
329329

330330
# pyre-ignore[56]: Pyre was not able to infer the type of argument
331331
@unittest.skipIf(
332-
torch.cuda.device_count() <= 1,
332+
torch.cuda.device_count() < 1,
333333
"Not enough GPUs, this test requires at least one GPU",
334334
)
335335
# Mock out reset_weight_momentum to count calls
@@ -365,7 +365,7 @@ def test_eval_forward(
365365

366366
# pyre-ignore[56]: Pyre was not able to infer the type of argument
367367
@unittest.skipIf(
368-
torch.cuda.device_count() <= 1,
368+
torch.cuda.device_count() < 1,
369369
"Not enough GPUs, this test requires at least two GPUs",
370370
)
371371
def test_iter_increment_per_forward(self) -> None:
@@ -397,7 +397,7 @@ def test_iter_increment_per_forward(self) -> None:
397397

398398
# pyre-ignore[56]: Pyre was not able to infer the type of argument
399399
@unittest.skipIf(
400-
torch.cuda.device_count() <= 1,
400+
torch.cuda.device_count() < 1,
401401
"Not enough GPUs, this test requires at least one GPU",
402402
)
403403
def test_iter_passed_as_int_to_itep_module(self) -> None:
@@ -440,7 +440,7 @@ def mock_forward(features: KeyedJaggedTensor, iter_val: int) -> List[Tensor]:
440440

441441
# pyre-ignore[56]: Pyre was not able to infer the type of argument
442442
@unittest.skipIf(
443-
torch.cuda.device_count() <= 1,
443+
torch.cuda.device_count() < 1,
444444
"Not enough GPUs, this test requires at least one GPU",
445445
)
446446
def test_blank_line_formatting_preserved(self) -> None:
@@ -472,7 +472,7 @@ def test_blank_line_formatting_preserved(self) -> None:
472472

473473
# pyre-ignore[56]: Pyre was not able to infer the type of argument
474474
@unittest.skipIf(
475-
torch.cuda.device_count() <= 1,
475+
torch.cuda.device_count() < 1,
476476
"Not enough GPUs, this test requires at least one GPU",
477477
)
478478
def test_iter_boundary_values_with_pruning_logic(self) -> None:
@@ -523,7 +523,7 @@ def mock_forward(
523523
torch.cuda.device_count() <= 1,
524524
"Not enough GPUs, this test requires at least one GPU",
525525
)
526-
def test_error_handling_invalid_iter_tensor_values(self) -> None:
526+
def test_error_handling_invalid_iter_tensor_values_cuda(self) -> None:
527527
"""Test behavior with invalid iter tensor values."""
528528
itep_module = GenericITEPModule(
529529
table_name_to_unpruned_hash_sizes=self._table_name_to_unpruned_hash_sizes,
@@ -567,7 +567,7 @@ def test_error_handling_invalid_iter_tensor_values(self) -> None:
567567

568568
# pyre-ignore[56]: Pyre was not able to infer the type of argument
569569
@unittest.skipIf(
570-
torch.cuda.device_count() <= 1,
570+
torch.cuda.device_count() < 1,
571571
"Not enough GPUs, this test requires at least one GPU",
572572
)
573573
def test_iter_consistency_across_training_steps(self) -> None:
@@ -626,7 +626,7 @@ def track_iter_forward(
626626

627627
# pyre-ignore[56]: Pyre was not able to infer the type of argument
628628
@unittest.skipIf(
629-
torch.cuda.device_count() <= 1,
629+
torch.cuda.device_count() < 1,
630630
"Not enough GPUs, this test requires at least one GPU",
631631
)
632632
def test_performance_iter_conversion_overhead(self) -> None:
@@ -680,7 +680,7 @@ def test_performance_iter_conversion_overhead(self) -> None:
680680

681681
# pyre-ignore[56]: Pyre was not able to infer the type of argument
682682
@unittest.skipIf(
683-
torch.cuda.device_count() <= 1,
683+
torch.cuda.device_count() < 1,
684684
"Not enough GPUs, this test requires at least one GPU",
685685
)
686686
def test_iter_type_conversion_edge_cases(self) -> None:
@@ -744,7 +744,7 @@ def capture_iter_forward(
744744

745745
# pyre-ignore[56]: Pyre was not able to infer the type of argument
746746
@unittest.skipIf(
747-
torch.cuda.device_count() <= 1,
747+
torch.cuda.device_count() < 1,
748748
"Not enough GPUs, this test requires at least one GPU",
749749
)
750750
def test_concurrent_forward_passes_iter_safety(self) -> None:

0 commit comments

Comments
 (0)