diff --git a/tests/models/clap/test_feature_extraction_clap.py b/tests/models/clap/test_feature_extraction_clap.py index c4a85606e0bfca..733dd66681c473 100644 --- a/tests/models/clap/test_feature_extraction_clap.py +++ b/tests/models/clap/test_feature_extraction_clap.py @@ -407,22 +407,22 @@ def test_integration_fusion_long_input(self): EXPECTED_INPUT_FEATURES = torch.tensor( [ [ - -11.1830, -10.1894, -8.6051, -4.8578, -1.3268, -8.4606, -14.5453, - -9.2017, 0.5781, 16.2129, 14.8289, 3.6326, -3.8794, -6.5544, - -2.4408, 1.9531, 6.0967, 1.7590, -7.6730, -6.1571, 2.0052, - 16.6694, 20.6447, 21.2145, 13.4972, 15.9043, 16.8987, 4.1766, - 11.9428, 21.2372, 12.3016, 4.8604, 6.7241, 1.8543, 4.9235, - 5.3188, -0.9897, -1.2416, -6.5864, 2.9529, 2.9274, 6.4753, - 10.2300, 11.2127, 3.4042, -1.0055, -6.0475, -6.7524, -3.9801, - -1.4434, 0.4740, -0.1584, -4.5457, -8.5746, -8.8428, -13.1475, - -9.6079, -8.5798, -4.1143, -3.7966, -7.1651, -6.1517, -8.0258, + -11.1830, -10.1894, -8.6051, -4.8578, -1.3268, -8.4606, -14.5453, + -9.2017, 0.5781, 16.2129, 14.8289, 3.6326, -3.8794, -6.5544, + -2.4408, 1.9531, 6.0967, 1.7590, -7.6730, -6.1571, 2.0052, + 16.6694, 20.6447, 21.2145, 13.4972, 15.9043, 16.8987, 4.1766, + 11.9428, 21.2372, 12.3016, 4.8604, 6.7241, 1.8543, 4.9235, + 5.3188, -0.9897, -1.2416, -6.5864, 2.9529, 2.9274, 6.4753, + 10.2300, 11.2127, 3.4042, -1.0055, -6.0475, -6.7524, -3.9801, + -1.4434, 0.4740, -0.1584, -4.5457, -8.5746, -8.8428, -13.1475, + -9.6079, -8.5798, -4.1143, -3.7966, -7.1651, -6.1517, -8.0258, -12.1486 ], [ - -10.2017, -7.9924, -5.9517, -3.9372, -1.9735, -4.3130, 16.1647, - 25.0592, 23.5532, 14.4974, -7.0778, -10.2262, 6.4782, 20.3454, - 19.4269, 1.7976, -16.5070, 4.9380, 12.3390, 6.9285, -13.6325, - -8.5298, 1.0839, -5.9629, -8.4812, 3.1331, -2.0963, -16.6046, + -10.2017, -7.9924, -5.9517, -3.9372, -1.9735, -4.3130, 16.1647, + 25.0592, 23.5532, 14.4974, -7.0778, -10.2262, 6.4782, 20.3454, + 19.4269, 1.7976, -16.5070, 4.9380, 12.3390, 6.9285, -13.6325, + -8.5298, 1.0839, -5.9629, -8.4812, 3.1331, -2.0963, -16.6046, -14.0070, -17.5707, -13.2080, -17.2168, -17.7770, -12.1111, -18.6184, -17.1897, -13.9801, -12.0426, -23.5400, -25.6823, -23.5813, -18.7847, -20.5473, -25.6458, -19.7585, -27.6007, -28.9276, -24.8948, -25.4458, @@ -431,27 +431,27 @@ def test_integration_fusion_long_input(self): -29.6947 ], [ - -9.2083, -7.2966, -6.2097, -7.9957, -2.9279, -11.1844, -6.1487, - 5.0738, 19.2957, 21.4577, 14.6803, -3.3148, -6.3328, -2.3537, - 6.9511, 15.2963, 14.6618, 5.2078, -0.0868, 1.1920, 18.1982, - 20.8467, 10.8038, 2.2521, 7.6906, 7.7427, -1.2541, -5.0018, - 0.9809, -2.1582, -5.4576, -5.4758, -11.8883, -9.0605, -8.4639, - -9.9899, -0.0543, -5.1628, 0.0481, -4.1505, -4.8141, -7.8235, - -9.0621, -10.1742, -8.9596, -11.5377, -16.5596, -17.1852, -17.5027, - -20.9322, -23.9538, -25.2600, -25.3426, -27.4534, -26.8857, -22.7851, - -25.8286, -24.8395, -23.8889, -24.2093, -26.5415, -23.7280, -25.6849, - -22.3628 + -9.2078, -7.2963, -6.2095, -7.9959, -2.9280, -11.1843, -6.1490, + 5.0733, 19.2957, 21.4578, 14.6803, -3.3153, -6.3334, -2.3542, + 6.9509, 15.2965, 14.6620, 5.2075, -0.0873, 1.1919, 18.1986, + 20.8470, 10.8035, 2.2516, 7.6905, 7.7427, -1.2543, -5.0018, + 0.9809, -2.1584, -5.4580, -5.4760, -11.8888, -9.0605, -8.4638, + -9.9897, -0.0540, -5.1629, 0.0483, -4.1504, -4.8140, -7.8236, + -9.0622, -10.1742, -8.9597, -11.5380, -16.5603, -17.1858, -17.5032, + -20.9326, -23.9543, -25.2602, -25.3429, -27.4536, -26.8859, -22.7852, + -25.8288, -24.8399, -23.8893, -24.2096, -26.5415, -23.7281, -25.6851, + -22.3629 ], [ - 1.3448, 2.9883, 4.0366, -0.8019, -10.4191, -10.0883, -4.3812, - 0.8136, 2.1579, 0.0832, 1.0949, -0.9759, -5.5319, -4.6009, - -6.5452, -14.9155, -20.1584, -9.3611, -2.4271, 1.4031, 4.9910, - 8.6916, 8.6785, 10.1973, 9.9029, 5.3840, 7.5336, 5.2803, - 2.8144, -0.3138, 2.2216, 5.7328, 7.5574, 7.7402, 1.0681, - 3.1049, 7.0742, 6.5588, 7.3712, 5.7881, 8.6874, 8.7725, - 2.8133, -4.5809, -6.1317, -5.1719, -5.0192, -9.0977, -10.9391, - -6.0769, 1.6016, -0.8965, -7.2252, -7.8632, -11.4468, -11.7446, - -10.7447, -7.0601, -2.7748, -4.1798, -2.8433, -3.1352, 0.8097, + 1.3448, 2.9883, 4.0366, -0.8019, -10.4191, -10.0883, -4.3812, + 0.8136, 2.1579, 0.0832, 1.0949, -0.9759, -5.5319, -4.6009, + -6.5452, -14.9155, -20.1584, -9.3611, -2.4271, 1.4031, 4.9910, + 8.6916, 8.6785, 10.1973, 9.9029, 5.3840, 7.5336, 5.2803, + 2.8144, -0.3138, 2.2216, 5.7328, 7.5574, 7.7402, 1.0681, + 3.1049, 7.0742, 6.5588, 7.3712, 5.7881, 8.6874, 8.7725, + 2.8133, -4.5809, -6.1317, -5.1719, -5.0192, -9.0977, -10.9391, + -6.0769, 1.6016, -0.8965, -7.2252, -7.8632, -11.4468, -11.7446, + -10.7447, -7.0601, -2.7748, -4.1798, -2.8433, -3.1352, 0.8097, 6.4212 ] ] @@ -461,12 +461,12 @@ def test_integration_fusion_long_input(self): input_speech = torch.cat([torch.tensor(x) for x in self._load_datasamples(5)]) feature_extractor = ClapFeatureExtractor() for padding, EXPECTED_VALUES, block_idx in zip( - ["repeat", "repeatpad", None, "pad"], EXPECTED_INPUT_FEATURES, [0, 1, 3, 2] + ["repeat", "repeatpad", None, "pad"], EXPECTED_INPUT_FEATURES, [1, 2, 0, 3] ): set_seed(987654321) input_features = feature_extractor(input_speech, return_tensors="pt", padding=padding).input_features self.assertEqual(input_features.shape, (1, 4, 1001, 64)) - self.assertTrue(torch.allclose(input_features[0, block_idx, MEL_BIN], EXPECTED_VALUES, atol=1e-4)) + self.assertTrue(torch.allclose(input_features[0, block_idx, MEL_BIN], EXPECTED_VALUES, atol=1e-3)) def test_integration_rand_trunc_long_input(self): # fmt: off @@ -485,8 +485,8 @@ def test_integration_rand_trunc_long_input(self): -22.7864 ], [ - -35.7719, -27.2566, -23.6964, -27.5521, 0.2510, 7.4391, 1.3917, - -13.3417, -28.1758, -17.0856, -5.7723, -0.8000, -7.8832, -15.5548, + -35.7719, -27.2566, -23.6964, -27.5521, 0.2510, 7.4391, 1.3917, + -13.3417, -28.1758, -17.0856, -5.7723, -0.8000, -7.8832, -15.5548, -30.5935, -24.7571, -13.7009, -10.3432, -21.2464, -24.8118, -19.4080, -14.9779, -11.7991, -18.4485, -20.1982, -17.3652, -20.6328, -28.2967, -25.7819, -21.8962, -28.5083, -29.5719, -30.2120, -35.7033, -31.8218, @@ -509,11 +509,11 @@ def test_integration_rand_trunc_long_input(self): -27.1716 ], [ - -33.2015, -28.7741, -21.9457, -23.4888, -32.1072, -8.6307, 3.2724, - 5.9157, -0.9221, -30.1814, -31.0015, -27.4508, -27.0477, -9.5342, - 0.3221, 0.6511, -7.1596, -25.9707, -32.8924, -32.2300, -13.8974, - -0.4895, 0.9168, -10.7663, -27.1176, -35.0829, -11.6859, -4.8855, - -11.8898, -26.6167, -5.6192, -3.8443, -19.7947, -14.4101, -8.6236, + -33.2015, -28.7741, -21.9457, -23.4888, -32.1072, -8.6307, 3.2724, + 5.9157, -0.9221, -30.1814, -31.0015, -27.4508, -27.0477, -9.5342, + 0.3221, 0.6511, -7.1596, -25.9707, -32.8924, -32.2300, -13.8974, + -0.4895, 0.9168, -10.7663, -27.1176, -35.0829, -11.6859, -4.8855, + -11.8898, -26.6167, -5.6192, -3.8443, -19.7947, -14.4101, -8.6236, -21.2458, -21.0801, -17.9136, -24.4663, -18.6333, -24.8085, -15.5854, -15.4344, -11.5046, -22.3625, -27.3387, -32.4353, -30.9670, -31.3789, -35.4044, -34.4591, -25.2433, -28.0773, -33.8736, -33.0224, -33.3155,