Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call the HuggingFace token a token consistently #466

Merged
merged 6 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ run/
__pycache__/
web/
secrets/
.vscode/
config/secrets.toml
20 changes: 6 additions & 14 deletions plugins/huggingface/modelgauge/suts/huggingface_inference.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
from typing import List, Optional

from huggingface_hub import ( # type: ignore
ChatCompletionOutput,
get_inference_endpoint,
InferenceClient,
InferenceEndpointStatus,
get_inference_endpoint,
)
from huggingface_hub.utils import HfHubHTTPError # type: ignore
from pydantic import BaseModel
from typing import List, Optional

from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
from modelgauge.prompt import TextPrompt
from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription
from modelgauge.secret_values import InjectSecret
from modelgauge.sut import PromptResponseSUT, SUTCompletion, SUTResponse
from modelgauge.sut_capabilities import AcceptsTextPrompt
from modelgauge.sut_decorator import modelgauge_sut
from modelgauge.sut_registry import SUTS
from pydantic import BaseModel


class ChatMessage(BaseModel):
Expand All @@ -28,16 +30,6 @@ class HuggingFaceInferenceChatRequest(BaseModel):
top_p: Optional[float] = None


class HuggingFaceInferenceToken(RequiredSecret):
@classmethod
def description(cls) -> SecretDescription:
return SecretDescription(
scope="hugging_face",
key="token",
instructions="You can create tokens at https://huggingface.co/settings/tokens.",
)


@modelgauge_sut(capabilities=[AcceptsTextPrompt])
class HuggingFaceInferenceSUT(PromptResponseSUT[HuggingFaceInferenceChatRequest, ChatCompletionOutput]):
"""A Hugging Face SUT that is hosted on a dedicated inference endpoint."""
Expand Down
7 changes: 4 additions & 3 deletions plugins/huggingface/tests/test_huggingface_inference.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from unittest.mock import Mock, patch

import pytest
from huggingface_hub import InferenceEndpointStatus # type: ignore
from huggingface_hub.utils import HfHubHTTPError # type: ignore
from pydantic import BaseModel
from unittest.mock import Mock, patch

from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
from modelgauge.prompt import SUTOptions, TextPrompt
from modelgauge.sut import SUTCompletion, SUTResponse
from modelgauge.suts.huggingface_inference import (
ChatMessage,
HuggingFaceInferenceChatRequest,
HuggingFaceInferenceSUT,
HuggingFaceInferenceToken,
)
from pydantic import BaseModel


@pytest.fixture
Expand Down
11 changes: 11 additions & 0 deletions src/modelgauge/auth/huggingface_inference_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from modelgauge.secret_values import RequiredSecret, SecretDescription


class HuggingFaceInferenceToken(RequiredSecret):
@classmethod
def description(cls) -> SecretDescription:
return SecretDescription(
scope="hugging_face",
key="token",
instructions="You can create tokens at https://huggingface.co/settings/tokens.",
)
11 changes: 11 additions & 0 deletions src/modelgauge/auth/together_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from modelgauge.secret_values import RequiredSecret, SecretDescription


class TogetherApiKey(RequiredSecret):
@classmethod
def description(cls) -> SecretDescription:
return SecretDescription(
scope="together",
key="api_key",
instructions="See https://api.together.xyz/settings/api-keys",
)
11 changes: 11 additions & 0 deletions src/modelgauge/auth/vllm_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from modelgauge.secret_values import RequiredSecret, SecretDescription


class VllmApiKey(RequiredSecret):
@classmethod
def description(cls) -> SecretDescription:
return SecretDescription(
scope="vllm",
key="api_key",
instructions="Contact MLCommons admin for access.",
)
28 changes: 4 additions & 24 deletions src/modelgauge/suts/together_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,16 @@
from pydantic import BaseModel, Field
from requests.adapters import HTTPAdapter, Retry # type:ignore

from modelgauge.auth.together_key import TogetherApiKey
from modelgauge.general import APIException
from modelgauge.prompt import ChatPrompt, ChatRole, SUTOptions, TextPrompt
from modelgauge.prompt_formatting import format_chat
from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription
from modelgauge.sut import (
PromptResponseSUT,
SUTCompletion,
SUTResponse,
TokenProbability,
TopTokens,
)
from modelgauge.sut_capabilities import (
AcceptsChatPrompt,
AcceptsTextPrompt,
ProducesPerTokenLogProbabilities,
)
from modelgauge.secret_values import InjectSecret
from modelgauge.sut import PromptResponseSUT, SUTCompletion, SUTResponse, TokenProbability, TopTokens
from modelgauge.sut_capabilities import AcceptsChatPrompt, AcceptsTextPrompt, ProducesPerTokenLogProbabilities
from modelgauge.sut_decorator import modelgauge_sut
from modelgauge.sut_registry import SUTS


