Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scorer support for aggregations. Allowing for BM25 / Vector hybrid search. #3408

Open
rbs333 opened this issue Oct 8, 2024 · 0 comments

Comments

@rbs333
Copy link
Contributor

rbs333 commented Oct 8, 2024

Description: Currently there is no way to set the scorer in an aggregate request. This makes running hybrid BM25 / queries impractical. This PR adds support so that it can be executed.

Example test added for hybrid query:

@pytest.mark.redismod
@skip_ifmodversion_lt("2.10.05", "search")
async def test_aggregations_hybrid_scoring(client):
    client.ft().create_index(
        (
            TextField("name", sortable=True, weight=5.0),
            TextField("description", sortable=True, weight=5.0),
            VectorField(
                "vector",
                "HNSW",
                {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"},
            ),
        )
    )

    client.hset(
        "doc1",
        mapping={
            "name": "cat book",
            "description": "a book about cats",
            "vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(),
        },
    )
    client.hset(
        "doc2",
        mapping={
            "name": "dog book",
            "description": "a book about dogs",
            "vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(),
        },
    )

    query_string = "(@description:cat)=>[KNN 3 @vector $vec_param AS dist]"
    req = (
        aggregations.AggregateRequest(query_string)
        .scorer("BM25")
        .add_scores()
        .apply(hybrid_score="@__score + @dist")
        .load("*")
        .dialect(4)
    )

    res = (
        client.ft()
        .aggregate(
            req,
            query_params={
                "vec_param": np.array([0.11, 0.21]).astype(np.float32).tobytes()
            },
        )
        .rows[0]
    )

    assert len(res) == 6
    assert b"hybrid_score" in res
    assert b"__score" in res
    assert b"__dist" in res
    assert float(res[1]) + float(res[3]) == float(res[5])

PR: #3409

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant