diff --git a/pyproject.toml b/pyproject.toml index 0e165f8e1..8a92b468e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -259,7 +259,9 @@ filterwarnings = [ # should be addressed in pandas 'ignore:np.find_common_type is deprecated.*:DeprecationWarning', 'ignore:.*distutils Version classes are deprecated.*', - 'ignore:.*`resume_download` is deprecated.*' + # TODO: add more context but linked to sentence-transformers + 'ignore:.*`resume_download` is deprecated.*', + 'ignore:.*`clean_up_tokenization_spaces` was not set.*', ] addopts = "--doctest-modules" doctest_optionflags = "NORMALIZE_WHITESPACE ELLIPSIS" diff --git a/skrub/tests/test_sentence_encoder.py b/skrub/tests/test_sentence_encoder.py index 3147f8739..fea945dd4 100644 --- a/skrub/tests/test_sentence_encoder.py +++ b/skrub/tests/test_sentence_encoder.py @@ -8,6 +8,8 @@ pytest.importorskip("sentence_transformers") +MODEL_NAME = "sentence-transformers/paraphrase-albert-small-v2" + def test_missing_import_error(): try: @@ -17,7 +19,7 @@ def test_missing_import_error(): else: return - st = SentenceEncoder() + st = SentenceEncoder(model_name_or_path=MODEL_NAME) x = pd.Series(["oh no"]) with pytest.raises(ImportError, match="Missing optional dependency"): st.fit(x) @@ -25,7 +27,7 @@ def test_missing_import_error(): def test_sentence_encoder(df_module): X = df_module.make_column("", ["hello sir", "hola que tal"]) - encoder = SentenceEncoder(n_components=2) + encoder = SentenceEncoder(model_name_or_path=MODEL_NAME, n_components=2) X_out = encoder.fit_transform(X) assert X_out.shape == (2, 2) @@ -36,18 +38,18 @@ def test_sentence_encoder(df_module): @pytest.mark.parametrize("X", [["hello"], "hello"]) def test_not_a_series(X): with pytest.raises(ValueError): - SentenceEncoder().fit(X) + SentenceEncoder(model_name_or_path=MODEL_NAME).fit(X) def test_not_a_series_with_string(df_module): X = df_module.make_column("", [1, 2, 3]) with pytest.raises(RejectColumn): - SentenceEncoder().fit(X) + SentenceEncoder(model_name_or_path=MODEL_NAME).fit(X) def test_missing_value(df_module): X = df_module.make_column("", [None, None, "hey"]) - encoder = SentenceEncoder(n_components="all") + encoder = SentenceEncoder(model_name_or_path=MODEL_NAME, n_components="all") X_out = encoder.fit_transform(X) assert X_out.shape == (3, 384) @@ -57,17 +59,17 @@ def test_missing_value(df_module): def test_n_components(df_module): X = df_module.make_column("", ["hello sir", "hola que tal"]) - encoder = SentenceEncoder(n_components="all") + encoder = SentenceEncoder(model_name_or_path=MODEL_NAME, n_components="all") X_out = encoder.fit_transform(X) assert X_out.shape[1] == 384 assert encoder.n_components_ == 384 - encoder = SentenceEncoder(n_components=2) + encoder = SentenceEncoder(model_name_or_path=MODEL_NAME, n_components=2) X_out = encoder.fit_transform(X) assert X_out.shape[1] == 2 assert encoder.n_components_ == 2 - encoder = SentenceEncoder(n_components=30) + encoder = SentenceEncoder(model_name_or_path=MODEL_NAME, n_components=30) with pytest.warns(UserWarning): X_out = encoder.fit_transform(X) assert not hasattr(encoder, "pca_") @@ -77,19 +79,21 @@ def test_n_components(df_module): def test_wrong_parameters(): with pytest.raises(ValueError, match="Got n_components='yes'"): - SentenceEncoder(n_components="yes")._check_params() + SentenceEncoder( + model_name_or_path=MODEL_NAME, n_components="yes" + )._check_params() with pytest.raises(ValueError, match="Got batch_size=-10"): - SentenceEncoder(batch_size=-10)._check_params() + SentenceEncoder(model_name_or_path=MODEL_NAME, batch_size=-10)._check_params() with pytest.raises(ValueError, match="Got model_name_or_path=1"): SentenceEncoder(model_name_or_path=1)._check_params() with pytest.raises(ValueError, match="Got norm=l3"): - SentenceEncoder(norm="l3")._check_params() + SentenceEncoder(model_name_or_path=MODEL_NAME, norm="l3")._check_params() with pytest.raises(ValueError, match="Got cache_folder=1"): - SentenceEncoder(cache_folder=1)._check_params() + SentenceEncoder(model_name_or_path=MODEL_NAME, cache_folder=1)._check_params() with pytest.raises(ValueError, match="Got model_name_or_path=1"): SentenceEncoder(model_name_or_path=1)._check_params() @@ -103,7 +107,7 @@ def test_wrong_model_name(): def test_transform_equal_fit_transform(df_module): x = df_module.make_column("", ["hello again"]) - encoder = SentenceEncoder() + encoder = SentenceEncoder(model_name_or_path=MODEL_NAME) X_out = encoder.fit_transform(x) X_out_2 = encoder.transform(x) df_module.assert_frame_equal(X_out, X_out_2)