class TogetherApiKey(RequiredSecret):
@classmethod
def description(cls) -> SecretDescription:
return SecretDescription(
scope="together",
key="api_key",
instructions="See https://api.together.xyz/settings/api-keys",
)


_SYSTEM_ROLE = "system"
_USER_ROLE = "user"
_ASSISTANT_ROLE = "assistant"
Expand Down
82 changes: 23 additions & 59 deletions src/modelgauge/tests/safe_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,78 +5,42 @@

from modelgauge.aggregations import get_measurement_stats_by_key
from modelgauge.annotator import Annotator
from modelgauge.annotators.llama_guard_annotator import (
LlamaGuardAnnotation,
LlamaGuardAnnotator,
LlamaGuardVersion,
)
from modelgauge.annotators.llama_guard_annotator import LlamaGuardAnnotation, LlamaGuardAnnotator, LlamaGuardVersion
from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
from modelgauge.auth.together_key import TogetherApiKey
from modelgauge.auth.vllm_key import VllmApiKey
from modelgauge.base_test import PromptResponseTest
from modelgauge.dependency_helper import DependencyHelper
from modelgauge.external_data import ExternalData, WebData
from modelgauge.prompt import SUTOptions, TextPrompt
from modelgauge.secret_values import (
InjectSecret,
RequiredSecret,
SecretDescription,
)
from modelgauge.single_turn_prompt_response import (
TestItem,
PromptWithContext,
TestItemAnnotations,
MeasuredTestItem,
)
from modelgauge.secret_values import InjectSecret
from modelgauge.single_turn_prompt_response import MeasuredTestItem, PromptWithContext, TestItem, TestItemAnnotations
from modelgauge.sut_capabilities import AcceptsTextPrompt
from modelgauge.suts.together_client import TogetherApiKey
from modelgauge.test_decorator import modelgauge_test
from modelgauge.test_registry import TESTS
from modelgauge.tests.safe import SafeTestItemContext, SafeTestResult, PersonaResult
from modelgauge.tests.safe import PersonaResult, SafeTestItemContext, SafeTestResult

try:
from modelgauge.annotators.ensemble_annotator import EnsembleAnnotator # type: ignore
from modelgauge.annotators.llama_guard_2_lora_annotator import ( # type: ignore
LlamaGuard2LoRAAnnotator, # type: ignore
LlamaGuard2LoRAConfig, # type: ignore
)
from modelgauge.annotators.llama_3_70b_chat_annotator import LLAMA_3_70B_CONFIG # type: ignore
from modelgauge.annotators.llama_guard_2_lora_annotator import LlamaGuard2LoRAAnnotator # type: ignore
from modelgauge.annotators.llama_guard_2_lora_annotator import LlamaGuard2LoRAConfig # type: ignore
from modelgauge.annotators.local_lg2_lora_annotator import config as LG2_LORA_CONFIG # type: ignore
from modelgauge.annotators.mistral_8x22b_instruct_annotator import ( # type: ignore
MISTRAL_8x22B_CONFIG, # type: ignore
)
from modelgauge.annotators.prompt_engineered_annotator import ( # type: ignore
PromptEngineeredAnnotator, # type: ignore
)
from modelgauge.annotators.wildguard_annotator import ( # type: ignore
WildguardAnnotator, # type: ignore
WILDGUARD_ANNOTATOR_CONFIG,
)
from modelgauge.annotators.mistral_8x22b_instruct_annotator import (
MISTRAL_8x22B_CONFIG,
) # type: ignore
from modelgauge.annotators.prompt_engineered_annotator import (
PromptEngineeredAnnotator,
) # type: ignore
from modelgauge.annotators.wildguard_annotator import WILDGUARD_ANNOTATOR_CONFIG # type: ignore
from modelgauge.annotators.wildguard_annotator import WildguardAnnotator # type: ignore
from modelgauge.safety_model_response import SafetyModelResponse # type: ignore

PRIVATE_ANNOTATORS_AVAILABLE = True
except ImportError:
PRIVATE_ANNOTATORS_AVAILABLE = False


# TODO: Make optional?
class VllmApiKey(RequiredSecret):
@classmethod
def description(cls) -> SecretDescription:
return SecretDescription(
scope="vllm",
key="api_key",
instructions="Contact MLCommons admin for access.",
)


class HuggingFaceKey(RequiredSecret):
@classmethod
def description(cls) -> SecretDescription:
return SecretDescription(
scope="huggingface",
key="api_key",
instructions="Add your HuggingFace token to secrets, or contact MLCommons admin.",
)


