Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update argilla integration to use argilla_sdk v2 #705

Merged
merged 22 commits into from
Jul 30, 2024
Merged

Conversation

alvarobartt
Copy link
Member

@alvarobartt alvarobartt commented Jun 6, 2024

Description

This PR renames and updates Argilla to ArgillaBase, since now the client in argilla_sdk (later to be renamed to argilla only as per a recent discussion with @frascuchon) is named Argilla too. Besides that, the code has been updated to use the latest Python client instead not only for ArgillaBase but also for the subclasses TextGenerationToArgilla and PreferenceToArgilla.

Warning

This change here implies that the argilla server version should be 1.27.0 or higher, otherwise the argilla_sdk won't work.

Closes argilla-io/argilla#4880

@alvarobartt alvarobartt added this to the 1.2.0 milestone Jun 6, 2024
@alvarobartt alvarobartt self-assigned this Jun 6, 2024
Copy link
Contributor

@burtenshaw burtenshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice start. Just a few high level comments.

src/distilabel/steps/argilla/base.py Outdated Show resolved Hide resolved
src/distilabel/steps/argilla/text_generation.py Outdated Show resolved Hide resolved
@alvarobartt
Copy link
Member Author

alvarobartt commented Jun 6, 2024

Edit: the issue was with the Argilla Server version as I was using 1.26.0 while 1.27.0 or higher was required 👍🏻

As an update @burtenshaw @frascuchon I've installed argilla_sdk from the latest version of argilla-python in main and running the code below, leads to the following exception, complaining about not finding the /records/bulk endpoint, could you try to reproduce on your end?


Install as pip install "distilabel[hf-inference-endpoints] @ git+https://github.com/argilla-io/distilabel.git@argilla-2.0" and then run the code below with your personal HF_TOKEN and your Argilla credentials (feel free to use dev):

from uuid import uuid4

from distilabel.llms import InferenceEndpointsLLM
from distilabel.pipeline.local import Pipeline
from distilabel.steps import (
    LoadDataFromDicts,
    TextGenerationToArgilla,
)
from distilabel.steps.tasks import TextGeneration

if __name__ == "__main__":
    with Pipeline(name="my-pipeline") as pipeline:
        load_dataset = LoadDataFromDicts(
            name="load_dataset",
            data=[
                {
                    "instruction": "Write a short story about a dragon that saves a princess from a tower.",
                },
            ],
        )

        text_generation = TextGeneration(
            name="text_generation",
            llm=InferenceEndpointsLLM(
                model_id="meta-llama/Meta-Llama-3-8B-Instruct",
                tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
                api_key="...",  # type: ignore
            ),
            num_generations=4,
            group_generations=True,
        )

        text_generation_to_argilla = TextGenerationToArgilla(
            name="text_generation_to_argilla",
            api_url="...",
            api_key="...",  # type: ignore
            dataset_name=f"text-generation-{uuid4()}",
            dataset_workspace="admin",
        )

        (  # type: ignore
            load_dataset
            >> text_generation
            >> text_generation_to_argilla
        )

    pipeline.run(
        parameters={
            text_generation.name: {  # type: ignore
                "llm": {
                    "generation_kwargs": {
                        "max_new_tokens": 512,
                        "temperature": 0.7,
                    },
                },
            },
        }
    )

The logs then look like:

image

@alvarobartt alvarobartt marked this pull request as ready for review June 7, 2024 11:11
@alvarobartt alvarobartt linked an issue Jun 7, 2024 that may be closed by this pull request
@alvarobartt alvarobartt modified the milestones: 1.2.0, 1.3.0 Jun 11, 2024
Base automatically changed from develop to main June 18, 2024 12:36
@gabrielmbmb gabrielmbmb changed the base branch from main to develop June 19, 2024 11:21
@gabrielmbmb gabrielmbmb modified the milestones: 1.3.0, 1.4.0 Jul 3, 2024
Copy link

Documentation for this PR has been built. You can view it at: https://distilabel.argilla.io/pr-705/

@gabrielmbmb gabrielmbmb added the enhancement New feature or request label Jul 30, 2024
Copy link

codspeed-hq bot commented Jul 30, 2024

CodSpeed Performance Report

Merging #705 will not alter performance

Comparing argilla-2.0 (b0a6b71) with develop (be61d20)

Summary

✅ 1 untouched benchmarks

@gabrielmbmb gabrielmbmb merged commit 18dc02c into develop Jul 30, 2024
7 checks passed
@gabrielmbmb gabrielmbmb deleted the argilla-2.0 branch July 30, 2024 16:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
argilla enhancement New feature or request
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

[FEATURE] Upgrade distilabel to use Argilla 2.0
4 participants