class Locale(str, Enum):
EN_US = "EN_US" # English, United States
FR_FR = "FR_FR" # French, France
Expand Down Expand Up @@ -147,7 +111,7 @@ def __init__(
persona_types: List[SafePersonasVersion1],
together_api_key: TogetherApiKey,
vllm_api_key: Optional[VllmApiKey] = None,
huggingface_key: Optional[HuggingFaceKey] = None,
huggingface_inference_token: Optional[HuggingFaceInferenceToken] = None,
use_private_annotators: bool = False,
):
"""Uses llamaguard 2 if use_private_annotators is False."""
Expand All @@ -168,7 +132,7 @@ def __init__(
self.together_api_key = together_api_key
if self.use_private_annotators:
self._configure_vllm_annotators(vllm_api_key) # type: ignore
self._configure_huggingface_annotators(huggingface_key) # type: ignore
self._configure_huggingface_annotators(huggingface_inference_token) # type: ignore
self._configure_together_annotators(together_api_key)

# TODO: Encapsulate multipart secrets (e.g. key and URL).
Expand All @@ -182,8 +146,8 @@ def _configure_vllm_annotators(self, vllm_api_key: VllmApiKey):
LG2_LORA_CONFIG.api_key = self.vllm_api_key.value
LG2_LORA_CONFIG.base_url = self.vllm_endpoint_url

def _configure_huggingface_annotators(self, huggingface_key: HuggingFaceKey):
WILDGUARD_ANNOTATOR_CONFIG.api_key = huggingface_key.value
def _configure_huggingface_annotators(self, huggingface_inference_token: HuggingFaceInferenceToken):
WILDGUARD_ANNOTATOR_CONFIG.token = huggingface_inference_token.value
assert WILDGUARD_ANNOTATOR_CONFIG.is_valid(), "HuggingFace configuration is missing a token or endpoint URL."

def _configure_together_annotators(self, together_api_key: TogetherApiKey):
Expand Down Expand Up @@ -328,7 +292,7 @@ def aggregate_measurements(self, items: List[MeasuredTestItem]) -> SafeTestResul
if PRIVATE_ANNOTATORS_AVAILABLE:
for hazard in SafeTestVersion1.hazards:
VLLM_API_KEY = InjectSecret(VllmApiKey)
HUGGINGFACE_KEY = InjectSecret(HuggingFaceKey) # was: os.getenv("HF_TOKEN", "")
HUGGINGFACE_INFERENCE_TOKEN = InjectSecret(HuggingFaceInferenceToken)
TESTS.register(
SafeTestVersion1,
f"safe-{hazard}-1.0-private",
Expand All @@ -337,6 +301,6 @@ def aggregate_measurements(self, items: List[MeasuredTestItem]) -> SafeTestResul
ALL_PERSONAS,
TOGETHER_KEY,
vllm_api_key=VLLM_API_KEY,
huggingface_key=HUGGINGFACE_KEY,
huggingface_inference_token=HUGGINGFACE_INFERENCE_TOKEN,
use_private_annotators=True,
)
24 changes: 9 additions & 15 deletions tests/modelgauge_tests/test_safe.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import pytest

from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
from modelgauge.auth.together_key import TogetherApiKey
from modelgauge.auth.vllm_key import VllmApiKey
from modelgauge.prompt import TextPrompt
from modelgauge.single_turn_prompt_response import (
MeasuredTestItem,
PromptWithContext,
TestItem,
)
from modelgauge.suts.together_client import TogetherApiKey
from modelgauge.single_turn_prompt_response import MeasuredTestItem, PromptWithContext, TestItem
from modelgauge.tests.safe import (
FOLDER_NAME,
PersonaResult,
Expand All @@ -15,18 +13,14 @@
SafeTestItemContext,
SafeTestResult,
)
from modelgauge.tests.safe_v1 import (
HuggingFaceKey,
Locale,
SafeTestVersion1,
SafePersonasVersion1,
VllmApiKey,
)
from modelgauge.tests.safe_v1 import Locale, SafePersonasVersion1, SafeTestVersion1

from modelgauge_tests.fake_dependency_helper import FakeDependencyHelper, make_csv


FAKE_TOGETHER_KEY = TogetherApiKey("some-value")
FAKE_VLLM_KEY = VllmApiKey("some-value")
FAKE_HF_KEY = HuggingFaceKey("some-value")
FAKE_HF_TOKEN = HuggingFaceInferenceToken("some-value")


def _init_safe_test(hazard, persona_types):
Expand All @@ -46,7 +40,7 @@ def _init_safe_test_v1_private(hazard, persona_types):
persona_types,
FAKE_TOGETHER_KEY,
FAKE_VLLM_KEY,
FAKE_HF_KEY,
FAKE_HF_TOKEN,
use_private_annotators=True,
)

Expand Down
Loading