diff --git a/docs/source/en/main_classes/pipelines.md b/docs/source/en/main_classes/pipelines.md index 9e699f7d2027..84860c8514ca 100644 --- a/docs/source/en/main_classes/pipelines.md +++ b/docs/source/en/main_classes/pipelines.md @@ -427,6 +427,23 @@ Pipelines available for natural language processing tasks include the following. - __call__ - all +The TextGenerationPipeline supports optional safety checking through the `safety_config` parameter. See the [Safe Generation example](https://github.com/huggingface/transformers/tree/main/examples/safe_generation) for implementing custom safety checkers. + +**Example**: +```python +from transformers import pipeline +from transformers.generation.safety import SafetyConfig +from examples.safe_generation.checkers import BasicToxicityChecker + +# Create safety checker +checker = BasicToxicityChecker(threshold=0.7) +config = SafetyConfig.from_checker(checker) + +# Use with text generation pipeline +pipe = pipeline("text-generation", model="gpt2") +result = pipe("Hello", safety_config=config, max_new_tokens=50) +``` + ### Text2TextGenerationPipeline [[autodoc]] Text2TextGenerationPipeline diff --git a/examples/safe_generation/README.md b/examples/safe_generation/README.md new file mode 100644 index 000000000000..80659d87e638 --- /dev/null +++ b/examples/safe_generation/README.md @@ -0,0 +1,254 @@ +# Safe Generation Example Implementations + +This directory contains reference implementations of safety checkers for the transformers safe generation feature. + +## Overview + +The core transformers library provides **infrastructure only**: +- `SafetyChecker` abstract base class +- `SafetyLogitsProcessor` and `SafetyStoppingCriteria` +- `SafetyConfig` configuration system +- `SafetyResult` and `SafetyViolation` data structures + +**Concrete implementations** like `BasicToxicityChecker` are provided here as examples. + +This follows the same pattern as watermarking in transformers - the core provides infrastructure, users provide or choose implementations. + +## Usage + +### Basic Usage with Pipeline + +```python +from examples.safe_generation import BasicToxicityChecker +from transformers import pipeline +from transformers.generation.safety import SafetyConfig + +# Create a safety checker +checker = BasicToxicityChecker(threshold=0.7) + +# Option 1: Use with SafetyConfig +config = SafetyConfig.from_checker(checker) +pipe = pipeline("text-generation", model="gpt2", safety_config=config) + +# Option 2: Direct generation with model +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("gpt2") +tokenizer = AutoTokenizer.from_pretrained("gpt2") + +# Attach tokenizer to model (required for safety processors) +model.tokenizer = tokenizer + +inputs = tokenizer("Hello, I want to", return_tensors="pt") +outputs = model.generate(**inputs, safety_config=config, max_new_tokens=20) +print(tokenizer.decode(outputs[0])) +``` + +### Using Preset Configurations + +SafetyConfig provides three preset configurations for different safety/performance trade-offs: + +```python +from examples.safe_generation import BasicToxicityChecker +from transformers.generation.safety import SafetyConfig, STRICT_PRESET, MODERATE_PRESET, LENIENT_PRESET + +checker = BasicToxicityChecker(threshold=0.7) + +# STRICT preset - Maximum safety, more overhead +# - Smaller caches (50 entries, 500 unsafe hash limit) +# - Returns violations and metadata for debugging +config_strict = SafetyConfig.from_checker(checker, **STRICT_PRESET) + +# MODERATE preset - Balanced approach (default) +# - Medium caches (100 entries, 1000 unsafe hash limit) +# - No extra metadata (better performance) +config_moderate = SafetyConfig.from_checker(checker, **MODERATE_PRESET) + +# LENIENT preset - Performance-optimized +# - Larger caches (200 entries, 2000 unsafe hash limit) +# - No extra metadata +config_lenient = SafetyConfig.from_checker(checker, **LENIENT_PRESET) + +# Custom preset - Mix and match +config_custom = SafetyConfig.from_checker( + checker, + cache_size=150, + unsafe_hash_limit=1500, + return_violations=True, # Get detailed violation info + return_metadata=False # Skip extra metadata +) +``` + +**Preset Comparison:** + +| Preset | cache_size | unsafe_hash_limit | return_violations | return_metadata | Use Case | +|--------|-----------|-------------------|-------------------|-----------------|----------| +| STRICT | 50 | 500 | True | True | High-risk applications, debugging | +| MODERATE | 100 | 1000 | False | False | General use (balanced) | +| LENIENT | 200 | 2000 | False | False | Performance-critical, trusted content | + +### Customizing the BasicToxicityChecker + +```python +from examples.safe_generation import BasicToxicityChecker + +# Use different threshold +strict_checker = BasicToxicityChecker(threshold=0.5) # More strict + +# Use different model +custom_checker = BasicToxicityChecker( + model_name="unitary/toxic-bert", + threshold=0.7, + device="cuda" # Force specific device +) +``` + +## Implementing Custom Safety Checkers + +You can create your own safety checkers by inheriting from `SafetyChecker`: + +```python +from transformers.generation.safety import SafetyChecker, SafetyResult, SafetyViolation + +class MyCustomChecker(SafetyChecker): + def __init__(self, threshold=0.7): + self.threshold = threshold + # Your initialization here + + @property + def supported_categories(self) -> list[str]: + """Declare what categories this checker detects.""" + return ["violence", "hate_speech", "custom_category"] + + def check_safety(self, text, **kwargs): + """ + Check text for safety violations. + + Args: + text: Single string or list of strings to check + + Returns: + SafetyResult or list of SafetyResult + """ + if isinstance(text, str): + return self._check_single(text) + return [self._check_single(t) for t in text] + + def _check_single(self, text: str) -> SafetyResult: + # Your safety checking logic here + is_safe = True # Your determination + violations = [] # List of SafetyViolation objects + + if not is_safe: + violations.append(SafetyViolation( + category="your_category", + confidence=0.9, + severity="high", + description="Why this is unsafe" + )) + + return SafetyResult( + is_safe=is_safe, + confidence=0.95, + violations=violations, + metadata={"checker": "MyCustomChecker"} + ) +``` + +### LLM-Based Safety Checker Example + +```python +import openai +import json +from transformers.generation.safety import SafetyChecker, SafetyResult, SafetyViolation + +class LLMSafetyChecker(SafetyChecker): + """Safety checker using an LLM API for multi-category detection.""" + + def __init__(self, api_key: str, categories: list[str] = None): + self.client = openai.OpenAI(api_key=api_key) + self.categories = categories or ["violence", "self_harm", "hate_speech", "sexual"] + + @property + def supported_categories(self) -> list[str]: + return self.categories + + def check_safety(self, text, **kwargs): + if isinstance(text, str): + return self._check_single(text) + return [self._check_single(t) for t in text] + + def _check_single(self, text: str) -> SafetyResult: + prompt = f"""Evaluate if this text contains any of: {', '.join(self.categories)} + +Text: {text} + +Respond with JSON: {{"is_safe": true/false, "category": "...", "confidence": 0.0-1.0}}""" + + try: + response = self.client.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": prompt}], + response_format={"type": "json_object"} + ) + result = json.loads(response.choices[0].message.content) + + violations = [] + if not result["is_safe"]: + violations.append(SafetyViolation( + category=result.get("category", "unknown"), + confidence=result["confidence"], + severity="high" if result["confidence"] > 0.8 else "medium", + description=f"Detected {result['category']} content" + )) + + return SafetyResult( + is_safe=result["is_safe"], + confidence=result["confidence"], + violations=violations, + metadata={"model": "gpt-4", "categories_checked": self.categories} + ) + except Exception as e: + # Fail-safe: assume unsafe on error + return SafetyResult( + is_safe=False, + confidence=0.0, + violations=[SafetyViolation("error", 0.0, "high", str(e))], + metadata={"error": str(e)} + ) + +# Usage +llm_checker = LLMSafetyChecker(api_key="your-api-key") +config = SafetyConfig.from_checker(llm_checker) +``` + +## Performance Optimization + +For high-latency checkers (like LLM APIs), use SafetyConfig.from_checker() with custom performance settings: + +```python +from transformers.generation.safety import SafetyConfig + +# For high-latency checkers, optimize with larger caches and sliding windows +config = SafetyConfig.from_checker( + your_checker, # Your checker instance + cache_size=500, # Large cache for API responses + unsafe_hash_limit=5000, # Track more unsafe patterns + sliding_window_size=512, # Limit tokens sent to API + incremental_checking=True, # Avoid re-processing same content + return_violations=False, # Disable for better performance + return_metadata=False # Disable for better performance +) +``` + +## Files in This Directory + +- `checkers.py`: Reference implementation of `BasicToxicityChecker` +- `__init__.py`: Exports for easy importing +- `README.md`: This file - usage guide and examples + +## Further Reading + +- [Safe Generation Design Document](../../docs/0.safe_generation_design.md) +- [Extensibility and Checker Strategy](../../docs/6.extensibility_and_checker_strategy.md) +- [Core Safety Infrastructure](../../docs/1.core_safety_infrastructure.md) diff --git a/examples/safe_generation/__init__.py b/examples/safe_generation/__init__.py new file mode 100644 index 000000000000..e42775addde3 --- /dev/null +++ b/examples/safe_generation/__init__.py @@ -0,0 +1,43 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Safe Generation Example Implementations + +This module provides reference implementations of safety checkers for the transformers +safe generation feature. These are example implementations that users can use directly +or adapt for their specific needs. + +The core transformers library provides only the infrastructure (SafetyChecker abstract base, +processors, configuration). Concrete implementations like BasicToxicityChecker are provided +here as examples to demonstrate how to implement custom safety checkers. + +Example usage: + from examples.safe_generation import BasicToxicityChecker + from transformers import pipeline + from transformers.generation.safety import SafetyConfig + + # Create a safety checker + checker = BasicToxicityChecker(threshold=0.7) + + # Use with pipeline + config = SafetyConfig.from_checker(checker) + pipe = pipeline("text-generation", model="gpt2", safety_config=config) +""" + +from .checkers import BasicToxicityChecker + + +__all__ = ["BasicToxicityChecker"] diff --git a/examples/safe_generation/checkers.py b/examples/safe_generation/checkers.py new file mode 100644 index 000000000000..f634a34bfda6 --- /dev/null +++ b/examples/safe_generation/checkers.py @@ -0,0 +1,231 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Union + +import torch +import torch.nn.functional as F + +from transformers import AutoModelForSequenceClassification, AutoTokenizer +from transformers.generation.safety import SafetyChecker, SafetyResult, SafetyViolation +from transformers.utils import is_torch_available, logging + + +if not is_torch_available(): + raise ImportError("PyTorch is required to use safety checkers. Please install PyTorch: pip install torch") + + +logger = logging.get_logger(__name__) + + +class BasicToxicityChecker(SafetyChecker): + """ + Toxicity checker using the s-nlp/roberta_toxicity_classifier model. + + This checker uses a pre-trained RoBERTa model to detect toxic content in text. It supports both + single text and batch processing, with configurable thresholds and automatic device selection. + + This is a reference implementation provided in the examples directory to demonstrate how to + implement custom safety checkers. The core transformers library provides only the infrastructure + (SafetyChecker abstract base class, processors, configuration). + + Args: + model_name (`str`, *optional*, defaults to `"s-nlp/roberta_toxicity_classifier"`): + The name of the pre-trained model to use for toxicity detection. + threshold (`float`, *optional*, defaults to `0.7`): + The toxicity score threshold above which content is considered unsafe. + device (`str`, *optional*): + The device to run the model on. If None, automatically selects CUDA if available, else CPU. + + Examples: + ```python + >>> from examples.safe_generation import BasicToxicityChecker + >>> from transformers.generation.safety import SafetyConfig + >>> from transformers import pipeline + + >>> # Create checker + >>> checker = BasicToxicityChecker(threshold=0.7) + + >>> # Use with SafetyConfig + >>> config = SafetyConfig.from_checker(checker) + >>> pipe = pipeline("text-generation", model="gpt2", safety_config=config) + ``` + """ + + def __init__( + self, + model_name: str = "s-nlp/roberta_toxicity_classifier", + threshold: float = 0.7, + device: Optional[str] = None, + ): + self.model_name = model_name + self.threshold = threshold + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + + # Load model and tokenizer with error handling + try: + logger.info(f"Loading toxicity model: {model_name}") + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForSequenceClassification.from_pretrained(model_name) + self.model.to(self.device) + self.model.eval() + logger.info(f"Successfully loaded toxicity model on {self.device}") + except Exception as e: + raise RuntimeError( + f"Failed to load toxicity model '{model_name}'. " + f"Please ensure the model exists and you have internet connectivity. " + f"Original error: {e}" + ) + + @property + def supported_categories(self) -> list[str]: + """Return list of safety categories this checker supports.""" + return ["toxicity"] + + def check_safety(self, text: Union[str, list[str]], **kwargs) -> Union[SafetyResult, list[SafetyResult]]: + """ + Check text(s) for toxicity violations. + + Args: + text (`Union[str, List[str]]`): + Single text string or list of texts to check for toxicity. + **kwargs: + Additional parameters (currently unused). + + Returns: + `Union[SafetyResult, List[SafetyResult]]`: + SafetyResult for single text input, List[SafetyResult] for multiple texts. + """ + if isinstance(text, str): + return self._check_single_text(text, **kwargs) + elif isinstance(text, list): + return [self._check_single_text(t, **kwargs) for t in text] + else: + raise TypeError(f"Expected string or list of strings, got {type(text)}") + + def _check_single_text(self, text: str, **kwargs) -> SafetyResult: + """ + Check single text for toxicity. + + Args: + text (`str`): Text to check for toxicity. + **kwargs: Additional parameters (currently unused). + + Returns: + `SafetyResult`: Result of the safety check. + """ + # Input validation + if not isinstance(text, str): + raise TypeError(f"Expected string input, got {type(text)}") + + # Handle empty text + if not text.strip(): + return SafetyResult( + is_safe=True, + confidence=1.0, + violations=[], + metadata={"reason": "empty_text", "model_name": self.model_name}, + ) + + # Handle very long text + original_length = len(text) + max_length = 10000 # Reasonable limit + if len(text) > max_length: + text = text[:max_length] + logger.warning(f"Text truncated from {original_length} to {max_length} characters") + + # Tokenize and run inference + try: + inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True).to( + self.device + ) + + with torch.no_grad(): + outputs = self.model(**inputs) + probabilities = F.softmax(outputs.logits, dim=-1) + + # Extract toxicity probability (assuming binary classification: [non-toxic, toxic]) + toxicity_score = probabilities[0][1].item() # Toxic class probability + + except Exception as e: + logger.error(f"Error during toxicity inference: {e}") + raise RuntimeError(f"Toxicity detection failed: {e}") + + # Determine safety + is_safe = toxicity_score < self.threshold + violations = [] + + if not is_safe: + violations.append( + SafetyViolation( + category="toxicity", + confidence=toxicity_score, + severity=self._get_severity(toxicity_score), + description=f"Detected toxic content with {toxicity_score:.2%} confidence", + ) + ) + + # Prepare metadata + metadata = { + "model_name": self.model_name, + "toxicity_score": toxicity_score, + "threshold": self.threshold, + "device": self.device, + } + + if original_length > max_length: + metadata["truncated"] = True + metadata["original_length"] = original_length + metadata["processed_length"] = max_length + + return SafetyResult( + is_safe=is_safe, + confidence=max(toxicity_score, 1.0 - toxicity_score), + violations=violations, + metadata=metadata, + ) + + def _get_severity(self, score: float) -> str: + """ + Determine severity based on toxicity score. + + Args: + score (`float`): Toxicity score from 0.0 to 1.0. + + Returns: + `str`: Severity level ("low", "medium", "high", "critical"). + """ + if score >= 0.95: + return "critical" + elif score >= 0.85: + return "high" + elif score >= 0.75: + return "medium" + else: + return "low" + + def get_config(self) -> dict[str, Any]: + """ + Return checker configuration for serialization. + + Returns: + `Dict[str, Any]`: Dictionary containing the checker's configuration. + """ + return { + "checker_type": "BasicToxicityChecker", + "model_name": self.model_name, + "threshold": self.threshold, + "device": self.device, + } diff --git a/examples/safety_generation_example.py b/examples/safety_generation_example.py new file mode 100644 index 000000000000..885543fdc4bc --- /dev/null +++ b/examples/safety_generation_example.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +""" +Example: Safe vs Regular Text Generation with Transformers Safety + +This example demonstrates how to compare regular generation to generation with +real-time safety filtering (toxicity) using Transformers' safety utilities. + +""" + +import os +import platform +import sys +from pathlib import Path + +import torch + + +# Add safe_generation to path to import BasicToxicityChecker +sys.path.insert(0, str(Path(__file__).parent / "safe_generation")) + +from safe_generation import BasicToxicityChecker + +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig +from transformers.generation.safety import SafetyConfig, SafetyLogitsProcessor, SafetyStoppingCriteria + + +def get_device(): + if torch.backends.mps.is_available() and platform.system() == "Darwin": + return torch.device("mps") + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def main(): + """Compare regular generation vs safety-filtered generation.""" + + # Load a pre-trained model and tokenizer + model_name = os.environ.get("SAFETY_EXAMPLE_MODEL", "Qwen/Qwen2.5-3B") + print(f"Loading model: {model_name}") + print("(Using Qwen for consistent text generation)") + + device = get_device() + model = AutoModelForCausalLM.from_pretrained(model_name).to(device) + model.eval() + tokenizer = AutoTokenizer.from_pretrained(model_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Build safety components with strict threshold + toxicity_threshold = 0.7 # Strict threshold for demo + safety_checker = BasicToxicityChecker( + model_name="s-nlp/roberta_toxicity_classifier", threshold=toxicity_threshold, device="cpu" + ) + + # Create safety config from checker + safety_config = SafetyConfig.from_checker(safety_checker) + + from transformers.generation.logits_process import LogitsProcessorList + from transformers.generation.stopping_criteria import StoppingCriteriaList + + safety_processor = SafetyLogitsProcessor( + safety_checker=safety_checker, + tokenizer=tokenizer, + safety_config=safety_config, + check_interval=1, + ) + safety_stopper = SafetyStoppingCriteria( + safety_checker=safety_checker, + tokenizer=tokenizer, + safety_config=safety_config, + check_final_only=False, + ) + + safety_processors = LogitsProcessorList([safety_processor]) + safety_criteria = StoppingCriteriaList([safety_stopper]) + + # Base generation configuration - deterministic for comparison + base_gen = GenerationConfig( + max_new_tokens=30, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + ) + + prompts = [ + 'He went on a long tirade spouting insults and obscenities: "', + ] + + print("\n" + "=" * 60) + print("SAFE TEXT GENERATION EXAMPLES (Regular vs Safe)") + print("=" * 60) + print("(Errors will raise; no soft fallbacks)") + + for i, prompt in enumerate(prompts, 1): + print(f"\nExample {i}: '{prompt}'") + print("-" * 40) + + # Tokenize input and move to model device + inputs = tokenizer(prompt, return_tensors="pt").to(device) + + # 1) Regular generation (no safety) + with torch.no_grad(): + outputs_regular = model.generate( + **inputs, + generation_config=base_gen, + ) + text_regular = tokenizer.decode(outputs_regular[0], skip_special_tokens=True) + cont_regular = text_regular[len(prompt) :].strip() + print(f"Regular: {cont_regular}") + + # 2) Safe generation (with safety filtering) + with torch.no_grad(): + outputs_safe = model.generate( + **inputs, + generation_config=base_gen, + logits_processor=safety_processors, + stopping_criteria=safety_criteria, + ) + text_safe = tokenizer.decode(outputs_safe[0], skip_special_tokens=True) + cont_safe = text_safe[len(prompt) :].strip() + print(f"Safe: {cont_safe}") + + if cont_regular != cont_safe: + print("Safety filtering applied - outputs differ") + else: + print("No safety filtering needed - outputs identical") + + # Verify safety checker would detect issues in the output + regular_safety_result = safety_checker.check_safety(text_regular) + if not regular_safety_result.is_safe: + print(" WARNING: Safety checker detected violations in output but filtering didn't occur!") + print(f" Violations: {[v.category for v in regular_safety_result.violations]}") + print(f" Confidence: {regular_safety_result.confidence:.3f}") + + print("\n" + "=" * 60) + print("HOW IT WORKS:") + print("=" * 60) + print( + """ +1. SafetyLogitsProcessor blocks ALL tokens when unsafe content is detected +2. SafetyStoppingCriteria can halt generation if unsafe content is detected +3. Both work during generation, stopping output when safety violations occur +4. Deterministic generation allows direct comparison of safe vs regular outputs + """ + ) + + print("\nDifferent Safety Levels:") + print("- strict: threshold=0.5 (more restrictive)") + print("- moderate: threshold=0.7 (balanced)") + print("- lenient: threshold=0.9 (less restrictive)") + print("\nCurrent demo uses: threshold=0.7 for reliable blocking") + print("\nTo use predefined presets:") + print("from transformers.generation.safety import STRICT_PRESET") + print("config = SafetyConfig.from_checker(checker, **STRICT_PRESET)") + + +if __name__ == "__main__": + main() diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index 92ef3184e773..49b8e03e2b6f 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -88,6 +88,16 @@ _import_structure["continuous_batching"] = [ "ContinuousMixin", ] + _import_structure["safety"] = [ + "SafetyChecker", + "SafetyResult", + "SafetyViolation", + "SafetyMetrics", + "SafetyState", + "SafetyConfig", + "SafetyLogitsProcessor", + "SafetyStoppingCriteria", + ] _import_structure["utils"] = [ "GenerationMixin", "GenerateBeamDecoderOnlyOutput", diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 7be052a9a946..b9fcd91e489a 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -244,6 +244,10 @@ class GenerationConfig(PushToHubMixin): Arguments used to watermark the model outputs by adding a small bias to randomly selected set of "green" tokens. See the docs of [`SynthIDTextWatermarkingConfig`] and [`WatermarkingConfig`] for more details. If passed as `Dict`, it will be converted to a `WatermarkingConfig` internally. + safety_config (`SafetyConfig` or `dict`, *optional*): + Configuration for content safety filtering during generation. Enables real-time detection and suppression + of unsafe content like toxicity, hate speech, etc. See [`SafetyConfig`] for more details. If passed as + `Dict`, it will be converted to a `SafetyConfig` internally. > Parameters that define the output variables of generate @@ -388,6 +392,22 @@ def __init__(self, **kwargs): else: self.watermarking_config = WatermarkingConfig.from_dict(watermarking_config) + # Safety configuration for content filtering during generation + safety_config = kwargs.pop("safety_config", None) + if safety_config is None: + self.safety_config = None + elif hasattr(safety_config, "enabled"): # Duck typing for SafetyConfig + self.safety_config = safety_config + else: + # Lazy import to avoid circular dependencies + try: + from .safety import SafetyConfig + + self.safety_config = SafetyConfig.from_dict(safety_config) + except ImportError: + logger.warning("SafetyConfig requested but safety module not available") + self.safety_config = None + # Parameters that define the output variables of `generate` self.num_return_sequences = kwargs.pop("num_return_sequences", 1) self.output_attentions = kwargs.pop("output_attentions", False) diff --git a/src/transformers/generation/safety/__init__.py b/src/transformers/generation/safety/__init__.py new file mode 100644 index 000000000000..095aed1eec5d --- /dev/null +++ b/src/transformers/generation/safety/__init__.py @@ -0,0 +1,40 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import is_torch_available +from .base import SafetyChecker, SafetyMetrics, SafetyResult, SafetyState, SafetyViolation +from .configuration import LENIENT_PRESET, MODERATE_PRESET, STRICT_PRESET, SafetyConfig + + +if is_torch_available(): + from .processors import SafetyLogitsProcessor, SafetyStoppingCriteria +else: + SafetyLogitsProcessor = None + SafetyStoppingCriteria = None + + +__all__ = [ + "SafetyChecker", + "SafetyResult", + "SafetyViolation", + "SafetyMetrics", + "SafetyState", + "SafetyConfig", + "STRICT_PRESET", + "MODERATE_PRESET", + "LENIENT_PRESET", +] + +if is_torch_available(): + __all__ += ["SafetyLogitsProcessor", "SafetyStoppingCriteria"] diff --git a/src/transformers/generation/safety/base.py b/src/transformers/generation/safety/base.py new file mode 100644 index 000000000000..1a92e8f07eb1 --- /dev/null +++ b/src/transformers/generation/safety/base.py @@ -0,0 +1,366 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import threading +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Optional, Union + + +@dataclass +class SafetyViolation: + """ + Represents a single safety violation detected in text. + + Args: + category (`str`): + The category of safety violation (e.g., "toxicity", "bias", "pii"). + confidence (`float`): + Confidence score for the violation detection, ranging from 0.0 to 1.0. + severity (`str`, *optional*, defaults to `"medium"`): + Severity level of the violation. One of "low", "medium", "high", "critical". + description (`str`, *optional*, defaults to `""`): + Human-readable description of the violation. + span (`Tuple[int, int]`, *optional*): + Character span in the original text where the violation occurs, if applicable. + """ + + category: str + confidence: float + severity: str = "medium" + description: str = "" + span: Optional[tuple[int, int]] = None + + +@dataclass +class SafetyResult: + """ + Result of a safety checking operation. + + Args: + is_safe (`bool`): + Whether the checked text is considered safe overall. + confidence (`float`): + Overall confidence in the safety assessment, ranging from 0.0 to 1.0. + violations (`List[SafetyViolation]`): + List of safety violations detected in the text. + metadata (`Dict[str, Any]`): + Additional checker-specific information and context. + """ + + is_safe: bool + confidence: float + violations: list[SafetyViolation] + metadata: dict[str, Any] + + +@dataclass +class SafetyMetrics: + """ + Metrics collection for safety operations monitoring and analysis. + + Tracks performance and usage statistics for safety checking operations, + enabling production monitoring and optimization. + + Args: + total_generations (`int`, defaults to 0): + Total number of generations attempted. + blocked_generations (`int`, defaults to 0): + Number of generations blocked due to safety violations. + suppression_events (`int`, defaults to 0): + Number of token suppression events during generation. + cache_hits (`int`, defaults to 0): + Number of cache hits for safety check results. + cache_misses (`int`, defaults to 0): + Number of cache misses requiring new safety checks. + total_safety_check_time_ms (`float`, defaults to 0.0): + Cumulative time spent on safety checks in milliseconds. + safety_check_count (`int`, defaults to 0): + Total number of safety checks performed. + """ + + total_generations: int = 0 + blocked_generations: int = 0 + suppression_events: int = 0 + cache_hits: int = 0 + cache_misses: int = 0 + total_safety_check_time_ms: float = 0.0 + safety_check_count: int = 0 + + def __post_init__(self): + """Initialize thread safety lock after dataclass fields.""" + self._lock = threading.Lock() + + @property + def cache_hit_rate(self) -> float: + """Calculate cache hit rate as a percentage.""" + total_cache_ops = self.cache_hits + self.cache_misses + if total_cache_ops == 0: + return 0.0 + return (self.cache_hits / total_cache_ops) * 100.0 + + @property + def avg_safety_check_time_ms(self) -> float: + """Calculate average safety check time in milliseconds.""" + if self.safety_check_count == 0: + return 0.0 + return self.total_safety_check_time_ms / self.safety_check_count + + @property + def block_rate(self) -> float: + """Calculate generation block rate as a percentage.""" + if self.total_generations == 0: + return 0.0 + return (self.blocked_generations / self.total_generations) * 100.0 + + def record_safety_check(self, check_time_ms: float) -> None: + """Record a safety check operation with timing.""" + with self._lock: + self.safety_check_count += 1 + self.total_safety_check_time_ms += check_time_ms + + def record_cache_hit(self) -> None: + """Record a cache hit event.""" + with self._lock: + self.cache_hits += 1 + + def record_cache_miss(self) -> None: + """Record a cache miss event.""" + with self._lock: + self.cache_misses += 1 + + def record_generation_attempt(self) -> None: + """Record a generation attempt.""" + with self._lock: + self.total_generations += 1 + + def record_blocked_generation(self) -> None: + """Record a generation that was blocked due to safety violations.""" + with self._lock: + self.blocked_generations += 1 + + def record_suppression_event(self) -> None: + """Record a token suppression event.""" + with self._lock: + self.suppression_events += 1 + + def to_dict(self) -> dict[str, Union[int, float]]: + """ + Export metrics as dictionary for logging or monitoring systems. + + Returns: + Dict[str, Union[int, float]]: Dictionary containing all metrics. + """ + with self._lock: + return { + "total_generations": self.total_generations, + "blocked_generations": self.blocked_generations, + "suppression_events": self.suppression_events, + "cache_hits": self.cache_hits, + "cache_misses": self.cache_misses, + "cache_hit_rate": self.cache_hit_rate, + "avg_safety_check_time_ms": self.avg_safety_check_time_ms, + "block_rate": self.block_rate, + "safety_check_count": self.safety_check_count, + } + + def reset(self) -> None: + """Reset all metrics to zero for new measurement period.""" + with self._lock: + self.total_generations = 0 + self.blocked_generations = 0 + self.suppression_events = 0 + self.cache_hits = 0 + self.cache_misses = 0 + self.total_safety_check_time_ms = 0.0 + self.safety_check_count = 0 + + def combine(self, other: SafetyMetrics) -> SafetyMetrics: + """ + Combine metrics from another SafetyMetrics instance. + + Args: + other (SafetyMetrics): Another metrics instance to combine with. + + Returns: + SafetyMetrics: New instance with combined metrics. + """ + # Use both locks in consistent order to prevent deadlocks + locks = sorted([self._lock, other._lock], key=lambda x: id(x)) + with locks[0]: + with locks[1]: + return SafetyMetrics( + total_generations=self.total_generations + other.total_generations, + blocked_generations=self.blocked_generations + other.blocked_generations, + suppression_events=self.suppression_events + other.suppression_events, + cache_hits=self.cache_hits + other.cache_hits, + cache_misses=self.cache_misses + other.cache_misses, + total_safety_check_time_ms=self.total_safety_check_time_ms + other.total_safety_check_time_ms, + safety_check_count=self.safety_check_count + other.safety_check_count, + ) + + +class SafetyChecker(ABC): + """ + Abstract base class for all safety checkers. + + Safety checkers are responsible for analyzing text content and detecting various types of safety violations + such as toxicity, bias, personally identifiable information, or other harmful content. + """ + + @abstractmethod + def check_safety(self, text: Union[str, list[str]], **kwargs) -> Union[SafetyResult, list[SafetyResult]]: + """ + Check text(s) for safety violations. + + Args: + text (`Union[str, List[str]]`): + Single text string or list of texts to check for safety violations. + **kwargs: + Additional checker-specific parameters. + + Returns: + `Union[SafetyResult, List[SafetyResult]]`: + SafetyResult for single text input, List[SafetyResult] for multiple texts. + """ + raise NotImplementedError( + f"{self.__class__.__name__} is an abstract class. Only classes inheriting this class can be called." + ) + + @property + @abstractmethod + def supported_categories(self) -> list[str]: + """ + Return list of safety categories this checker supports. + + Returns: + `List[str]`: List of supported safety categories (e.g., ["toxicity", "bias"]). + """ + raise NotImplementedError( + f"{self.__class__.__name__} is an abstract class. Only classes inheriting this class can be called." + ) + + def get_config(self) -> dict[str, Any]: + """ + Return checker configuration for serialization. + + Returns: + `Dict[str, Any]`: Dictionary containing the checker's configuration parameters. + """ + return {"checker_type": self.__class__.__name__} + + +@dataclass +class SafetyState: + """ + Tracks incremental safety checking state for efficient sequence processing. + + This class maintains state information to enable efficient sliding window + and incremental safety checking, avoiding redundant processing of previously + checked content. + + Args: + last_check_position (`int`, *optional*, defaults to `0`): + The position (in tokens) where the last safety check ended. + last_check_result (`Optional[SafetyResult]`, *optional*): + The result of the last safety check performed. + sequence_prefix (`str`, *optional*, defaults to `""`): + The text prefix that has already been checked for safety. + is_safe_so_far (`bool`, *optional*, defaults to `True`): + Whether the sequence has been safe up to the last check position. + window_start_position (`int`, *optional*, defaults to `0`): + The starting position of the current sliding window. + """ + + last_check_position: int = 0 + last_check_result: Optional[SafetyResult] = None + sequence_prefix: str = "" + is_safe_so_far: bool = True + window_start_position: int = 0 + + def should_check_incremental(self, current_position: int, min_new_tokens: int = 5) -> bool: + """ + Determine if an incremental safety check should be performed. + + Args: + current_position (`int`): + Current position in the sequence (in tokens). + min_new_tokens (`int`, *optional*, defaults to `5`): + Minimum number of new tokens before triggering a new check. + + Returns: + `bool`: True if a new safety check should be performed. + """ + # Always check if this is the first check + if self.last_check_position == 0: + return True + + # Check if enough new tokens have been added + new_tokens = current_position - self.last_check_position + return new_tokens >= min_new_tokens + + def update_check_result(self, position: int, result: SafetyResult, sequence_prefix: str = "") -> None: + """ + Update the state with a new safety check result. + + Args: + position (`int`): + The position where this check ended. + result (`SafetyResult`): + The safety check result. + sequence_prefix (`str`, *optional*, defaults to `""`): + The sequence prefix that was checked. + """ + self.last_check_position = position + self.last_check_result = result + self.sequence_prefix = sequence_prefix + self.is_safe_so_far = result.is_safe if result else True + + def get_incremental_text(self, full_text: str, sliding_window_size: int = -1) -> tuple[str, int]: + """ + Extract the portion of text that needs incremental checking. + + Args: + full_text (`str`): + The complete sequence text. + sliding_window_size (`int`, *optional*, defaults to `-1`): + Size of sliding window in characters. -1 means no sliding window. + + Returns: + `tuple[str, int]`: The text portion to check and its start position. + """ + if sliding_window_size == -1: + # No sliding window - return text from last check position + if len(self.sequence_prefix) > 0: + # Find where we left off and return remaining text + remaining_text = full_text[len(self.sequence_prefix) :] + return self.sequence_prefix + remaining_text, 0 + return full_text, 0 + # Use sliding window + if len(full_text) <= sliding_window_size: + return full_text, 0 + window_start = max(0, len(full_text) - sliding_window_size) + self.window_start_position = window_start + return full_text[window_start:], window_start + + def reset(self) -> None: + """Reset the safety state for a new sequence.""" + self.last_check_position = 0 + self.last_check_result = None + self.sequence_prefix = "" + self.is_safe_so_far = True + self.window_start_position = 0 diff --git a/src/transformers/generation/safety/configuration.py b/src/transformers/generation/safety/configuration.py new file mode 100644 index 000000000000..de5f1f8156dd --- /dev/null +++ b/src/transformers/generation/safety/configuration.py @@ -0,0 +1,325 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Optional + + +if TYPE_CHECKING: + from .base import SafetyChecker + + +# Constants for validation warnings +WARNING_CACHE_SIZE_LIMIT = 10000 +WARNING_UNSAFE_HASH_LIMIT = 100000 + + +@dataclass +class SafetyConfig: + """ + Configuration for safety checking in text generation. + + This configuration class stores settings for safety checking and accepts a user-provided + safety checker instance. The transformers library provides the infrastructure + (SafetyChecker abstract base, processors, configuration), while users implement + concrete checkers for their specific safety requirements. + + Args: + enabled (`bool`, *optional*, defaults to `False`): + Whether safety checking is enabled. + checker (`SafetyChecker`, *optional*, defaults to `None`): + The safety checker instance to use. Must be provided by the user. + See examples/safe_generation/ for reference implementations. + device (`str`, *optional*): + Device to run models on. If None, automatically selects CUDA if available. + cache_size (`int`, *optional*, defaults to `100`): + Maximum number of safety check results to cache. Larger values use more memory + but can improve performance for repetitive content. + unsafe_hash_limit (`int`, *optional*, defaults to `1000`): + Maximum number of unsafe sequence hashes to remember. Prevents memory leaks + in long-running applications with many unsafe sequences. + sliding_window_size (`int`, *optional*, defaults to `512`): + Maximum number of tokens to check for safety instead of the full sequence. + Helps improve performance for long sequences while maintaining safety effectiveness. + Set to -1 to disable sliding window (check full sequence). + incremental_checking (`bool`, *optional*, defaults to `True`): + Whether to enable incremental safety checking that tracks state between checks + to avoid redundant processing. Improves performance for long generations. + return_violations (`bool`, *optional*, defaults to `False`): + Whether to return detailed violation information in results. + return_metadata (`bool`, *optional*, defaults to `False`): + Whether to return additional metadata in results. + + Examples: + ```python + # Using a reference implementation from examples directory + # Note: You need to add examples/ to your Python path first: + import sys + from pathlib import Path + sys.path.insert(0, str(Path("examples"))) + + from safe_generation import BasicToxicityChecker + from transformers.generation.safety import SafetyConfig + + # Create checker instance + checker = BasicToxicityChecker(threshold=0.7) + + # Option 1: Create config with from_checker() (recommended) + config = SafetyConfig.from_checker(checker) + + # Option 2: Create config directly + config = SafetyConfig(enabled=True, checker=checker) + + # Use with generation + from transformers import pipeline + pipe = pipeline("text-generation", model="gpt2", safety_config=config) + ``` + """ + + # Checker configuration + enabled: bool = False + checker: Optional[SafetyChecker] = None + + # Device configuration + device: Optional[str] = None + + # Performance configuration + cache_size: int = 100 + unsafe_hash_limit: int = 1000 + sliding_window_size: int = 512 + incremental_checking: bool = True + prefix_lengths: list[int] = field(default_factory=lambda: [100, 75, 50]) + min_text_length_for_prefix: int = 50 + + # Output configuration + return_violations: bool = False + return_metadata: bool = False + + def __post_init__(self): + """Perform immediate validation after initialization.""" + # Basic type checking for critical parameters + if not isinstance(self.cache_size, int): + raise TypeError(f"cache_size must be an integer, got {type(self.cache_size).__name__}") + + if not isinstance(self.unsafe_hash_limit, int): + raise TypeError(f"unsafe_hash_limit must be an integer, got {type(self.unsafe_hash_limit).__name__}") + + # Range validation + if self.cache_size < 1: + raise ValueError("cache_size must be a positive integer") + + if self.unsafe_hash_limit < 1: + raise ValueError("unsafe_hash_limit must be a positive integer") + + # Validate sliding window size + if not isinstance(self.sliding_window_size, int): + raise TypeError(f"sliding_window_size must be an integer, got {type(self.sliding_window_size).__name__}") + + if self.sliding_window_size < -1 or self.sliding_window_size == 0: + raise ValueError("sliding_window_size must be a positive integer or -1 to disable") + + # Validate incremental checking + if not isinstance(self.incremental_checking, bool): + raise TypeError(f"incremental_checking must be a boolean, got {type(self.incremental_checking).__name__}") + + # Validate prefix configuration + if not isinstance(self.prefix_lengths, list): + raise TypeError(f"prefix_lengths must be a list, got {type(self.prefix_lengths).__name__}") + + if not all(isinstance(length, int) and length > 0 for length in self.prefix_lengths): + raise ValueError("All prefix_lengths must be positive integers") + + if not isinstance(self.min_text_length_for_prefix, int) or self.min_text_length_for_prefix < 1: + raise ValueError("min_text_length_for_prefix must be a positive integer") + + def to_dict(self) -> dict[str, Any]: + """ + Convert to dictionary for serialization. + + Note: The checker instance is not serialized. You must recreate it when + deserializing. + + Returns: + `Dict[str, Any]`: Dictionary representation of the configuration. + """ + return { + "enabled": self.enabled, + "device": self.device, + "cache_size": self.cache_size, + "unsafe_hash_limit": self.unsafe_hash_limit, + "sliding_window_size": self.sliding_window_size, + "incremental_checking": self.incremental_checking, + "prefix_lengths": self.prefix_lengths, + "min_text_length_for_prefix": self.min_text_length_for_prefix, + "return_violations": self.return_violations, + "return_metadata": self.return_metadata, + # Note: checker is not serialized - must be provided when deserializing + } + + @classmethod + def from_dict(cls, config_dict: dict[str, Any]) -> SafetyConfig: + """ + Create SafetyConfig from dictionary. + + Args: + config_dict (`Dict[str, Any]`): Dictionary containing configuration parameters. + + Returns: + `SafetyConfig`: Instance created from the dictionary. + """ + return cls(**config_dict) + + def validate(self) -> None: + """ + Validate configuration parameters. + + Raises: + ValueError: If any configuration parameter is invalid. + """ + # Validate enabled is boolean + if not isinstance(self.enabled, bool): + raise ValueError("enabled must be a boolean") + + # Warn about potentially inefficient configurations (validation done in __post_init__) + if self.cache_size > WARNING_CACHE_SIZE_LIMIT: + warnings.warn( + f"cache_size > {WARNING_CACHE_SIZE_LIMIT} may use excessive memory", UserWarning, stacklevel=2 + ) + + if self.unsafe_hash_limit > WARNING_UNSAFE_HASH_LIMIT: + warnings.warn( + f"unsafe_hash_limit > {WARNING_UNSAFE_HASH_LIMIT} may use excessive memory", UserWarning, stacklevel=2 + ) + + # Validate output configuration + if not isinstance(self.return_violations, bool): + raise ValueError("return_violations must be a boolean") + + if not isinstance(self.return_metadata, bool): + raise ValueError("return_metadata must be a boolean") + + def construct_checker(self) -> SafetyChecker: + """ + Retrieve the safety checker from the configuration. + + Returns the user-provided checker instance that was specified when creating + the configuration. + + Returns: + `SafetyChecker`: The safety checker instance. + + Raises: + ValueError: If no checker instance is provided. + + Examples: + ```python + # See examples/safe_generation/ for reference implementations + import sys + from pathlib import Path + sys.path.insert(0, str(Path("examples"))) + + from safe_generation import BasicToxicityChecker + from transformers.generation.safety import SafetyConfig + + # Create checker + checker = BasicToxicityChecker(threshold=0.7) + + # Create config with checker + config = SafetyConfig.from_checker(checker) + + # Construct checker (returns the same instance) + safety_checker = config.construct_checker() + ``` + """ + if self.checker is None: + raise ValueError( + "SafetyConfig requires a checker instance. " + "You must provide a SafetyChecker when creating the configuration. " + "See examples/safe_generation/ for reference implementations:\n\n" + " from examples.safe_generation import BasicToxicityChecker\n" + " checker = BasicToxicityChecker(threshold=0.7)\n" + " config = SafetyConfig.from_checker(checker)\n\n" + "Or implement your own custom checker by inheriting from SafetyChecker." + ) + return self.checker + + @classmethod + def from_checker(cls, checker: SafetyChecker, **kwargs) -> SafetyConfig: + """ + Create a SafetyConfig from a safety checker instance. + + This is the recommended way to create a SafetyConfig. + + Args: + checker (`SafetyChecker`): The safety checker instance to use. + **kwargs: Additional configuration parameters to override defaults. + + Returns: + `SafetyConfig`: A SafetyConfig instance with the provided checker. + + Examples: + ```python + # See examples/safe_generation/ for reference implementations + import sys + from pathlib import Path + sys.path.insert(0, str(Path("examples"))) + + from safe_generation import BasicToxicityChecker + from transformers.generation.safety import SafetyConfig + + # Create checker + checker = BasicToxicityChecker(threshold=0.7) + + # Create config from checker + config = SafetyConfig.from_checker(checker) + + # With additional parameters + config = SafetyConfig.from_checker( + checker, + cache_size=200, + return_violations=True + ) + ``` + """ + return cls(enabled=True, checker=checker, **kwargs) + + +# Preset configuration kwargs for convenience +# These replace the deprecated create_default() method +# Usage: SafetyConfig.from_checker(checker, **STRICT_PRESET) + +STRICT_PRESET = { + "cache_size": 50, + "unsafe_hash_limit": 500, + "return_violations": True, + "return_metadata": True, +} + +MODERATE_PRESET = { + "cache_size": 100, + "unsafe_hash_limit": 1000, + "return_violations": False, + "return_metadata": False, +} + +LENIENT_PRESET = { + "cache_size": 200, + "unsafe_hash_limit": 2000, + "return_violations": False, + "return_metadata": False, +} diff --git a/src/transformers/generation/safety/processors.py b/src/transformers/generation/safety/processors.py new file mode 100644 index 000000000000..f33e6d2d4693 --- /dev/null +++ b/src/transformers/generation/safety/processors.py @@ -0,0 +1,777 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import hashlib +import logging +import time +from collections import OrderedDict +from typing import Optional + +import torch + +from ..logits_process import LogitsProcessor +from ..stopping_criteria import StoppingCriteria +from .base import SafetyChecker, SafetyMetrics, SafetyResult, SafetyState, SafetyViolation +from .configuration import SafetyConfig + + +logger = logging.getLogger(__name__) + +# Configuration constants +DEFAULT_CACHE_SIZE = 100 +DEFAULT_UNSAFE_HASH_LIMIT = 1000 +DEFAULT_CHECK_INTERVAL = 1 + + +class _SafetyCache: + """Simple LRU cache for safety check results.""" + + def __init__(self, max_size: int = DEFAULT_CACHE_SIZE): + self.max_size = max_size + self._cache = OrderedDict() + + def get(self, text: str, use_prefix_matching: bool = False): + """ + Get cached result and move to end for LRU. + + Args: + text: Text to look up (will be hashed to create cache key) + use_prefix_matching: Ignored for simple cache (only supported by prefix cache) + + Returns: + SafetyResult if found, None otherwise + """ + key = _generate_cache_key(text) + if key in self._cache: + value = self._cache.pop(key) + self._cache[key] = value + return value + return None + + def put(self, text: str, value) -> None: + """ + Put result in cache with LRU eviction. + + Args: + text: The text that was checked (will be hashed to create cache key) + value: The SafetyResult to store + """ + key = _generate_cache_key(text) + if len(self._cache) >= self.max_size: + self._cache.popitem(last=False) + self._cache[key] = value + + def __contains__(self, text: str) -> bool: + """Check if text exists in cache.""" + key = _generate_cache_key(text) + return key in self._cache + + +class _PrefixSafetyCache: + """ + Advanced caching system that supports prefix-based caching for efficient sequence checking. + + This cache can reuse safety results for sequences that share common prefixes, + significantly improving performance for incremental checking scenarios. + """ + + def __init__( + self, + max_size: int = DEFAULT_CACHE_SIZE, + prefix_lengths: Optional[list[int]] = None, + min_text_length_for_prefix: int = 50, + ): + self.max_size = max_size + self.prefix_lengths = prefix_lengths if prefix_lengths is not None else [100, 75, 50] + self.min_text_length_for_prefix = min_text_length_for_prefix + self._cache = OrderedDict() # Maps full cache keys to results + self._prefix_map = {} # Maps text prefixes to cache keys that contain them + + def get(self, text: str, use_prefix_matching: bool = True): + """ + Get cached result, optionally using prefix matching for efficiency. + + Args: + text: Text to look up + use_prefix_matching: Whether to try prefix matching if exact match fails + + Returns: + SafetyResult if found, None otherwise + """ + cache_key = _generate_cache_key(text) + + # Try exact match first + if cache_key in self._cache: + result = self._cache.pop(cache_key) + self._cache[cache_key] = result # Move to end for LRU + return result + + # If prefix matching is enabled and exact match failed + if use_prefix_matching: + return self._try_prefix_match(text) + + return None + + def put(self, text: str, result) -> None: + """ + Store result in cache with prefix indexing. + + Args: + text: The text that was checked + result: The SafetyResult to store + """ + cache_key = _generate_cache_key(text) + + # Evict oldest if at capacity + if len(self._cache) >= self.max_size: + old_key, _ = self._cache.popitem(last=False) + self._cleanup_prefix_references(old_key) + + # Store result + self._cache[cache_key] = result + + # Update prefix mapping for common prefixes + if len(text) > self.min_text_length_for_prefix: # Only index prefixes for longer texts + # Use the longest configured prefix length that's not larger than half the text + max_prefix_length = max([length for length in self.prefix_lengths if length <= len(text) // 2], default=0) + if max_prefix_length > 0: + prefix = text[:max_prefix_length] + prefix_key = _generate_cache_key(prefix) + + if prefix_key not in self._prefix_map: + self._prefix_map[prefix_key] = set() + self._prefix_map[prefix_key].add(cache_key) + + def _try_prefix_match(self, text: str): + """ + Try to find a cached result for a prefix of the given text. + + This is useful when we have cached results for shorter versions of the sequence. + """ + if len(text) < self.min_text_length_for_prefix: # Don't use prefix matching for very short texts + return None + + # Try progressively shorter prefixes from configuration + for prefix_len in sorted(self.prefix_lengths, reverse=True): + if len(text) <= prefix_len: + continue + + prefix = text[:prefix_len] + prefix_key = _generate_cache_key(prefix) + + if prefix_key in self._prefix_map: + # Found potential matches - check if any are safe + for candidate_key in self._prefix_map[prefix_key]: + if candidate_key in self._cache: + result = self._cache[candidate_key] + # Only reuse if the cached result was safe + # (unsafe results might not apply to the longer sequence) + if result.is_safe: + # Move to end for LRU + self._cache.move_to_end(candidate_key) + return result + + return None + + def _cleanup_prefix_references(self, removed_cache_key: str) -> None: + """Remove references to evicted cache keys from prefix mapping.""" + keys_to_remove = [] + for prefix_key, cache_keys in self._prefix_map.items(): + if removed_cache_key in cache_keys: + cache_keys.discard(removed_cache_key) + if not cache_keys: # No more references + keys_to_remove.append(prefix_key) + + for key in keys_to_remove: + del self._prefix_map[key] + + def __contains__(self, text: str) -> bool: + """Check if text exists in cache.""" + cache_key = _generate_cache_key(text) + return cache_key in self._cache + + +def _generate_cache_key(text: str) -> str: + """ + Generate a SHA-256 based cache key for text content. + + Uses length prefix for quick rejection of different-sized texts, + followed by SHA-256 hash for collision-resistant uniqueness. + + Args: + text (str): The text content to generate a cache key for. + + Returns: + str: A cache key in the format "length:hash" + """ + text_hash = hashlib.sha256(text.encode("utf-8")).hexdigest() + return f"{len(text)}:{text_hash}" + + +class _SlidingWindowSafetyMixin: + """ + Shared functionality for sliding window safety processing. + + This mixin provides common methods for both SafetyLogitsProcessor and + SafetyStoppingCriteria to handle sliding window text extraction, + incremental checking, and cache management. + """ + + def _get_text_to_check(self, full_text: str, safety_state: SafetyState) -> tuple[str, int]: + """ + Determine what text to check based on sliding window and incremental settings. + + Args: + full_text: The complete sequence text + safety_state: The safety state for this sequence + + Returns: + tuple[str, int]: Text to check and window start position + """ + if self.incremental_checking: + # Use incremental checking with sliding window + return safety_state.get_incremental_text( + full_text, self.sliding_window_size if self.sliding_window_size > 0 else -1 + ) + # Use sliding window without incremental state + if self.sliding_window_size > 0 and len(full_text) > self.sliding_window_size: + # Extract sliding window (character-based approximation) + text_to_check = full_text[-self.sliding_window_size :] + window_start = len(full_text) - self.sliding_window_size + return text_to_check, window_start + return full_text, 0 + + def _should_skip_safety_check( + self, safety_state: SafetyState, current_position: int, min_new_tokens: int = 5 + ) -> tuple[bool, SafetyResult]: + """ + Determine if we should skip the safety check and return cached result. + + Args: + safety_state: The safety state for this sequence + current_position: Current position in tokens + min_new_tokens: Minimum tokens required for new check + + Returns: + tuple[bool, SafetyResult]: Whether to skip check and result to use if skipping + """ + if not self.incremental_checking: + return False, None + + if not safety_state.should_check_incremental(current_position, min_new_tokens): + # Use previous result if available + safety_result = safety_state.last_check_result + if safety_result is not None: + return True, safety_result + return False, None + + def _perform_safety_check(self, text_to_check: str) -> SafetyResult: + """ + Perform safety check with caching and error handling. + + Args: + text_to_check: Text to check for safety violations + + Returns: + SafetyResult: Result of the safety check + """ + # Check cache first (unified interface) + safety_result = self._sequence_cache.get(text_to_check, use_prefix_matching=True) + + if safety_result is not None: + self._metrics.record_cache_hit() + return safety_result + + # Check safety of text portion - record timing + start_time = time.perf_counter() + try: + safety_result = self.safety_checker.check_safety(text_to_check) + except Exception as e: + logger.warning(f"Safety check failed for sequence: {e}") + # Fail safe - assume unsafe if we can't check + safety_result = SafetyResult( + is_safe=False, + confidence=0.0, + violations=[SafetyViolation("unknown", 0.0, "high", "Safety check failed")], + metadata={"error": str(e)}, + ) + + # Record timing and cache miss + end_time = time.perf_counter() + check_time_ms = (end_time - start_time) * 1000 + self._metrics.record_safety_check(check_time_ms) + self._metrics.record_cache_miss() + + # Cache the result + self._sequence_cache.put(text_to_check, safety_result) + return safety_result + + def _update_safety_state( + self, + safety_state: SafetyState, + current_position: int, + safety_result: SafetyResult, + text_to_check: str, + window_start: int, + full_text: str, + ) -> None: + """ + Update safety state with new check result if using incremental checking. + + Args: + safety_state: The safety state to update + current_position: Current position in sequence + safety_result: Result from safety check + text_to_check: Text that was checked + window_start: Start position of the window + full_text: Complete sequence text + """ + if self.incremental_checking: + safety_state.update_check_result( + current_position, safety_result, text_to_check if window_start == 0 else full_text + ) + + +class SafetyLogitsProcessor(LogitsProcessor, _SlidingWindowSafetyMixin): + """ + [`LogitsProcessor`] that blocks generation when unsafe content is detected. + + This processor checks the current sequence for safety violations and blocks + further generation by suppressing all tokens when unsafe content is detected. + It integrates with the transformers safety framework to provide real-time + content blocking. + + Args: + safety_checker ([`SafetyChecker`]): + The safety checker to use for content evaluation. + tokenizer ([`PreTrainedTokenizer`]): + The tokenizer used for decoding sequences. + safety_config ([`SafetyConfig`]): + Configuration for safety checking. + check_interval (`int`, *optional*, defaults to 1): + Check safety every N tokens. Must be positive. + suppress_threshold (`float`, *optional*, defaults to negative infinity): + Logit value for suppressing unsafe tokens. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + >>> from transformers.generation.safety import SafetyLogitsProcessor, SafetyConfig + >>> from examples.safe_generation import BasicToxicityChecker + + >>> # Initialize model and tokenizer + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> tokenizer.pad_token = tokenizer.eos_token + + >>> # Create safety checker and config + >>> safety_checker = BasicToxicityChecker() + >>> safety_config = SafetyConfig.from_checker(safety_checker) + >>> safety_processor = SafetyLogitsProcessor( + ... safety_checker=safety_checker, + ... tokenizer=tokenizer, + ... safety_config=safety_config + ... ) + + >>> # Generate with safety filtering + >>> inputs = tokenizer("Tell me about", return_tensors="pt") + >>> outputs = model.generate( + ... **inputs, + ... logits_processor=[safety_processor], + ... max_new_tokens=50, + ... do_sample=True + ... ) + >>> generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + ``` + """ + + def __init__( + self, + safety_checker: SafetyChecker, + tokenizer, + safety_config: SafetyConfig, + check_interval: int = 1, + suppress_threshold: float = -float("inf"), + ): + """ + Initialize the SafetyLogitsProcessor. + + Args: + safety_checker: The safety checker to use for content evaluation + tokenizer: The tokenizer used for decoding sequences + safety_config: Configuration for safety checking + check_interval: Check safety every N tokens (default: 1, must be positive) + suppress_threshold: Logit value for suppressing unsafe tokens + + Raises: + ValueError: If check_interval is not positive + """ + # Input validation + if not isinstance(check_interval, int) or check_interval < 1: + raise ValueError(f"check_interval must be a positive integer, got {check_interval}") + + self.safety_checker = safety_checker + self.tokenizer = tokenizer + self.safety_config = safety_config + self.check_interval = check_interval + self.suppress_threshold = suppress_threshold + self._step_count = 0 + + # Initialize sliding window and incremental checking + self._safety_states = {} # Track safety state per sequence in the batch + self.sliding_window_size = getattr(safety_config, "sliding_window_size", 512) + self.incremental_checking = getattr(safety_config, "incremental_checking", True) + + # Initialize cache with configured size (use prefix cache if incremental checking is enabled) + cache_size = getattr(safety_config, "cache_size", DEFAULT_CACHE_SIZE) + if self.incremental_checking: + prefix_lengths = getattr(safety_config, "prefix_lengths", [100, 75, 50]) + min_text_length_for_prefix = getattr(safety_config, "min_text_length_for_prefix", 50) + self._sequence_cache = _PrefixSafetyCache( + max_size=cache_size, + prefix_lengths=prefix_lengths, + min_text_length_for_prefix=min_text_length_for_prefix, + ) # Advanced prefix-aware cache + else: + self._sequence_cache = _SafetyCache(max_size=cache_size) # Simple LRU cache + self._metrics = SafetyMetrics() # Initialize metrics collection + + def _apply_token_suppression(self, scores: torch.FloatTensor, batch_idx: int, safety_result: SafetyResult) -> None: + """ + Apply token suppression for unsafe content. + + Args: + scores: Token scores tensor to modify + batch_idx: Index in the batch + safety_result: Safety check result + """ + if not safety_result.is_safe: + tokens_to_suppress = self._get_tokens_to_suppress(scores[batch_idx], safety_result) + if len(tokens_to_suppress) > 0: + device = scores.device + if isinstance(tokens_to_suppress, list): + tokens_to_suppress = torch.tensor(tokens_to_suppress, device=device) + scores[batch_idx, tokens_to_suppress] = self.suppress_threshold + self._metrics.record_suppression_event() + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + """ + Apply safety filtering to token scores. + + Args: + input_ids: Current sequence tokens [batch_size, seq_len] + scores: Token logits [batch_size, vocab_size] + + Returns: + Modified scores with unsafe tokens suppressed + """ + self._step_count += 1 + + # Only check safety at specified intervals for performance + if self._step_count % self.check_interval != 0: + return scores + + batch_size = input_ids.shape[0] + modified_scores = scores.clone() + + for i in range(batch_size): + # Convert current sequence to text + current_sequence = input_ids[i] + full_text = self.tokenizer.decode(current_sequence, skip_special_tokens=True) + + # Get or create safety state for this sequence + seq_id = i + if seq_id not in self._safety_states: + self._safety_states[seq_id] = SafetyState() + safety_state = self._safety_states[seq_id] + + # Determine what text to check based on configuration + text_to_check, window_start = self._get_text_to_check(full_text, safety_state) + current_position = len(current_sequence) + + # Check if we can skip the safety check and reuse previous result + should_skip, cached_result = self._should_skip_safety_check(safety_state, current_position) + if should_skip: + self._apply_token_suppression(modified_scores, i, cached_result) + continue + + # Perform safety check with caching and error handling + safety_result = self._perform_safety_check(text_to_check) + + # Update safety state if using incremental checking + self._update_safety_state( + safety_state, current_position, safety_result, text_to_check, window_start, full_text + ) + + # Apply suppression if content is unsafe + self._apply_token_suppression(modified_scores, i, safety_result) + + # Debug logging for unsafe content + if not safety_result.is_safe: + logger.debug(f"Suppressed tokens using blocking strategy for sequence: {full_text[:50]}...") + + return modified_scores + + def _get_tokens_to_suppress(self, scores: torch.FloatTensor, safety_result) -> torch.Tensor: + """ + Determine which tokens to suppress when unsafe content is detected. + + Args: + scores: Token logits for a single sequence [vocab_size] + safety_result: The safety assessment result (unused for blocking strategy) + + Returns: + Tensor of token indices to suppress (all tokens for blocking) + """ + # Block strategy: suppress all tokens to force generation to stop + return torch.arange(scores.size(0), device=scores.device) + + def get_metrics(self) -> SafetyMetrics: + """ + Return current metrics for this processor. + + Returns: + SafetyMetrics: Current metrics collection. + """ + return self._metrics + + def reset_safety_states(self) -> None: + """ + Reset all safety states. Call this when starting a new generation batch. + """ + self._safety_states.clear() + + def _get_text_for_safety_check(self, full_text: str, safety_state: SafetyState) -> tuple[str, int]: + """ + Extract the appropriate text portion for safety checking. + + Args: + full_text: The complete sequence text + safety_state: Current safety state for incremental checking + + Returns: + tuple[str, int]: Text to check and its starting position + """ + if self.incremental_checking: + return safety_state.get_incremental_text( + full_text, self.sliding_window_size if self.sliding_window_size > 0 else -1 + ) + # Simple sliding window without incremental state + if self.sliding_window_size > 0 and len(full_text) > self.sliding_window_size: + window_start = len(full_text) - self.sliding_window_size + return full_text[window_start:], window_start + return full_text, 0 + + +class SafetyStoppingCriteria(StoppingCriteria, _SlidingWindowSafetyMixin): + """ + [`StoppingCriteria`] that halts generation when unsafe content is detected. + + This provides a sequence-level safety check that can stop generation before + unsafe content is returned to the user. It works as a final safety gate + after token-level filtering by SafetyLogitsProcessor. + + Args: + safety_checker ([`SafetyChecker`]): + The safety checker to use for content evaluation. + tokenizer ([`PreTrainedTokenizer`]): + The tokenizer used for decoding sequences. + safety_config ([`SafetyConfig`]): + Configuration for safety checking. + check_final_only (`bool`, *optional*, defaults to `False`): + If True, only check safety on the final call (when all sequences are complete). + If False, check safety on every call during generation. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + >>> from transformers.generation.safety import SafetyStoppingCriteria, SafetyConfig + >>> from examples.safe_generation import BasicToxicityChecker + + >>> # Initialize model and tokenizer + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> tokenizer.pad_token = tokenizer.eos_token + + >>> # Create safety checker and config + >>> safety_checker = BasicToxicityChecker() + >>> safety_config = SafetyConfig.from_checker(safety_checker) + >>> safety_stopping = SafetyStoppingCriteria( + ... safety_checker=safety_checker, + ... tokenizer=tokenizer, + ... safety_config=safety_config + ... ) + + >>> # Generate with safety stopping + >>> inputs = tokenizer("Tell me about", return_tensors="pt") + >>> outputs = model.generate( + ... **inputs, + ... stopping_criteria=[safety_stopping], + ... max_new_tokens=50, + ... do_sample=True + ... ) + >>> generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + ``` + """ + + def __init__( + self, safety_checker: SafetyChecker, tokenizer, safety_config: SafetyConfig, check_final_only: bool = False + ): + """ + Initialize the SafetyStoppingCriteria. + + Args: + safety_checker: The safety checker to use for content evaluation + tokenizer: The tokenizer used for decoding sequences + safety_config: Configuration for safety checking + check_final_only: If True, only check when generation is complete + + Raises: + ValueError: If safety_checker is None + """ + if safety_checker is None: + raise ValueError("safety_checker cannot be None") + + self.safety_checker = safety_checker + self.tokenizer = tokenizer + self.safety_config = safety_config + self.check_final_only = check_final_only + self._unsafe_sequence_hashes = OrderedDict() # Track unsafe sequences by content hash (LRU) + + # Initialize sliding window and incremental checking + self._safety_states = {} # Track safety state per sequence in the batch + self.sliding_window_size = getattr(safety_config, "sliding_window_size", 512) + self.incremental_checking = getattr(safety_config, "incremental_checking", True) + + # Initialize cache with configured size (use prefix cache if incremental checking is enabled) + cache_size = getattr(safety_config, "cache_size", DEFAULT_CACHE_SIZE) + if self.incremental_checking: + prefix_lengths = getattr(safety_config, "prefix_lengths", [100, 75, 50]) + min_text_length_for_prefix = getattr(safety_config, "min_text_length_for_prefix", 50) + self._sequence_cache = _PrefixSafetyCache( + max_size=cache_size, + prefix_lengths=prefix_lengths, + min_text_length_for_prefix=min_text_length_for_prefix, + ) # Advanced prefix-aware cache + else: + self._sequence_cache = _SafetyCache(max_size=cache_size) # Simple LRU cache + # Get configured unsafe hash limit + self._unsafe_hash_limit = getattr(safety_config, "unsafe_hash_limit", DEFAULT_UNSAFE_HASH_LIMIT) + self._metrics = SafetyMetrics() # Initialize metrics collection + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: + """ + Check if generation should stop due to safety violations. + + Args: + input_ids: Current sequences [batch_size, seq_len] + scores: Token scores [batch_size, vocab_size] + + Returns: + Boolean tensor indicating which sequences should stop [batch_size] + """ + batch_size = input_ids.shape[0] + + # Record generation attempts for metrics + for _ in range(batch_size): + self._metrics.record_generation_attempt() + + # Initialize should_stop tensor + should_stop = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device) + + # If check_final_only is True, skip safety checks during generation + # This will be handled by other safety mechanisms or post-generation checks + if self.check_final_only and not kwargs.get("is_final_call", False): + return should_stop + + # Check each sequence for safety violations + for i in range(batch_size): + # Convert current sequence to text + current_sequence = input_ids[i] + full_text = self.tokenizer.decode(current_sequence, skip_special_tokens=True) + + # Get or create safety state for this sequence + seq_id = i + if seq_id not in self._safety_states: + self._safety_states[seq_id] = SafetyState() + safety_state = self._safety_states[seq_id] + + # Determine what text to check based on configuration + text_to_check, window_start = self._get_text_to_check(full_text, safety_state) + current_position = len(current_sequence) + + # Check if we can skip the safety check and reuse previous result (less aggressive for stopping) + should_skip, cached_result = self._should_skip_safety_check( + safety_state, current_position, min_new_tokens=3 + ) + if should_skip and cached_result is not None and not cached_result.is_safe: + should_stop[i] = True + continue + + # Create sequence hash for tracking (using consistent SHA-256 hash) + sequence_hash = _generate_cache_key(text_to_check) + + # Check if this sequence content is already known to be unsafe (LRU) + if sequence_hash in self._unsafe_sequence_hashes: + # Move to end for LRU behavior + self._unsafe_sequence_hashes.move_to_end(sequence_hash) + should_stop[i] = True + continue + + # Perform safety check with caching and error handling + safety_result = self._perform_safety_check(text_to_check) + + # Update safety state if using incremental checking + self._update_safety_state( + safety_state, current_position, safety_result, text_to_check, window_start, full_text + ) + + # If unsafe, mark sequence for stopping + if not safety_result.is_safe: + # Add to unsafe hash dict with LRU eviction to prevent memory leaks + if len(self._unsafe_sequence_hashes) >= self._unsafe_hash_limit: + # Remove oldest entry (LRU eviction) + self._unsafe_sequence_hashes.popitem(last=False) + # Add new entry (will be at the end = most recently used) + self._unsafe_sequence_hashes[sequence_hash] = True # Track by content hash + should_stop[i] = True + self._metrics.record_blocked_generation() + + # Log safety violation for debugging + violation_categories = [v.category for v in safety_result.violations] + logger.warning( + f"Generation stopped for sequence {i} due to safety violations: {violation_categories}. " + f"Text: {full_text[:100]}..." + ) + + return should_stop + + def get_metrics(self) -> SafetyMetrics: + """ + Return current metrics for this stopping criteria. + + Returns: + SafetyMetrics: Current metrics collection. + """ + return self._metrics + + def reset_safety_states(self) -> None: + """ + Reset all safety states. Call this when starting a new generation batch. + """ + self._safety_states.clear() diff --git a/src/transformers/generation/safety/utils.py b/src/transformers/generation/safety/utils.py new file mode 100644 index 000000000000..f639aca3a082 --- /dev/null +++ b/src/transformers/generation/safety/utils.py @@ -0,0 +1,40 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration import SafetyConfig + + +def validate_safety_config(config: SafetyConfig) -> bool: + """ + Validate a safety configuration and return whether it's valid. + + Args: + config (`SafetyConfig`): Configuration to validate. + + Returns: + `bool`: True if configuration is valid, False otherwise. + + Example: + ```python + config = SafetyConfig(enabled=True, thresholds={"toxicity": 0.5}) + if validate_safety_config(config): + print("Configuration is valid") + ``` + """ + try: + config.validate() + return True + except (ValueError, TypeError): + return False diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 64430fefad42..aed403d25665 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1128,6 +1128,65 @@ def _get_candidate_generator( ) return candidate_generator + def _create_safety_processor(self, safety_config, processor_type="logits"): + """ + Create safety processor from configuration. + + Args: + safety_config: SafetyConfig object containing safety settings + processor_type: Type of processor to create ("logits" or "stopping") + + Returns: + SafetyLogitsProcessor or SafetyStoppingCriteria, or None if creation fails + """ + if not safety_config or not getattr(safety_config, "enabled", False): + return None + + # Ensure we have a tokenizer + if not hasattr(self, "tokenizer") or self.tokenizer is None: + logger.warning("Cannot create safety processor: tokenizer not available") + return None + + try: + from .safety import SafetyLogitsProcessor, SafetyStoppingCriteria + + # Get checker from configuration + try: + safety_checker = safety_config.construct_checker() + except ValueError as e: + raise ValueError( + f"Safety configuration error: {e}\n" + "You must provide a SafetyChecker instance in SafetyConfig. " + "See examples/safe_generation/ for reference implementations." + ) from e + + if processor_type == "logits": + return SafetyLogitsProcessor( + safety_checker=safety_checker, + tokenizer=self.tokenizer, + safety_config=safety_config, + check_interval=getattr(safety_config, "check_interval", 1), + ) + elif processor_type == "stopping": + return SafetyStoppingCriteria( + safety_checker=safety_checker, + tokenizer=self.tokenizer, + safety_config=safety_config, + check_final_only=getattr(safety_config, "check_final_only", False), + ) + else: + raise ValueError(f"processor_type must be 'logits' or 'stopping', got '{processor_type}'") + + except ImportError: + logger.warning("Safety module not available - cannot create safety processors") + return None + except ValueError: + # Re-raise ValueError for input validation errors (like invalid processor_type or missing checker) + raise + except Exception as e: + logger.warning(f"Failed to create safety {processor_type} processor: {e}") + return None + def _get_logits_processor( self, generation_config: GenerationConfig, @@ -1285,6 +1344,12 @@ def _get_logits_processor( ) ) + # Add safety processor if enabled + if hasattr(generation_config, "safety_config") and generation_config.safety_config is not None: + safety_processor = self._create_safety_processor(generation_config.safety_config, "logits") + if safety_processor is not None: + processors.append(safety_processor) + # TODO (joao): find a strategy to specify the order of the processors processors = self._merge_criteria_processor_list(processors, logits_processor) @@ -1386,6 +1451,13 @@ def _get_stopping_criteria( criteria.append( ConfidenceCriteria(assistant_confidence_threshold=generation_config.assistant_confidence_threshold) ) + + # Add safety stopping criteria if enabled + if hasattr(generation_config, "safety_config") and generation_config.safety_config is not None: + safety_stopping = self._create_safety_processor(generation_config.safety_config, "stopping") + if safety_stopping is not None: + criteria.append(safety_stopping) + criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) return criteria diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 7950e6faf2da..a57882ef0ac7 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -433,6 +433,11 @@ def _forward(self, model_inputs, **generate_kwargs): if "generation_config" not in generate_kwargs: generate_kwargs["generation_config"] = self.generation_config + # If safety_config is provided, attach tokenizer to model for safety processor creation + # GenerationMixin._create_safety_processor() expects self.tokenizer on the model + if "safety_config" in generate_kwargs and hasattr(self, "tokenizer") and self.tokenizer is not None: + self.model.tokenizer = self.tokenizer + output = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs) if isinstance(output, ModelOutput): diff --git a/tests/generation/test_safety_checkers.py b/tests/generation/test_safety_checkers.py new file mode 100644 index 000000000000..d60f30ea287c --- /dev/null +++ b/tests/generation/test_safety_checkers.py @@ -0,0 +1,261 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest +from pathlib import Path +from unittest.mock import Mock, patch + + +# Add examples directory to Python path to import BasicToxicityChecker +examples_path = Path(__file__).parent.parent.parent / "examples" +if str(examples_path) not in sys.path: + sys.path.insert(0, str(examples_path)) + +from safe_generation import BasicToxicityChecker # noqa: E402 + +from transformers.generation.safety import SafetyResult # noqa: E402 +from transformers.testing_utils import require_torch # noqa: E402 + + +@require_torch +class TestBasicToxicityChecker(unittest.TestCase): + """Test suite for BasicToxicityChecker.""" + + def setUp(self): + """Set up test fixtures.""" + self.mock_tokenizer_patcher = patch("transformers.AutoTokenizer.from_pretrained") + self.mock_model_patcher = patch("transformers.AutoModelForSequenceClassification.from_pretrained") + + self.mock_tokenizer = self.mock_tokenizer_patcher.start() + self.mock_model = self.mock_model_patcher.start() + + # Configure mock tokenizer + mock_tokenizer_instance = Mock() + + # Create a mock that can be unpacked as **kwargs + class MockTokenizerOutput(dict): + def to(self, device): + return self + + mock_tokenizer_instance.return_value = MockTokenizerOutput({"input_ids": Mock(), "attention_mask": Mock()}) + self.mock_tokenizer.return_value = mock_tokenizer_instance + + # Configure mock model + self.mock_model_instance = Mock() + self.mock_model_instance.eval.return_value = None + self.mock_model_instance.to.return_value = None + self.mock_model.return_value = self.mock_model_instance + + def tearDown(self): + """Clean up test fixtures.""" + self.mock_tokenizer_patcher.stop() + self.mock_model_patcher.stop() + + @patch("torch.cuda.is_available", return_value=False) + def test_init_with_defaults(self, mock_cuda): + """Test BasicToxicityChecker initialization with default parameters.""" + checker = BasicToxicityChecker() + + self.assertEqual(checker.model_name, "s-nlp/roberta_toxicity_classifier") + self.assertEqual(checker.threshold, 0.7) + self.assertEqual(checker.device, "cpu") + self.assertEqual(checker.supported_categories, ["toxicity"]) + + @patch("torch.cuda.is_available", return_value=True) + def test_init_with_cuda_available(self, mock_cuda): + """Test BasicToxicityChecker initialization when CUDA is available.""" + checker = BasicToxicityChecker() + self.assertEqual(checker.device, "cuda") + + def test_init_with_custom_params(self): + """Test BasicToxicityChecker initialization with custom parameters.""" + checker = BasicToxicityChecker(model_name="custom/model", threshold=0.8, device="cpu") + + self.assertEqual(checker.model_name, "custom/model") + self.assertEqual(checker.threshold, 0.8) + self.assertEqual(checker.device, "cpu") + + def test_init_model_loading_failure(self): + """Test BasicToxicityChecker handles model loading failures gracefully.""" + # Make model loading fail + self.mock_model.side_effect = Exception("Model not found") + + with self.assertRaises(RuntimeError) as context: + BasicToxicityChecker() + + self.assertIn("Failed to load toxicity model", str(context.exception)) + self.assertIn("Model not found", str(context.exception)) + + @patch("torch.no_grad") + @patch("torch.nn.functional.softmax") + def test_safe_text_detection(self, mock_softmax, mock_no_grad): + """Test detection of safe (non-toxic) text.""" + import torch + + # Mock safe prediction (low toxicity score) + mock_outputs = Mock() + mock_outputs.logits = torch.tensor([[2.0, 0.5]]) # Non-toxic >> toxic + self.mock_model_instance.return_value = mock_outputs + + # Mock softmax to return low toxicity probability + mock_softmax.return_value = torch.tensor([[0.8, 0.2]]) # [non-toxic, toxic] + + checker = BasicToxicityChecker(threshold=0.7) + result = checker.check_safety("This is a nice, positive comment") + + self.assertIsInstance(result, SafetyResult) + self.assertTrue(result.is_safe) + self.assertEqual(len(result.violations), 0) + self.assertIn("toxicity_score", result.metadata) + self.assertAlmostEqual(result.metadata["toxicity_score"], 0.2, places=5) + + @patch("torch.no_grad") + @patch("torch.nn.functional.softmax") + def test_toxic_text_detection(self, mock_softmax, mock_no_grad): + """Test detection of toxic text.""" + import torch + + # Mock toxic prediction (high toxicity score) + mock_outputs = Mock() + mock_outputs.logits = torch.tensor([[0.2, 3.0]]) # Non-toxic << toxic + self.mock_model_instance.return_value = mock_outputs + + # Mock softmax to return high toxicity probability + mock_softmax.return_value = torch.tensor([[0.15, 0.85]]) # [non-toxic, toxic] + + checker = BasicToxicityChecker(threshold=0.7) + result = checker.check_safety("This is some toxic harmful content") + + self.assertIsInstance(result, SafetyResult) + self.assertFalse(result.is_safe) + self.assertEqual(len(result.violations), 1) + + violation = result.violations[0] + self.assertEqual(violation.category, "toxicity") + self.assertAlmostEqual(violation.confidence, 0.85, places=5) + self.assertIn("high", violation.severity) # 0.85 should be "high" severity + self.assertIn("85.00%", violation.description) + + def test_batch_processing(self): + """Test batch processing of multiple texts.""" + import torch + + with patch("torch.no_grad"), patch("torch.nn.functional.softmax") as mock_softmax: + # Mock mixed results + mock_outputs = Mock() + mock_outputs.logits = torch.tensor([[2.0, 0.5]]) + self.mock_model_instance.return_value = mock_outputs + mock_softmax.return_value = torch.tensor([[0.8, 0.2]]) # Safe + + checker = BasicToxicityChecker() + results = checker.check_safety(["Safe text", "Another safe text"]) + + self.assertIsInstance(results, list) + self.assertEqual(len(results), 2) + self.assertTrue(all(isinstance(r, SafetyResult) for r in results)) + + def test_empty_text_handling(self): + """Test handling of empty text input.""" + + checker = BasicToxicityChecker() + result = checker.check_safety("") + + self.assertTrue(result.is_safe) + self.assertEqual(result.confidence, 1.0) + self.assertEqual(len(result.violations), 0) + self.assertEqual(result.metadata["reason"], "empty_text") + + def test_whitespace_only_text_handling(self): + """Test handling of whitespace-only text input.""" + + checker = BasicToxicityChecker() + result = checker.check_safety(" \n\t ") + + self.assertTrue(result.is_safe) + self.assertEqual(result.confidence, 1.0) + self.assertEqual(len(result.violations), 0) + self.assertEqual(result.metadata["reason"], "empty_text") + + @patch("safe_generation.checkers.logger") + def test_long_text_truncation(self, mock_logger): + """Test handling of very long text input.""" + import torch + + with patch("torch.no_grad"), patch("torch.nn.functional.softmax") as mock_softmax: + mock_outputs = Mock() + mock_outputs.logits = torch.tensor([[2.0, 0.5]]) + self.mock_model_instance.return_value = mock_outputs + mock_softmax.return_value = torch.tensor([[0.8, 0.2]]) + + checker = BasicToxicityChecker() + long_text = "A" * 15000 # Longer than 10000 char limit + result = checker.check_safety(long_text) + + self.assertIn("truncated", result.metadata) + self.assertTrue(result.metadata["truncated"]) + self.assertEqual(result.metadata["original_length"], 15000) + self.assertEqual(result.metadata["processed_length"], 10000) + mock_logger.warning.assert_called_once() + + def test_invalid_input_type(self): + """Test handling of invalid input types.""" + + checker = BasicToxicityChecker() + + with self.assertRaises(TypeError) as context: + checker.check_safety(123) # Not a string or list + + self.assertIn("Expected string or list of strings", str(context.exception)) + + def test_severity_classification(self): + """Test severity classification logic.""" + + checker = BasicToxicityChecker() + + # Test different severity levels + self.assertEqual(checker._get_severity(0.96), "critical") + self.assertEqual(checker._get_severity(0.90), "high") + self.assertEqual(checker._get_severity(0.80), "medium") + self.assertEqual(checker._get_severity(0.65), "low") + + def test_get_config(self): + """Test get_config method returns correct configuration.""" + + checker = BasicToxicityChecker(model_name="test/model", threshold=0.8, device="cpu") + + config = checker.get_config() + expected_config = { + "checker_type": "BasicToxicityChecker", + "model_name": "test/model", + "threshold": 0.8, + "device": "cpu", + } + + self.assertEqual(config, expected_config) + + @patch("torch.no_grad") + def test_inference_error_handling(self, mock_no_grad): + """Test handling of inference errors.""" + + # Make model inference fail + self.mock_model_instance.side_effect = RuntimeError("CUDA out of memory") + + checker = BasicToxicityChecker() + + with self.assertRaises(RuntimeError) as context: + checker.check_safety("test text") + + self.assertIn("Toxicity detection failed", str(context.exception)) diff --git a/tests/generation/test_safety_config.py b/tests/generation/test_safety_config.py new file mode 100644 index 000000000000..018d9d47d012 --- /dev/null +++ b/tests/generation/test_safety_config.py @@ -0,0 +1,383 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import Mock + +from transformers.generation.safety import ( + LENIENT_PRESET, + MODERATE_PRESET, + STRICT_PRESET, + SafetyChecker, + SafetyConfig, +) + + +class TestSafetyConfig(unittest.TestCase): + """Test suite for SafetyConfig.""" + + def setUp(self): + """Set up mock checker for tests.""" + self.mock_checker = Mock(spec=SafetyChecker) + self.mock_checker.supported_categories = ["toxicity"] + + def test_default_config(self): + """Test SafetyConfig with default values.""" + config = SafetyConfig() + + # Check default values + self.assertFalse(config.enabled) + self.assertIsNone(config.checker) + self.assertIsNone(config.device) + self.assertFalse(config.return_violations) + self.assertFalse(config.return_metadata) + self.assertEqual(config.cache_size, 100) + self.assertEqual(config.unsafe_hash_limit, 1000) + self.assertEqual(config.sliding_window_size, 512) + self.assertTrue(config.incremental_checking) + + def test_from_checker_basic(self): + """Test creating config from checker using from_checker (recommended pattern).""" + config = SafetyConfig.from_checker(self.mock_checker) + + # Verify config was created correctly + self.assertTrue(config.enabled) + self.assertIs(config.checker, self.mock_checker) + self.assertEqual(config.cache_size, 100) # Default + self.assertFalse(config.return_violations) # Default + self.assertFalse(config.return_metadata) # Default + + def test_from_checker_with_preset(self): + """Test creating config from checker with preset parameters.""" + config = SafetyConfig.from_checker(self.mock_checker, **STRICT_PRESET) + + self.assertTrue(config.enabled) + self.assertIs(config.checker, self.mock_checker) + self.assertEqual(config.cache_size, 50) + self.assertEqual(config.unsafe_hash_limit, 500) + self.assertTrue(config.return_violations) + self.assertTrue(config.return_metadata) + + def test_from_checker_with_custom_params(self): + """Test creating config from checker with custom parameters.""" + config = SafetyConfig.from_checker(self.mock_checker, cache_size=200, return_violations=True, device="cuda") + + self.assertTrue(config.enabled) + self.assertIs(config.checker, self.mock_checker) + self.assertEqual(config.cache_size, 200) + self.assertTrue(config.return_violations) + self.assertEqual(config.device, "cuda") + + def test_construct_checker_returns_instance(self): + """Test that construct_checker returns the provided checker instance.""" + config = SafetyConfig.from_checker(self.mock_checker) + retrieved = config.construct_checker() + self.assertIs(retrieved, self.mock_checker) + + def test_construct_checker_error_when_missing(self): + """Test that construct_checker raises helpful error when checker is missing.""" + config = SafetyConfig(enabled=True) + + with self.assertRaises(ValueError) as context: + config.construct_checker() + + error_message = str(context.exception) + self.assertIn("SafetyConfig requires a checker instance", error_message) + self.assertIn("examples/safe_generation", error_message) + self.assertIn("BasicToxicityChecker", error_message) + self.assertIn("from_checker", error_message) + + def test_serialization_round_trip(self): + """Test serialization and deserialization (note: checker not serialized).""" + original_config = SafetyConfig.from_checker( + self.mock_checker, cache_size=150, return_violations=True, device="cpu" + ) + + # Serialize to dict + config_dict = original_config.to_dict() + + # Check dict contents (checker is not serialized) + self.assertEqual(config_dict["enabled"], True) + self.assertEqual(config_dict["cache_size"], 150) + self.assertEqual(config_dict["device"], "cpu") + self.assertTrue(config_dict["return_violations"]) + self.assertNotIn("checker", config_dict) + + # Deserialize from dict + restored_config = SafetyConfig.from_dict(config_dict) + + # Check attributes match (except checker which isn't serialized) + self.assertEqual(restored_config.enabled, original_config.enabled) + self.assertEqual(restored_config.cache_size, original_config.cache_size) + self.assertEqual(restored_config.device, original_config.device) + self.assertIsNone(restored_config.checker) # Checker must be re-provided + + # Re-attach checker to restored config + restored_config.checker = self.mock_checker + retrieved = restored_config.construct_checker() + self.assertIs(retrieved, self.mock_checker) + + def test_validation_success(self): + """Test validation with valid configuration.""" + # Valid default config + config = SafetyConfig() + config.validate() # Should not raise + + # Valid config with checker + config = SafetyConfig.from_checker(self.mock_checker, return_violations=True) + config.validate() # Should not raise + + def test_validation_enabled_type(self): + """Test validation of enabled field.""" + config = SafetyConfig(enabled="true") # Wrong type + with self.assertRaises(ValueError) as context: + config.validate() + self.assertIn("enabled must be a boolean", str(context.exception)) + + def test_validation_output_config_types(self): + """Test validation of output configuration types.""" + # Wrong return_violations type + config = SafetyConfig(return_violations="true") + with self.assertRaises(ValueError) as context: + config.validate() + self.assertIn("return_violations must be a boolean", str(context.exception)) + + # Wrong return_metadata type + config = SafetyConfig(return_metadata=1) + with self.assertRaises(ValueError) as context: + config.validate() + self.assertIn("return_metadata must be a boolean", str(context.exception)) + + def test_cache_size_configuration(self): + """Test cache size configuration and validation.""" + # Test default cache size + config = SafetyConfig() + self.assertEqual(config.cache_size, 100) + + # Test custom cache size + config = SafetyConfig(cache_size=50) + self.assertEqual(config.cache_size, 50) + + # Test cache size validation - must be positive integer (caught in __post_init__) + with self.assertRaises(ValueError): + SafetyConfig(cache_size=0) + + with self.assertRaises(ValueError): + SafetyConfig(cache_size=-1) + + with self.assertRaises(TypeError): + SafetyConfig(cache_size=3.14) + + with self.assertRaises(TypeError): + SafetyConfig(cache_size="100") + + def test_unsafe_hash_limit_configuration(self): + """Test unsafe hash limit configuration and validation.""" + # Test default unsafe hash limit + config = SafetyConfig() + self.assertEqual(config.unsafe_hash_limit, 1000) + + # Test custom unsafe hash limit + config = SafetyConfig(unsafe_hash_limit=500) + self.assertEqual(config.unsafe_hash_limit, 500) + + # Test validation - must be positive integer (caught in __post_init__) + with self.assertRaises(ValueError): + SafetyConfig(unsafe_hash_limit=0) + + with self.assertRaises(ValueError): + SafetyConfig(unsafe_hash_limit=-1) + + with self.assertRaises(TypeError): + SafetyConfig(unsafe_hash_limit=2.5) + + with self.assertRaises(TypeError): + SafetyConfig(unsafe_hash_limit="1000") + + def test_large_cache_size_warning(self): + """Test warning for potentially inefficient cache sizes.""" + import warnings + + # Test cache size warning + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + SafetyConfig(cache_size=20000).validate() + self.assertEqual(len(w), 1) + self.assertTrue("cache_size > 10000" in str(w[0].message)) + + # Test unsafe hash limit warning + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + SafetyConfig(unsafe_hash_limit=200000).validate() + self.assertEqual(len(w), 1) + self.assertTrue("unsafe_hash_limit > 100000" in str(w[0].message)) + + def test_preset_constants(self): + """Test that preset constants have expected values.""" + # STRICT_PRESET + self.assertEqual(STRICT_PRESET["cache_size"], 50) + self.assertEqual(STRICT_PRESET["unsafe_hash_limit"], 500) + self.assertTrue(STRICT_PRESET["return_violations"]) + self.assertTrue(STRICT_PRESET["return_metadata"]) + + # MODERATE_PRESET + self.assertEqual(MODERATE_PRESET["cache_size"], 100) + self.assertEqual(MODERATE_PRESET["unsafe_hash_limit"], 1000) + self.assertFalse(MODERATE_PRESET["return_violations"]) + self.assertFalse(MODERATE_PRESET["return_metadata"]) + + # LENIENT_PRESET + self.assertEqual(LENIENT_PRESET["cache_size"], 200) + self.assertEqual(LENIENT_PRESET["unsafe_hash_limit"], 2000) + self.assertFalse(LENIENT_PRESET["return_violations"]) + self.assertFalse(LENIENT_PRESET["return_metadata"]) + + def test_presets_with_from_checker(self): + """Test using presets with from_checker.""" + # Test strict preset + strict_config = SafetyConfig.from_checker(self.mock_checker, **STRICT_PRESET) + self.assertEqual(strict_config.cache_size, 50) + self.assertEqual(strict_config.unsafe_hash_limit, 500) + self.assertTrue(strict_config.return_violations) + self.assertTrue(strict_config.return_metadata) + + # Test moderate preset + moderate_config = SafetyConfig.from_checker(self.mock_checker, **MODERATE_PRESET) + self.assertEqual(moderate_config.cache_size, 100) + self.assertEqual(moderate_config.unsafe_hash_limit, 1000) + self.assertFalse(moderate_config.return_violations) + + # Test lenient preset + lenient_config = SafetyConfig.from_checker(self.mock_checker, **LENIENT_PRESET) + self.assertEqual(lenient_config.cache_size, 200) + self.assertEqual(lenient_config.unsafe_hash_limit, 2000) + self.assertFalse(lenient_config.return_violations) + + def test_serialization_includes_cache_config(self): + """Test that serialization includes cache configuration.""" + config = SafetyConfig(cache_size=75, unsafe_hash_limit=750) + config_dict = config.to_dict() + + self.assertEqual(config_dict["cache_size"], 75) + self.assertEqual(config_dict["unsafe_hash_limit"], 750) + + # Test round-trip + restored_config = SafetyConfig.from_dict(config_dict) + self.assertEqual(restored_config.cache_size, 75) + self.assertEqual(restored_config.unsafe_hash_limit, 750) + + def test_sliding_window_configuration(self): + """Test sliding window configuration parameters.""" + # Test default values + config = SafetyConfig() + self.assertEqual(config.sliding_window_size, 512) + self.assertTrue(config.incremental_checking) + + # Test custom values + config = SafetyConfig(sliding_window_size=256, incremental_checking=False) + self.assertEqual(config.sliding_window_size, 256) + self.assertFalse(config.incremental_checking) + + def test_sliding_window_validation(self): + """Test validation of sliding window parameters.""" + # Test valid sliding window size + config = SafetyConfig(sliding_window_size=100) + config.validate() # Should not raise + + # Test valid disabled sliding window + config = SafetyConfig(sliding_window_size=-1) + config.validate() # Should not raise + + # Test invalid sliding window size (0) + with self.assertRaises(ValueError) as context: + SafetyConfig(sliding_window_size=0) + self.assertIn("sliding_window_size must be a positive integer or -1 to disable", str(context.exception)) + + # Test invalid sliding window size (negative but not -1) + with self.assertRaises(ValueError) as context: + SafetyConfig(sliding_window_size=-5) + self.assertIn("sliding_window_size must be a positive integer or -1 to disable", str(context.exception)) + + # Test invalid incremental_checking type + with self.assertRaises(TypeError) as context: + SafetyConfig(incremental_checking="true") + self.assertIn("incremental_checking must be a boolean", str(context.exception)) + + def test_sliding_window_serialization(self): + """Test serialization of sliding window parameters.""" + config = SafetyConfig( + sliding_window_size=256, incremental_checking=False, cache_size=50, unsafe_hash_limit=500 + ) + + # Test to_dict includes sliding window parameters + config_dict = config.to_dict() + self.assertEqual(config_dict["sliding_window_size"], 256) + self.assertEqual(config_dict["incremental_checking"], False) + + # Test round-trip serialization + restored_config = SafetyConfig.from_dict(config_dict) + self.assertEqual(restored_config.sliding_window_size, 256) + self.assertFalse(restored_config.incremental_checking) + self.assertEqual(restored_config.cache_size, 50) + self.assertEqual(restored_config.unsafe_hash_limit, 500) + + def test_sliding_window_edge_cases(self): + """Test edge cases for sliding window configuration.""" + # Test very large sliding window size + config = SafetyConfig(sliding_window_size=10000) + config.validate() # Should be valid + + # Test minimum sliding window size + config = SafetyConfig(sliding_window_size=1) + config.validate() # Should be valid + + # Test both sliding window and incremental checking disabled + config = SafetyConfig(sliding_window_size=-1, incremental_checking=False) + config.validate() # Should be valid + + def test_comprehensive_workflow(self): + """Test a complete workflow with SafetyConfig.""" + # Create configuration using from_checker (recommended approach) + config = SafetyConfig.from_checker( + self.mock_checker, cache_size=50, return_violations=True, return_metadata=True + ) + + # Validate configuration + config.validate() + + # Verify config was created correctly + self.assertTrue(config.enabled) + self.assertIs(config.checker, self.mock_checker) + self.assertEqual(config.cache_size, 50) + self.assertTrue(config.return_violations) + + # Test construct_checker returns same instance + retrieved_checker = config.construct_checker() + self.assertIs(retrieved_checker, self.mock_checker) + + # Serialize and deserialize (note: checker not serialized) + config_dict = config.to_dict() + restored_config = SafetyConfig.from_dict(config_dict) + + # Verify consistency (except checker which isn't serialized) + self.assertEqual(config.enabled, restored_config.enabled) + self.assertEqual(config.cache_size, restored_config.cache_size) + self.assertIsNone(restored_config.checker) # Checker must be re-provided after deserialization + + # Re-attach checker to restored config + restored_config.checker = self.mock_checker + + # Validate restored configuration + restored_config.validate() diff --git a/tests/generation/test_safety_e2e.py b/tests/generation/test_safety_e2e.py new file mode 100644 index 000000000000..c842ec67900a --- /dev/null +++ b/tests/generation/test_safety_e2e.py @@ -0,0 +1,231 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import unittest +from unittest.mock import Mock + +import torch + +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig +from transformers.generation.safety import SafetyChecker, SafetyConfig, SafetyResult, SafetyViolation +from transformers.testing_utils import require_torch, slow + + +class TestSafetyEndToEnd(unittest.TestCase): + """End-to-end tests for safety-enabled generation with actual models.""" + + def setUp(self): + """Set up test fixtures.""" + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def _create_mock_checker(self): + """Create a mock safety checker for testing.""" + # Create a mock checker that implements the SafetyChecker interface + mock_checker = Mock(spec=SafetyChecker) + mock_checker.supported_categories = ["toxicity"] + return mock_checker + + @require_torch + @slow + def test_greedy_generation_with_safety(self): + """Test that safety works with greedy decoding generation.""" + # Create mock checker + mock_checker = self._create_mock_checker() + + # Mock safe responses + mock_checker.check_safety.return_value = SafetyResult(is_safe=True, confidence=0.9, violations=[], metadata={}) + + # Load small model for testing + model_name = "sshleifer/tiny-gpt2" + model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + + # Create safety configuration with mock checker + safety_config = SafetyConfig.from_checker(mock_checker) + + # Create generation config with safety + gen_config = GenerationConfig( + max_length=20, + do_sample=False, # Greedy + safety_config=safety_config, + ) + + # Test generation + inputs = tokenizer("Hello, world", return_tensors="pt") + outputs = model.generate(**inputs, generation_config=gen_config) + + # Verify output is generated + self.assertGreater(outputs.shape[1], inputs["input_ids"].shape[1]) + + # Verify safety checker was called + mock_checker.check_safety.assert_called() + + @require_torch + @slow + def test_sample_generation_with_safety(self): + """Test that safety works with sampling generation.""" + mock_checker = self._create_mock_checker() + mock_checker.check_safety.return_value = SafetyResult(is_safe=True, confidence=0.9, violations=[], metadata={}) + + # Mock safe responses + mock_checker.check_safety.return_value = SafetyResult(is_safe=True, confidence=0.9, violations=[], metadata={}) + + # Load small model + model_name = "sshleifer/tiny-gpt2" + model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + + # Create safety configuration + safety_config = SafetyConfig.from_checker(mock_checker) + + # Test sampling with safety + inputs = tokenizer("Hello", return_tensors="pt") + outputs = model.generate(**inputs, max_length=15, do_sample=True, temperature=0.8, safety_config=safety_config) + + # Verify generation occurred + self.assertGreater(outputs.shape[1], inputs["input_ids"].shape[1]) + mock_checker.check_safety.assert_called() + + @require_torch + @slow + def test_beam_search_generation_with_safety(self): + """Test that safety works with beam search generation.""" + mock_checker = self._create_mock_checker() + mock_checker.check_safety.return_value = SafetyResult(is_safe=True, confidence=0.9, violations=[], metadata={}) + + # Mock safe responses + mock_checker.check_safety.return_value = SafetyResult(is_safe=True, confidence=0.9, violations=[], metadata={}) + + # Load small model + model_name = "sshleifer/tiny-gpt2" + model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + + # Create safety configuration + safety_config = SafetyConfig.from_checker(mock_checker) + + # Test beam search with safety + inputs = tokenizer("The weather is", return_tensors="pt") + outputs = model.generate(**inputs, max_length=15, num_beams=2, safety_config=safety_config) + + # Verify generation occurred + self.assertGreater(outputs.shape[1], inputs["input_ids"].shape[1]) + mock_checker.check_safety.assert_called() + + @require_torch + @slow + def test_safety_blocks_toxic_generation(self): + """Test that generation stops when toxic content is detected.""" + mock_checker = self._create_mock_checker() + mock_checker.check_safety.return_value = SafetyResult(is_safe=True, confidence=0.9, violations=[], metadata={}) + + # Mock unsafe response that should stop generation + mock_checker.check_safety.return_value = SafetyResult( + is_safe=False, + confidence=0.85, + violations=[SafetyViolation("toxicity", 0.85, "high", "Toxic content detected")], + metadata={"toxicity_score": 0.85}, + ) + + # Load small model + model_name = "sshleifer/tiny-gpt2" + model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + + # Create safety configuration + safety_config = SafetyConfig.from_checker(mock_checker) + + # Test generation - should stop early due to safety + inputs = tokenizer("Test input", return_tensors="pt") + outputs = model.generate( + **inputs, + max_length=50, # Allow long generation + safety_config=safety_config, + ) + + # Should stop early due to safety stopping criteria + # (The exact length depends on when safety check triggers) + self.assertLessEqual(outputs.shape[1], 50) + mock_checker.check_safety.assert_called() + + @require_torch + @slow + def test_safety_disabled_backward_compatibility(self): + """Test that safety disabled doesn't affect normal generation.""" + # No safety mocks needed - testing disabled safety + + # Load small model + model_name = "sshleifer/tiny-gpt2" + model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + + # Test without safety config (default behavior) + inputs = tokenizer("Hello world", return_tensors="pt") + outputs_no_safety = model.generate(**inputs, max_length=20, do_sample=False) + + # Test with disabled safety config + safety_config = SafetyConfig(enabled=False, checker=None) + outputs_disabled_safety = model.generate(**inputs, max_length=20, do_sample=False, safety_config=safety_config) + + # Results should be identical (since both use no safety) + # Note: Results might not be exactly identical due to random state, + # but both should generate successfully + self.assertEqual(outputs_no_safety.shape, outputs_disabled_safety.shape) + + @require_torch + @slow + def test_performance_impact_measurement(self): + """Test that safety overhead is reasonable.""" + # Load small model + model_name = "sshleifer/tiny-gpt2" + model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + + inputs = tokenizer("Performance test", return_tensors="pt") + + # Measure baseline (no safety) + start_time = time.time() + for _ in range(3): # Multiple runs for more stable timing + model.generate(**inputs, max_length=20, do_sample=False) + baseline_time = time.time() - start_time + + # Set up safety mocks for performance test + mock_checker = self._create_mock_checker() + mock_checker.check_safety.return_value = SafetyResult(is_safe=True, confidence=0.9, violations=[], metadata={}) + mock_checker.check_safety.return_value = SafetyResult(is_safe=True, confidence=0.9, violations=[], metadata={}) + + # Measure with safety enabled + safety_config = SafetyConfig.from_checker(mock_checker) + + start_time = time.time() + for _ in range(3): # Multiple runs for more stable timing + model.generate(**inputs, max_length=20, do_sample=False, safety_config=safety_config) + safety_time = time.time() - start_time + + # Calculate overhead percentage + overhead_percent = ((safety_time - baseline_time) / baseline_time) * 100 + + # Assert that overhead is reasonable (less than 50% for this simple test) + # Note: In real usage, overhead would be much less due to check_interval optimization + self.assertLess(overhead_percent, 50, f"Safety overhead of {overhead_percent:.1f}% is too high") + + print(f"Safety overhead: {overhead_percent:.1f}%") diff --git a/tests/generation/test_safety_integration.py b/tests/generation/test_safety_integration.py new file mode 100644 index 000000000000..0496dffe4b58 --- /dev/null +++ b/tests/generation/test_safety_integration.py @@ -0,0 +1,498 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest +from pathlib import Path +from unittest.mock import Mock, patch + +import torch + + +# Add examples directory to Python path to import BasicToxicityChecker +examples_path = Path(__file__).parent.parent.parent / "examples" +if str(examples_path) not in sys.path: + sys.path.insert(0, str(examples_path)) + +from safe_generation import BasicToxicityChecker # noqa: E402 + +from transformers.generation.configuration_utils import GenerationConfig # noqa: E402 +from transformers.generation.safety import ( # noqa: E402 + LENIENT_PRESET, + MODERATE_PRESET, + STRICT_PRESET, + SafetyChecker, + SafetyConfig, + SafetyResult, + SafetyViolation, +) +from transformers.generation.safety.processors import SafetyLogitsProcessor, SafetyStoppingCriteria # noqa: E402 +from transformers.testing_utils import require_torch # noqa: E402 + + +class TestSafetyIntegration(unittest.TestCase): + """Integration tests for the complete safety checking workflow.""" + + def setUp(self): + """Set up mock safety checker for tests.""" + self.mock_checker = Mock(spec=SafetyChecker) + self.mock_checker.check_safety.return_value = SafetyResult( + is_safe=True, confidence=0.9, violations=[], metadata={} + ) + self.mock_checker.supported_categories = ["toxicity"] + + def test_complete_safety_workflow(self): + """Test end-to-end safety checking workflow from configuration to results.""" + # Step 1: Create and validate configuration + config = SafetyConfig.from_checker(self.mock_checker, **STRICT_PRESET) + config.validate() + + # Verify configuration is set up correctly with STRICT preset values + self.assertTrue(config.enabled) + self.assertEqual(config.cache_size, 50) # STRICT_PRESET value + self.assertEqual(config.unsafe_hash_limit, 500) # STRICT_PRESET value + self.assertTrue(config.return_violations) # STRICT_PRESET value + self.assertTrue(config.return_metadata) # STRICT_PRESET value + + # Step 2: Test configuration serialization workflow + config_dict = config.to_dict() + restored_config = SafetyConfig.from_dict(config_dict) + restored_config.validate() + + # Verify serialization preserved configuration (except checker which isn't serialized) + self.assertEqual(config.cache_size, restored_config.cache_size) + self.assertEqual(config.enabled, restored_config.enabled) + self.assertEqual(config.return_violations, restored_config.return_violations) + self.assertIsNone(restored_config.checker) # Checker not serialized + + # Step 3: Test construct_checker returns the provided instance + retrieved_checker = config.construct_checker() + self.assertIs(retrieved_checker, self.mock_checker) + + @require_torch + @patch("transformers.AutoTokenizer.from_pretrained") + @patch("transformers.AutoModelForSequenceClassification.from_pretrained") + def test_config_to_checker_integration(self, mock_model, mock_tokenizer): + """Test creating checker instance and using it with SafetyConfig.""" + # Set up mocks + mock_tokenizer_instance = Mock() + mock_inputs = Mock() + mock_inputs.to.return_value = mock_inputs + mock_tokenizer_instance.return_value = mock_inputs + mock_tokenizer.return_value = mock_tokenizer_instance + + mock_model_instance = Mock() + mock_model_instance.eval.return_value = None + mock_model_instance.to.return_value = None + mock_model.return_value = mock_model_instance + + # User creates checker instance + checker = BasicToxicityChecker(threshold=0.8) + + # Verify checker was created with correct configuration + self.assertEqual(checker.threshold, 0.8) + self.assertEqual(checker.model_name, "s-nlp/roberta_toxicity_classifier") # Default + self.assertEqual(checker.supported_categories, ["toxicity"]) + + # Create SafetyConfig from checker instance (recommended pattern) + config = SafetyConfig.from_checker(checker, return_violations=True) + + # Verify config was created correctly + self.assertTrue(config.enabled) + self.assertIs(config.checker, checker) + self.assertTrue(config.return_violations) + + # Test that construct_checker returns the same instance + retrieved_checker = config.construct_checker() + self.assertIs(retrieved_checker, checker) + + # Test checker configuration serialization + checker_config_dict = checker.get_config() + expected_config = { + "checker_type": "BasicToxicityChecker", + "model_name": "s-nlp/roberta_toxicity_classifier", + "threshold": 0.8, + "device": checker.device, + } + self.assertEqual(checker_config_dict, expected_config) + + def test_utility_functions_integration(self): + """Test integration of utility functions with configurations.""" + from transformers.generation.safety.utils import validate_safety_config + + # Test validation utility with various configurations + configs_to_test = [ + SafetyConfig(), # Default + SafetyConfig.from_checker(self.mock_checker, **STRICT_PRESET), + SafetyConfig.from_checker(self.mock_checker, **MODERATE_PRESET), + SafetyConfig.from_checker(self.mock_checker, **LENIENT_PRESET), + ] + + for config in configs_to_test: + self.assertTrue(validate_safety_config(config)) + + # Test with invalid configuration (invalid cache_size) + with self.assertRaises(ValueError): + # __post_init__ will raise ValueError for invalid cache_size + SafetyConfig(cache_size=0) + + def test_safety_result_structure(self): + """Test that SafetyResult and SafetyViolation work correctly together.""" + # Create a violation + violation = SafetyViolation( + category="toxicity", + confidence=0.85, + severity="high", + description="Detected toxic content with 85% confidence", + ) + + # Create a safety result + result = SafetyResult( + is_safe=False, + confidence=0.85, + violations=[violation], + metadata={"model_name": "unitary/toxic-bert", "toxicity_score": 0.85, "threshold": 0.7}, + ) + + # Verify structure + self.assertFalse(result.is_safe) + self.assertEqual(result.confidence, 0.85) + self.assertEqual(len(result.violations), 1) + + violation = result.violations[0] + self.assertEqual(violation.category, "toxicity") + self.assertEqual(violation.confidence, 0.85) + self.assertEqual(violation.severity, "high") + + # Test metadata + self.assertIn("model_name", result.metadata) + self.assertEqual(result.metadata["threshold"], 0.7) + + def test_configuration_levels_produce_different_behaviors(self): + """Test that different preset levels produce appropriate settings.""" + # Test all predefined presets + strict = SafetyConfig.from_checker(self.mock_checker, **STRICT_PRESET) + moderate = SafetyConfig.from_checker(self.mock_checker, **MODERATE_PRESET) + lenient = SafetyConfig.from_checker(self.mock_checker, **LENIENT_PRESET) + + # Verify cache sizes are different and logical (strict < moderate < lenient) + self.assertEqual(strict.cache_size, 50) + self.assertEqual(moderate.cache_size, 100) + self.assertEqual(lenient.cache_size, 200) + self.assertLess(strict.cache_size, moderate.cache_size) + self.assertLess(moderate.cache_size, lenient.cache_size) + + # Verify unsafe hash limits follow same pattern + self.assertEqual(strict.unsafe_hash_limit, 500) + self.assertEqual(moderate.unsafe_hash_limit, 1000) + self.assertEqual(lenient.unsafe_hash_limit, 2000) + self.assertLess(strict.unsafe_hash_limit, moderate.unsafe_hash_limit) + self.assertLess(moderate.unsafe_hash_limit, lenient.unsafe_hash_limit) + + # Verify output configuration differences + self.assertTrue(strict.return_violations) + self.assertTrue(strict.return_metadata) + + self.assertFalse(moderate.return_violations) + self.assertFalse(lenient.return_violations) + + def test_error_handling_throughout_workflow(self): + """Test error handling across the complete workflow.""" + # Test configuration validation errors - invalid cache_size + with self.assertRaises(ValueError): + SafetyConfig(cache_size=-1) + + # Test configuration validation errors - invalid unsafe_hash_limit + with self.assertRaises(ValueError): + SafetyConfig(unsafe_hash_limit=0) + + # Test construct_checker without providing checker raises error + config = SafetyConfig(enabled=True) + with self.assertRaises(ValueError) as context: + config.construct_checker() + self.assertIn("SafetyConfig requires a checker instance", str(context.exception)) + + # Test invalid return_violations type + with self.assertRaises(ValueError) as context: + config = SafetyConfig(return_violations="true") # Wrong type + config.validate() + self.assertIn("return_violations must be a boolean", str(context.exception)) + + def test_public_api_imports(self): + """Test that all public API components can be imported correctly.""" + # Test core imports + from transformers.generation.safety import SafetyChecker, SafetyConfig + + # Verify classes are properly available + self.assertTrue(hasattr(SafetyChecker, "check_safety")) + self.assertTrue(hasattr(SafetyChecker, "supported_categories")) + + # Test SafetyConfig factory + config = SafetyConfig.from_checker(self.mock_checker, **MODERATE_PRESET) + self.assertIsInstance(config, SafetyConfig) + + # Test torch-dependent import + from transformers.utils import is_torch_available + + # Note: BasicToxicityChecker is a reference implementation in examples/safe_generation + # Core transformers only provides the SafetyChecker ABC + if is_torch_available(): + # Verify BasicToxicityChecker is available from examples + from safe_generation import BasicToxicityChecker + + self.assertTrue(issubclass(BasicToxicityChecker, SafetyChecker)) + + +class TestGenerationConfigIntegration(unittest.TestCase): + """Tests for safety integration with GenerationConfig and generation pipeline.""" + + def setUp(self): + """Set up mock safety checker for tests.""" + self.mock_checker = Mock(spec=SafetyChecker) + self.mock_checker.check_safety.return_value = SafetyResult( + is_safe=True, confidence=0.9, violations=[], metadata={} + ) + self.mock_checker.supported_categories = ["toxicity"] + + def test_generation_config_accepts_safety_config(self): + """Test that GenerationConfig properly accepts and stores safety_config.""" + safety_config = SafetyConfig.from_checker(self.mock_checker) + + # Test direct parameter + gen_config = GenerationConfig(max_length=100, safety_config=safety_config) + + self.assertIsNotNone(gen_config.safety_config) + self.assertEqual(gen_config.safety_config.enabled, True) + # Check preset fields instead of non-existent thresholds + self.assertEqual(gen_config.safety_config.cache_size, 100) # MODERATE_PRESET default + + # Test None safety_config + gen_config_none = GenerationConfig(max_length=100) + self.assertIsNone(gen_config_none.safety_config) + + # Test update method + gen_config_update = GenerationConfig(max_length=100) + gen_config_update.update(safety_config=safety_config) + self.assertIsNotNone(gen_config_update.safety_config) + + @require_torch + @patch("safe_generation.BasicToxicityChecker") + def test_generation_mixin_creates_safety_processors(self, mock_checker_class): + """Test that GenerationMixin creates safety processors when configured.""" + # Mock the checker + mock_checker = Mock() + mock_checker.check_safety.return_value = SafetyResult(is_safe=True, confidence=0.9, violations=[], metadata={}) + mock_checker_class.return_value = mock_checker + + # Create a simple model mock with GenerationMixin methods + from transformers.generation.utils import GenerationMixin + + model = Mock(spec=GenerationMixin) + model.config = Mock() + model.config.vocab_size = 1000 + model.device = torch.device("cpu") + + # Add the methods and required attributes + model._create_safety_processor = GenerationMixin._create_safety_processor.__get__(model) + model.tokenizer = Mock() # Add tokenizer mock + + # Mock tokenizer methods + model.tokenizer.decode = Mock(return_value="test text") + model.tokenizer.convert_tokens_to_ids = Mock(return_value=123) + model.tokenizer.unk_token_id = 0 + + # Test with safety enabled + mock_checker_instance = Mock(spec=SafetyChecker) + safety_config = SafetyConfig.from_checker(mock_checker_instance) + + # Test logits processor creation + logits_processor = model._create_safety_processor(safety_config, "logits") + self.assertIsInstance(logits_processor, SafetyLogitsProcessor) + + # Test stopping criteria creation + stopping_criteria = model._create_safety_processor(safety_config, "stopping") + self.assertIsInstance(stopping_criteria, SafetyStoppingCriteria) + + # Test with safety disabled + disabled_config = SafetyConfig(enabled=False) + self.assertIsNone(model._create_safety_processor(disabled_config, "logits")) + self.assertIsNone(model._create_safety_processor(disabled_config, "stopping")) + + # Test with None config + self.assertIsNone(model._create_safety_processor(None, "logits")) + + @require_torch + @patch("safe_generation.BasicToxicityChecker") + def test_logits_processor_integration(self, mock_checker_class): + """Test integration of safety with logits processor pipeline.""" + # Mock checker + mock_checker = Mock() + mock_checker.check_safety.return_value = SafetyResult( + is_safe=False, + confidence=0.9, + violations=[SafetyViolation("toxicity", 0.9, "high", "Toxic content detected")], + metadata={}, + ) + mock_checker_class.return_value = mock_checker + + # Create processor + safety_config = SafetyConfig.from_checker(self.mock_checker) + + # Mock tokenizer + mock_tokenizer = Mock() + mock_tokenizer.decode.return_value = "test text" + mock_tokenizer.convert_tokens_to_ids.return_value = 123 + mock_tokenizer.unk_token_id = 0 + + processor = SafetyLogitsProcessor( + safety_checker=mock_checker, tokenizer=mock_tokenizer, safety_config=safety_config + ) + + # Create test data + batch_size = 2 + vocab_size = 1000 + sequence_length = 5 + + input_ids = torch.randint(0, vocab_size, (batch_size, sequence_length)) + scores = torch.randn(batch_size, vocab_size) + + # Process scores + processed_scores = processor(input_ids, scores) + + # Verify scores were modified (top tokens should be suppressed) + self.assertFalse(torch.equal(scores, processed_scores)) + + # Verify checker was called + mock_checker.check_safety.assert_called() + + @require_torch + @patch("safe_generation.BasicToxicityChecker") + def test_stopping_criteria_integration(self, mock_checker_class): + """Test integration of safety with stopping criteria pipeline.""" + # Mock checker with unsafe result + mock_checker = Mock() + mock_checker.check_safety.return_value = SafetyResult( + is_safe=False, + confidence=0.9, + violations=[SafetyViolation("toxicity", 0.9, "high", "Toxic content")], + metadata={}, + ) + mock_checker_class.return_value = mock_checker + + # Create stopping criteria + safety_config = SafetyConfig.from_checker(self.mock_checker) + + # Mock tokenizer + mock_tokenizer = Mock() + mock_tokenizer.decode.return_value = "test text" + + criteria = SafetyStoppingCriteria( + safety_checker=mock_checker, tokenizer=mock_tokenizer, safety_config=safety_config + ) + + # Create test data + batch_size = 2 + vocab_size = 1000 + sequence_length = 10 + + input_ids = torch.randint(0, vocab_size, (batch_size, sequence_length)) + scores = torch.randn(batch_size, vocab_size) + + # Test stopping decision + should_stop = criteria(input_ids, scores) + + # Should stop due to unsafe content + self.assertTrue(should_stop.any()) + + # Verify checker was called + mock_checker.check_safety.assert_called() + + def test_backward_compatibility(self): + """Test that existing generation code works without safety configuration.""" + # Test GenerationConfig without safety + gen_config = GenerationConfig(max_length=100, temperature=0.8, top_p=0.9) + + self.assertIsNone(gen_config.safety_config) + self.assertEqual(gen_config.max_length, 100) + self.assertEqual(gen_config.temperature, 0.8) + + # Test that to_dict/from_dict works + config_dict = gen_config.to_dict() + restored = GenerationConfig.from_dict(config_dict) + + self.assertEqual(restored.max_length, 100) + self.assertIsNone(restored.safety_config) + + def test_safety_config_serialization_in_generation_config(self): + """Test that safety_config is properly serialized with GenerationConfig.""" + safety_config = SafetyConfig.from_checker(self.mock_checker, return_violations=True) + + gen_config = GenerationConfig(max_length=100, safety_config=safety_config) + + # Test to_dict + config_dict = gen_config.to_dict() + self.assertIn("safety_config", config_dict) + + # Test from_dict + restored = GenerationConfig.from_dict(config_dict) + self.assertIsNotNone(restored.safety_config) + self.assertEqual(restored.safety_config.enabled, True) + self.assertTrue(restored.safety_config.return_violations) + + def test_error_handling_in_generation_integration(self): + """Test error handling in generation pipeline integration.""" + # Test invalid safety config type + with self.assertRaises((TypeError, AttributeError)): + GenerationConfig(safety_config="invalid") + + # Test invalid processor type + from transformers.generation.utils import GenerationMixin + + model = Mock(spec=GenerationMixin) + model._create_safety_processor = GenerationMixin._create_safety_processor.__get__(model) + model.tokenizer = Mock() # Add tokenizer mock + + # Create config with mock checker + safety_config = SafetyConfig.from_checker(self.mock_checker) + + # Should raise ValueError for invalid processor type + with self.assertRaises(ValueError) as context: + model._create_safety_processor(safety_config, "invalid_type") + self.assertIn("processor_type must be 'logits' or 'stopping'", str(context.exception)) + + @require_torch + def test_end_to_end_safety_integration(self): + """Test complete end-to-end safety integration workflow.""" + # Create safety configuration + safety_config = SafetyConfig.from_checker(self.mock_checker) + + # Create generation configuration with safety + gen_config = GenerationConfig(max_length=50, temperature=0.8, safety_config=safety_config) + + # Verify safety config is properly stored + self.assertIsNotNone(gen_config.safety_config) + self.assertEqual(gen_config.safety_config.enabled, True) + + # Test serialization round-trip + config_dict = gen_config.to_dict() + restored_config = GenerationConfig.from_dict(config_dict) + + self.assertIsNotNone(restored_config.safety_config) + self.assertEqual(restored_config.safety_config.enabled, True) + self.assertEqual(restored_config.safety_config.cache_size, safety_config.cache_size) + + # Verify non-safety parameters are preserved + self.assertEqual(restored_config.max_length, 50) + self.assertEqual(restored_config.temperature, 0.8) diff --git a/tests/generation/test_safety_processors.py b/tests/generation/test_safety_processors.py new file mode 100644 index 000000000000..793caf52cbe0 --- /dev/null +++ b/tests/generation/test_safety_processors.py @@ -0,0 +1,1205 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import Mock + +import torch + +from transformers.generation.safety import ( + LENIENT_PRESET, + MODERATE_PRESET, + STRICT_PRESET, + SafetyConfig, + SafetyMetrics, + SafetyResult, + SafetyState, + SafetyViolation, +) +from transformers.generation.safety.processors import ( + SafetyLogitsProcessor, + SafetyStoppingCriteria, + _generate_cache_key, +) +from transformers.testing_utils import require_torch + + +@require_torch +class TestSafetyLogitsProcessor(unittest.TestCase): + """Test SafetyLogitsProcessor functionality.""" + + def setUp(self): + """Set up test fixtures.""" + # Mock safety checker + self.mock_checker = Mock() + self.mock_checker.check_safety.return_value = SafetyResult( + is_safe=True, confidence=0.9, violations=[], metadata={} + ) + + # Mock tokenizer + self.mock_tokenizer = Mock() + self.mock_tokenizer.decode.return_value = "test text" + + # Safety config + self.safety_config = SafetyConfig.from_checker(self.mock_checker) + + def test_safe_content_no_suppression(self): + """Test that safe content passes through without modification.""" + processor = SafetyLogitsProcessor( + safety_checker=self.mock_checker, tokenizer=self.mock_tokenizer, safety_config=self.safety_config + ) + + # Test safe content (mock already returns safe result) + input_ids = torch.tensor([[1, 2, 3, 4]]) + scores = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + original_scores = scores.clone() + + # Process + modified_scores = processor(input_ids, scores) + + # Scores should be unchanged for safe content + torch.testing.assert_close(modified_scores, original_scores) + + # Verify safety check was called + self.mock_checker.check_safety.assert_called_once() + + def test_unsafe_content_blocking(self): + """Test that unsafe content gets all tokens suppressed (blocking).""" + processor = SafetyLogitsProcessor( + safety_checker=self.mock_checker, tokenizer=self.mock_tokenizer, safety_config=self.safety_config + ) + + # Mock unsafe result + self.mock_checker.check_safety.return_value = SafetyResult( + is_safe=False, confidence=0.8, violations=[SafetyViolation("toxicity", 0.8, "high")], metadata={} + ) + + # Test data + input_ids = torch.tensor([[1, 2, 3, 4]]) + scores = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + vocab_size = scores.shape[-1] + + # Process + modified_scores = processor(input_ids, scores) + + # All tokens should be suppressed (blocking strategy) + for i in range(vocab_size): + self.assertEqual(modified_scores[0, i], float("-inf")) + + def test_check_interval(self): + """Test that safety checking respects check_interval parameter.""" + processor = SafetyLogitsProcessor( + safety_checker=self.mock_checker, + tokenizer=self.mock_tokenizer, + safety_config=self.safety_config, + check_interval=3, # Only check every 3rd call + ) + + input_ids = torch.tensor([[1, 2, 3, 4]]) + scores = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + + # First call (step 1) - no check + processor(input_ids, scores) + self.assertEqual(self.mock_checker.check_safety.call_count, 0) + + # Second call (step 2) - no check + processor(input_ids, scores) + self.assertEqual(self.mock_checker.check_safety.call_count, 0) + + # Third call (step 3) - check should happen + processor(input_ids, scores) + self.assertEqual(self.mock_checker.check_safety.call_count, 1) + + def test_batch_processing(self): + """Test that processor handles batched inputs correctly.""" + processor = SafetyLogitsProcessor( + safety_checker=self.mock_checker, tokenizer=self.mock_tokenizer, safety_config=self.safety_config + ) + + # Mock mixed safety results for batch + def mock_check_safety(text): + if "unsafe" in text: + return SafetyResult( + is_safe=False, confidence=0.8, violations=[SafetyViolation("toxicity", 0.8, "high")], metadata={} + ) + else: + return SafetyResult(is_safe=True, confidence=0.9, violations=[], metadata={}) + + self.mock_checker.check_safety.side_effect = mock_check_safety + + # Mock tokenizer to return different text for different sequences + def mock_decode(sequence, skip_special_tokens=True): + if torch.equal(sequence, torch.tensor([1, 2, 3, 4])): + return "safe text" + else: + return "unsafe text" + + self.mock_tokenizer.decode.side_effect = mock_decode + + # Batch with mixed safety + input_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) # [safe, unsafe] + scores = torch.tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]) + + # Process + modified_scores = processor(input_ids, scores) + + # First sequence (safe) should be unchanged + torch.testing.assert_close(modified_scores[0], scores[0]) + + # Second sequence (unsafe) should be suppressed + for i in range(scores.shape[-1]): + self.assertEqual(modified_scores[1, i], float("-inf")) + + def test_safety_check_failure(self): + """Test graceful handling when safety check fails.""" + processor = SafetyLogitsProcessor( + safety_checker=self.mock_checker, tokenizer=self.mock_tokenizer, safety_config=self.safety_config + ) + + # Mock safety checker to raise exception + self.mock_checker.check_safety.side_effect = Exception("Safety check failed") + + input_ids = torch.tensor([[1, 2, 3, 4]]) + scores = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + + # Process - should not raise, should assume unsafe and suppress + modified_scores = processor(input_ids, scores) + + # Should suppress all tokens when safety check fails (fail-safe behavior) + for i in range(scores.shape[-1]): + self.assertEqual(modified_scores[0, i], float("-inf")) + + +@require_torch +class TestSafetyStoppingCriteria(unittest.TestCase): + """Test SafetyStoppingCriteria functionality.""" + + def setUp(self): + """Set up test fixtures.""" + # Mock safety checker + self.mock_checker = Mock() + self.mock_checker.check_safety.return_value = SafetyResult( + is_safe=True, confidence=0.9, violations=[], metadata={} + ) + + # Mock tokenizer + self.mock_tokenizer = Mock() + self.mock_tokenizer.decode.return_value = "test text" + + # Safety config + self.safety_config = SafetyConfig.from_checker(self.mock_checker) + + def test_safe_content_continue_generation(self): + """Test that safe content allows generation to continue.""" + criteria = SafetyStoppingCriteria( + safety_checker=self.mock_checker, tokenizer=self.mock_tokenizer, safety_config=self.safety_config + ) + + input_ids = torch.tensor([[1, 2, 3, 4]]) + scores = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + + # Call stopping criteria + should_stop = criteria(input_ids, scores) + + # Should not stop for safe content + self.assertFalse(should_stop[0]) + self.mock_checker.check_safety.assert_called_once() + + def test_unsafe_content_stop_generation(self): + """Test that unsafe content stops generation.""" + criteria = SafetyStoppingCriteria( + safety_checker=self.mock_checker, tokenizer=self.mock_tokenizer, safety_config=self.safety_config + ) + + # Mock unsafe result + self.mock_checker.check_safety.return_value = SafetyResult( + is_safe=False, confidence=0.8, violations=[SafetyViolation("toxicity", 0.8, "high")], metadata={} + ) + + input_ids = torch.tensor([[1, 2, 3, 4]]) + scores = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + + # Call stopping criteria + should_stop = criteria(input_ids, scores) + + # Should stop for unsafe content + self.assertTrue(should_stop[0]) + + def test_check_final_only_mode(self): + """Test check_final_only parameter functionality.""" + criteria = SafetyStoppingCriteria( + safety_checker=self.mock_checker, + tokenizer=self.mock_tokenizer, + safety_config=self.safety_config, + check_final_only=True, + ) + + input_ids = torch.tensor([[1, 2, 3, 4]]) + scores = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + + # Call without is_final_call - should not check + should_stop = criteria(input_ids, scores) + self.assertFalse(should_stop[0]) + self.assertEqual(self.mock_checker.check_safety.call_count, 0) + + # Call with is_final_call=True - should check + should_stop = criteria(input_ids, scores, is_final_call=True) + self.assertFalse(should_stop[0]) # Safe content + self.assertEqual(self.mock_checker.check_safety.call_count, 1) + + def test_batch_stopping_criteria(self): + """Test stopping criteria with batched inputs.""" + criteria = SafetyStoppingCriteria( + safety_checker=self.mock_checker, tokenizer=self.mock_tokenizer, safety_config=self.safety_config + ) + + # Mock mixed safety results + def mock_check_safety(text): + if "unsafe" in text: + return SafetyResult( + is_safe=False, confidence=0.8, violations=[SafetyViolation("toxicity", 0.8, "high")], metadata={} + ) + else: + return SafetyResult(is_safe=True, confidence=0.9, violations=[], metadata={}) + + self.mock_checker.check_safety.side_effect = mock_check_safety + + # Mock tokenizer for batch + def mock_decode(sequence, skip_special_tokens=True): + if torch.equal(sequence, torch.tensor([1, 2, 3, 4])): + return "safe text" + else: + return "unsafe text" + + self.mock_tokenizer.decode.side_effect = mock_decode + + # Batch input + input_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) # [safe, unsafe] + scores = torch.tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]) + + # Call stopping criteria + should_stop = criteria(input_ids, scores) + + # First sequence (safe) should continue, second (unsafe) should stop + self.assertFalse(should_stop[0]) + self.assertTrue(should_stop[1]) + + def test_none_safety_checker_raises(self): + """Test that None safety_checker raises ValueError.""" + with self.assertRaises(ValueError): + SafetyStoppingCriteria( + safety_checker=None, tokenizer=self.mock_tokenizer, safety_config=self.safety_config + ) + + +@require_torch +class TestCacheKeyGeneration(unittest.TestCase): + """Test the SHA-256 cache key generation functionality.""" + + def test_cache_key_format(self): + """Test that cache keys follow the expected format.""" + text = "This is a test message" + cache_key = _generate_cache_key(text) + + # Should have format "length:hash" + parts = cache_key.split(":", 1) + self.assertEqual(len(parts), 2) + + # First part should be text length + self.assertEqual(parts[0], str(len(text))) + + # Second part should be a 64-character hex string (SHA-256) + self.assertEqual(len(parts[1]), 64) + self.assertTrue(all(c in "0123456789abcdef" for c in parts[1])) + + def test_cache_key_consistency(self): + """Test that same text produces same cache key.""" + text = "Consistent test message" + key1 = _generate_cache_key(text) + key2 = _generate_cache_key(text) + + self.assertEqual(key1, key2) + + def test_cache_key_uniqueness(self): + """Test that different texts produce different cache keys.""" + text1 = "First message" + text2 = "Second message" + text3 = "First messag" # Same length, different content + + key1 = _generate_cache_key(text1) + key2 = _generate_cache_key(text2) + key3 = _generate_cache_key(text3) + + # All keys should be different + self.assertNotEqual(key1, key2) + self.assertNotEqual(key1, key3) + self.assertNotEqual(key2, key3) + + def test_cache_key_different_lengths(self): + """Test that texts with different lengths have different cache keys.""" + short_text = "Short" + long_text = "This is a much longer text that should produce a different cache key" + + key1 = _generate_cache_key(short_text) + key2 = _generate_cache_key(long_text) + + self.assertNotEqual(key1, key2) + # Verify length prefixes are different + self.assertEqual(key1.split(":")[0], str(len(short_text))) + self.assertEqual(key2.split(":")[0], str(len(long_text))) + + def test_cache_key_empty_text(self): + """Test cache key generation for empty text.""" + empty_text = "" + cache_key = _generate_cache_key(empty_text) + + # Should still follow the format + parts = cache_key.split(":", 1) + self.assertEqual(len(parts), 2) + self.assertEqual(parts[0], "0") + self.assertEqual(len(parts[1]), 64) + + def test_cache_key_unicode_text(self): + """Test cache key generation for unicode text.""" + unicode_text = "Hello δΈ–η•Œ 🌍 cafΓ©" + cache_key = _generate_cache_key(unicode_text) + + # Should handle unicode properly + parts = cache_key.split(":", 1) + self.assertEqual(len(parts), 2) + self.assertEqual(parts[0], str(len(unicode_text))) + self.assertEqual(len(parts[1]), 64) + + # Should be consistent + key2 = _generate_cache_key(unicode_text) + self.assertEqual(cache_key, key2) + + def test_cache_key_collision_resistance(self): + """Test cache key collision resistance with similar texts.""" + texts = [ + "The quick brown fox", + "The quick brown fo", + "The quick brown fox ", # trailing space + " The quick brown fox", # leading space + "THE QUICK BROWN FOX", # different case + "The quick brown fox jumps", # extended + ] + + cache_keys = [_generate_cache_key(text) for text in texts] + + # All keys should be unique + self.assertEqual(len(cache_keys), len(set(cache_keys))) + + def test_cache_key_very_long_text(self): + """Test cache key generation for very long text.""" + # Create a long text + long_text = "Very long text " * 1000 + cache_key = _generate_cache_key(long_text) + + # Should still work and follow format + parts = cache_key.split(":", 1) + self.assertEqual(len(parts), 2) + self.assertEqual(parts[0], str(len(long_text))) + self.assertEqual(len(parts[1]), 64) + + +@require_torch +class TestSafetyMetrics(unittest.TestCase): + """Test the SafetyMetrics functionality.""" + + def test_metrics_initialization(self): + """Test that metrics initialize with correct default values.""" + metrics = SafetyMetrics() + + # Check all default values + self.assertEqual(metrics.total_generations, 0) + self.assertEqual(metrics.blocked_generations, 0) + self.assertEqual(metrics.suppression_events, 0) + self.assertEqual(metrics.cache_hits, 0) + self.assertEqual(metrics.cache_misses, 0) + self.assertEqual(metrics.total_safety_check_time_ms, 0.0) + self.assertEqual(metrics.safety_check_count, 0) + + def test_cache_hit_rate_calculation(self): + """Test cache hit rate calculation.""" + metrics = SafetyMetrics() + + # No operations - should be 0.0 + self.assertEqual(metrics.cache_hit_rate, 0.0) + + # Record some hits and misses + metrics.record_cache_hit() + metrics.record_cache_hit() + metrics.record_cache_miss() + + # Should be 66.67% (2 hits out of 3 total) + self.assertAlmostEqual(metrics.cache_hit_rate, 66.666666666666666, places=5) + + def test_avg_safety_check_time_calculation(self): + """Test average safety check time calculation.""" + metrics = SafetyMetrics() + + # No checks - should be 0.0 + self.assertEqual(metrics.avg_safety_check_time_ms, 0.0) + + # Record some checks + metrics.record_safety_check(10.0) + metrics.record_safety_check(20.0) + metrics.record_safety_check(30.0) + + # Should be 20.0ms average + self.assertEqual(metrics.avg_safety_check_time_ms, 20.0) + + def test_block_rate_calculation(self): + """Test block rate calculation.""" + metrics = SafetyMetrics() + + # No generations - should be 0.0 + self.assertEqual(metrics.block_rate, 0.0) + + # Record some generations + metrics.record_generation_attempt() + metrics.record_generation_attempt() + metrics.record_generation_attempt() + metrics.record_blocked_generation() + + # Should be 33.33% (1 blocked out of 3 total) + self.assertAlmostEqual(metrics.block_rate, 33.33333333333333, places=5) + + def test_metrics_recording_methods(self): + """Test all metrics recording methods.""" + metrics = SafetyMetrics() + + # Test safety check recording + metrics.record_safety_check(15.5) + self.assertEqual(metrics.safety_check_count, 1) + self.assertEqual(metrics.total_safety_check_time_ms, 15.5) + + # Test cache operations + metrics.record_cache_hit() + metrics.record_cache_miss() + self.assertEqual(metrics.cache_hits, 1) + self.assertEqual(metrics.cache_misses, 1) + + # Test generation tracking + metrics.record_generation_attempt() + metrics.record_blocked_generation() + self.assertEqual(metrics.total_generations, 1) + self.assertEqual(metrics.blocked_generations, 1) + + # Test suppression events + metrics.record_suppression_event() + self.assertEqual(metrics.suppression_events, 1) + + def test_metrics_to_dict(self): + """Test metrics export to dictionary.""" + metrics = SafetyMetrics() + + # Record some data + metrics.record_safety_check(10.0) + metrics.record_cache_hit() + metrics.record_generation_attempt() + metrics.record_suppression_event() + + result_dict = metrics.to_dict() + + # Check all expected keys are present + expected_keys = { + "total_generations", + "blocked_generations", + "suppression_events", + "cache_hits", + "cache_misses", + "cache_hit_rate", + "avg_safety_check_time_ms", + "block_rate", + "safety_check_count", + } + self.assertEqual(set(result_dict.keys()), expected_keys) + + # Check values + self.assertEqual(result_dict["total_generations"], 1) + self.assertEqual(result_dict["suppression_events"], 1) + self.assertEqual(result_dict["cache_hits"], 1) + self.assertEqual(result_dict["cache_hit_rate"], 100.0) + + def test_metrics_reset(self): + """Test metrics reset functionality.""" + metrics = SafetyMetrics() + + # Record some data + metrics.record_safety_check(10.0) + metrics.record_cache_hit() + metrics.record_generation_attempt() + metrics.record_suppression_event() + + # Verify data is present + self.assertGreater(metrics.safety_check_count, 0) + self.assertGreater(metrics.cache_hits, 0) + + # Reset + metrics.reset() + + # Verify all values are back to zero + self.assertEqual(metrics.total_generations, 0) + self.assertEqual(metrics.blocked_generations, 0) + self.assertEqual(metrics.suppression_events, 0) + self.assertEqual(metrics.cache_hits, 0) + self.assertEqual(metrics.cache_misses, 0) + self.assertEqual(metrics.total_safety_check_time_ms, 0.0) + self.assertEqual(metrics.safety_check_count, 0) + + def test_metrics_combine(self): + """Test combining metrics from multiple instances.""" + metrics1 = SafetyMetrics() + metrics2 = SafetyMetrics() + + # Record data in first instance + metrics1.record_safety_check(10.0) + metrics1.record_cache_hit() + metrics1.record_generation_attempt() + + # Record data in second instance + metrics2.record_safety_check(20.0) + metrics2.record_cache_miss() + metrics2.record_blocked_generation() + + # Combine them + combined = metrics1.combine(metrics2) + + # Check combined values + self.assertEqual(combined.safety_check_count, 2) + self.assertEqual(combined.total_safety_check_time_ms, 30.0) + self.assertEqual(combined.cache_hits, 1) + self.assertEqual(combined.cache_misses, 1) + self.assertEqual(combined.total_generations, 1) + self.assertEqual(combined.blocked_generations, 1) + + def test_logits_processor_metrics_integration(self): + """Test metrics integration with SafetyLogitsProcessor.""" + # Mock safety checker + mock_checker = Mock() + mock_checker.check_safety.return_value = SafetyResult( + is_safe=False, confidence=0.8, violations=[SafetyViolation("toxicity", 0.8, "high")], metadata={} + ) + + # Mock tokenizer + mock_tokenizer = Mock() + mock_tokenizer.decode.return_value = "test unsafe text" + + # Safety config + safety_config = SafetyConfig.from_checker(mock_checker) + + # Create processor + processor = SafetyLogitsProcessor( + safety_checker=mock_checker, tokenizer=mock_tokenizer, safety_config=safety_config + ) + + # Verify metrics are initialized + metrics = processor.get_metrics() + self.assertIsInstance(metrics, SafetyMetrics) + self.assertEqual(metrics.suppression_events, 0) + + # Process some data (this should trigger metrics recording) + input_ids = torch.tensor([[1, 2, 3, 4]]) + scores = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + + processor(input_ids, scores) + + # Check that metrics were recorded + metrics = processor.get_metrics() + self.assertGreater(metrics.safety_check_count, 0) + self.assertGreater(metrics.suppression_events, 0) # Should have suppression due to unsafe content + + def test_stopping_criteria_metrics_integration(self): + """Test metrics integration with SafetyStoppingCriteria.""" + # Mock safety checker + mock_checker = Mock() + mock_checker.check_safety.return_value = SafetyResult( + is_safe=False, confidence=0.8, violations=[SafetyViolation("toxicity", 0.8, "high")], metadata={} + ) + + # Mock tokenizer + mock_tokenizer = Mock() + mock_tokenizer.decode.return_value = "test unsafe text" + + # Safety config + safety_config = SafetyConfig.from_checker(mock_checker) + + # Create stopping criteria + criteria = SafetyStoppingCriteria( + safety_checker=mock_checker, tokenizer=mock_tokenizer, safety_config=safety_config + ) + + # Verify metrics are initialized + metrics = criteria.get_metrics() + self.assertIsInstance(metrics, SafetyMetrics) + self.assertEqual(metrics.total_generations, 0) + + # Process some data + input_ids = torch.tensor([[1, 2, 3, 4]]) + scores = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + + criteria(input_ids, scores) + + # Check that metrics were recorded + metrics = criteria.get_metrics() + self.assertGreater(metrics.total_generations, 0) + self.assertGreater(metrics.blocked_generations, 0) # Should have blocked generation + + def test_thread_safety_basic(self): + """Test basic thread safety of SafetyMetrics.""" + import threading + import time + + metrics = SafetyMetrics() + errors = [] + + def worker(): + try: + for i in range(100): + metrics.record_cache_hit() + metrics.record_safety_check(1.0) + time.sleep(0.001) # Small delay to encourage race conditions + except Exception as e: + errors.append(e) + + # Run multiple threads + threads = [] + for _ in range(5): + thread = threading.Thread(target=worker) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Should have no errors and correct counts + self.assertEqual(len(errors), 0, f"Thread safety errors: {errors}") + self.assertEqual(metrics.cache_hits, 500) # 5 threads * 100 operations + self.assertEqual(metrics.safety_check_count, 500) + + def test_hash_consistency(self): + """Test that hash inconsistency bug is fixed.""" + from transformers.generation.safety.processors import _generate_cache_key + + text1 = "This is a test message" + text2 = "This is a test message" # Same content + text3 = "Different message" + + # Same text should produce same hash + hash1 = _generate_cache_key(text1) + hash2 = _generate_cache_key(text2) + self.assertEqual(hash1, hash2) + + # Different text should produce different hash + hash3 = _generate_cache_key(text3) + self.assertNotEqual(hash1, hash3) + + # Hashes should be consistent across calls + for _ in range(10): + self.assertEqual(_generate_cache_key(text1), hash1) + + def test_cache_memory_management(self): + """Test that caches properly manage memory.""" + # Mock safety checker + mock_checker = Mock() + mock_checker.check_safety.return_value = SafetyResult(is_safe=True, confidence=0.9, violations=[], metadata={}) + + # Mock tokenizer + mock_tokenizer = Mock() + + # Safety config - disable incremental checking for this test to ensure all calls are made + safety_config = SafetyConfig.from_checker(mock_checker, incremental_checking=False) + + # Create processor + processor = SafetyLogitsProcessor( + safety_checker=mock_checker, tokenizer=mock_tokenizer, safety_config=safety_config + ) + + # Add many different sequences to test cache limits + for i in range(150): # More than default cache size of 100 + mock_tokenizer.decode.return_value = f"test text {i}" + input_ids = torch.tensor([[1, 2, 3, i]]) + scores = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + processor(input_ids, scores) + + # Cache should be limited and not grow unbounded + # The exact size check would depend on internal implementation + # but we can verify calls were made + self.assertEqual(mock_checker.check_safety.call_count, 150) + + def test_empty_and_special_text_handling(self): + """Test handling of edge case text inputs.""" + from transformers.generation.safety.processors import _generate_cache_key + + # Test edge cases + test_cases = [ + "", # Empty string + " ", # Single space + "\n\t", # Whitespace only + "πŸŒπŸš€πŸ’«", # Unicode emoji + "a" * 10000, # Very long string + "Test\x00null", # String with null byte + ] + + for text in test_cases: + try: + cache_key = _generate_cache_key(text) + # Should produce valid cache key + self.assertIsInstance(cache_key, str) + self.assertGreater(len(cache_key), 0) + # Should be consistent + self.assertEqual(cache_key, _generate_cache_key(text)) + except Exception as e: + self.fail(f"Failed to generate cache key for text: {repr(text)}, error: {e}") + + def test_device_mismatch_handling(self): + """Test handling when tensors are on different devices.""" + # Mock safety checker + mock_checker = Mock() + mock_checker.check_safety.return_value = SafetyResult( + is_safe=False, confidence=0.8, violations=[SafetyViolation("toxicity", 0.8, "high")], metadata={} + ) + + # Mock tokenizer + mock_tokenizer = Mock() + mock_tokenizer.decode.return_value = "unsafe text" + + # Safety config + safety_config = SafetyConfig.from_checker(mock_checker) + + # Create processor + processor = SafetyLogitsProcessor( + safety_checker=mock_checker, tokenizer=mock_tokenizer, safety_config=safety_config + ) + + # Test with tensors (simulate device mismatch without actually using CUDA) + input_ids = torch.tensor([[1, 2, 3, 4]]) + scores = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + + # Should not raise device mismatch errors + try: + result = processor(input_ids, scores) + self.assertEqual(result.shape, scores.shape) + except Exception as e: + self.fail(f"Device handling failed: {e}") + + def test_configurable_cache_size_logits_processor(self): + """Test that SafetyLogitsProcessor respects configured cache size.""" + # Mock safety checker + mock_checker = Mock() + mock_checker.check_safety.return_value = SafetyResult(is_safe=True, confidence=0.9, violations=[], metadata={}) + + # Mock tokenizer + mock_tokenizer = Mock() + + # Test small cache size + small_config = SafetyConfig.from_checker(mock_checker, cache_size=5) + processor = SafetyLogitsProcessor( + safety_checker=mock_checker, tokenizer=mock_tokenizer, safety_config=small_config + ) + + # Verify cache was initialized with correct size + self.assertEqual(processor._sequence_cache.max_size, 5) + + # Test large cache size + large_config = SafetyConfig.from_checker(mock_checker, cache_size=250) + processor = SafetyLogitsProcessor( + safety_checker=mock_checker, tokenizer=mock_tokenizer, safety_config=large_config + ) + + # Verify cache was initialized with correct size + self.assertEqual(processor._sequence_cache.max_size, 250) + + def test_configurable_cache_size_stopping_criteria(self): + """Test that SafetyStoppingCriteria respects configured cache and hash limits.""" + # Mock safety checker + mock_checker = Mock() + mock_checker.check_safety.return_value = SafetyResult(is_safe=True, confidence=0.9, violations=[], metadata={}) + + # Mock tokenizer + mock_tokenizer = Mock() + + # Test custom configuration + custom_config = SafetyConfig.from_checker(mock_checker, cache_size=30, unsafe_hash_limit=300) + + criteria = SafetyStoppingCriteria( + safety_checker=mock_checker, tokenizer=mock_tokenizer, safety_config=custom_config + ) + + # Verify cache and hash limit were configured correctly + self.assertEqual(criteria._sequence_cache.max_size, 30) + self.assertEqual(criteria._unsafe_hash_limit, 300) + + def test_default_cache_sizes_for_safety_levels(self): + """Test that different safety levels use appropriate cache sizes.""" + # Mock safety checker and tokenizer + mock_checker = Mock() + mock_checker.check_safety.return_value = SafetyResult(is_safe=True, confidence=0.9, violations=[], metadata={}) + mock_tokenizer = Mock() + + # Test strict configuration + strict_config = SafetyConfig.from_checker(mock_checker, **STRICT_PRESET) + processor = SafetyLogitsProcessor( + safety_checker=mock_checker, tokenizer=mock_tokenizer, safety_config=strict_config + ) + self.assertEqual(processor._sequence_cache.max_size, 50) + + criteria = SafetyStoppingCriteria( + safety_checker=mock_checker, tokenizer=mock_tokenizer, safety_config=strict_config + ) + self.assertEqual(criteria._unsafe_hash_limit, 500) + + # Test moderate configuration + moderate_config = SafetyConfig.from_checker(mock_checker, **MODERATE_PRESET) + processor = SafetyLogitsProcessor( + safety_checker=mock_checker, tokenizer=mock_tokenizer, safety_config=moderate_config + ) + self.assertEqual(processor._sequence_cache.max_size, 100) + + # Test lenient configuration + lenient_config = SafetyConfig.from_checker(mock_checker, **LENIENT_PRESET) + processor = SafetyLogitsProcessor( + safety_checker=mock_checker, tokenizer=mock_tokenizer, safety_config=lenient_config + ) + self.assertEqual(processor._sequence_cache.max_size, 200) + + def test_backward_compatibility_cache_size(self): + """Test that processors work with SafetyConfig without cache_size.""" + # Mock safety checker and tokenizer + mock_checker = Mock() + mock_checker.check_safety.return_value = SafetyResult(is_safe=True, confidence=0.9, violations=[], metadata={}) + mock_tokenizer = Mock() + + # Create a config that might not have cache_size (simulate old configs) + config = SafetyConfig.from_checker(mock_checker) + # Temporarily remove cache_size attribute to simulate old config + if hasattr(config, "cache_size"): + delattr(config, "cache_size") + + # Should still work with default cache size + processor = SafetyLogitsProcessor(safety_checker=mock_checker, tokenizer=mock_tokenizer, safety_config=config) + # Should use DEFAULT_CACHE_SIZE (100) + from transformers.generation.safety.processors import DEFAULT_CACHE_SIZE + + self.assertEqual(processor._sequence_cache.max_size, DEFAULT_CACHE_SIZE) + + def test_cache_size_edge_cases(self): + """Test edge cases for cache size configuration.""" + # Mock safety checker and tokenizer + mock_checker = Mock() + mock_checker.check_safety.return_value = SafetyResult(is_safe=True, confidence=0.9, violations=[], metadata={}) + mock_tokenizer = Mock() + + # Test minimum cache size (1) + min_config = SafetyConfig.from_checker(mock_checker, cache_size=1) + processor = SafetyLogitsProcessor( + safety_checker=mock_checker, tokenizer=mock_tokenizer, safety_config=min_config + ) + self.assertEqual(processor._sequence_cache.max_size, 1) + + # Test that processor works with cache size 1 + input_ids = torch.tensor([[1, 2, 3, 4]]) + scores = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + mock_tokenizer.decode.return_value = "test text" + + # Should not raise any errors + result = processor(input_ids, scores) + self.assertEqual(result.shape, scores.shape) + + +@require_torch +class TestSlidingWindowFunctionality(unittest.TestCase): + """Test sliding window and incremental checking functionality.""" + + def setUp(self): + """Set up test fixtures.""" + # Mock safety checker + self.mock_checker = Mock() + self.mock_tokenizer = Mock() + + def test_safety_state_initialization(self): + """Test SafetyState class initialization and basic functionality.""" + state = SafetyState() + + # Check initial values + self.assertEqual(state.last_check_position, 0) + self.assertIsNone(state.last_check_result) + self.assertEqual(state.sequence_prefix, "") + self.assertTrue(state.is_safe_so_far) + self.assertEqual(state.window_start_position, 0) + + def test_safety_state_incremental_check_logic(self): + """Test SafetyState incremental checking logic.""" + state = SafetyState() + + # First check should always be performed + self.assertTrue(state.should_check_incremental(0, min_new_tokens=5)) + self.assertTrue(state.should_check_incremental(10, min_new_tokens=5)) + + # Update state after first check + result = SafetyResult(is_safe=True, confidence=0.9, violations=[], metadata={}) + state.update_check_result(10, result, "first check") + + # Check with insufficient new tokens + self.assertFalse(state.should_check_incremental(14, min_new_tokens=5)) + + # Check with sufficient new tokens + self.assertTrue(state.should_check_incremental(15, min_new_tokens=5)) + + def test_safety_state_sliding_window(self): + """Test SafetyState sliding window extraction.""" + state = SafetyState() + full_text = "This is a very long text that should trigger sliding window behavior when it exceeds the configured window size limit." + + # Test without sliding window (disabled) + text_to_check, start_pos = state.get_incremental_text(full_text, sliding_window_size=-1) + self.assertEqual(text_to_check, full_text) + self.assertEqual(start_pos, 0) + + # Test with sliding window smaller than text + window_size = 50 + text_to_check, start_pos = state.get_incremental_text(full_text, sliding_window_size=window_size) + self.assertEqual(len(text_to_check), window_size) + self.assertEqual(text_to_check, full_text[-window_size:]) + self.assertEqual(start_pos, len(full_text) - window_size) + + # Test with sliding window larger than text + window_size = 200 + text_to_check, start_pos = state.get_incremental_text(full_text, sliding_window_size=window_size) + self.assertEqual(text_to_check, full_text) + self.assertEqual(start_pos, 0) + + def test_sliding_window_config_parameters(self): + """Test sliding window configuration parameters in SafetyConfig.""" + # Test default values + config = SafetyConfig() + self.assertEqual(config.sliding_window_size, 512) + self.assertTrue(config.incremental_checking) + + # Test custom values + config = SafetyConfig(sliding_window_size=256, incremental_checking=False) + self.assertEqual(config.sliding_window_size, 256) + self.assertFalse(config.incremental_checking) + + # Test serialization includes new parameters + config_dict = config.to_dict() + self.assertEqual(config_dict["sliding_window_size"], 256) + self.assertEqual(config_dict["incremental_checking"], False) + + # Test deserialization + restored_config = SafetyConfig.from_dict(config_dict) + self.assertEqual(restored_config.sliding_window_size, 256) + self.assertFalse(restored_config.incremental_checking) + + def test_logits_processor_sliding_window_integration(self): + """Test SafetyLogitsProcessor with sliding window functionality.""" + # Setup mocks + self.mock_checker.check_safety.return_value = SafetyResult( + is_safe=True, confidence=0.9, violations=[], metadata={} + ) + + # Create long text that would exceed window + long_text = "This is a very long piece of text that should trigger the sliding window behavior. " * 10 + self.mock_tokenizer.decode.return_value = long_text + + # Test with sliding window enabled + config = SafetyConfig.from_checker( + self.mock_checker, + sliding_window_size=100, + incremental_checking=True, + ) + + processor = SafetyLogitsProcessor( + safety_checker=self.mock_checker, tokenizer=self.mock_tokenizer, safety_config=config + ) + + # Verify sliding window parameters are set + self.assertEqual(processor.sliding_window_size, 100) + self.assertTrue(processor.incremental_checking) + + # Test processing with sliding window + input_ids = torch.tensor([[1, 2, 3, 4, 5]]) + scores = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + + result = processor(input_ids, scores) + self.assertEqual(result.shape, scores.shape) + + # Verify safety check was called (though with potentially windowed text) + self.mock_checker.check_safety.assert_called() + + def test_stopping_criteria_sliding_window_integration(self): + """Test SafetyStoppingCriteria with sliding window functionality.""" + # Setup mocks + self.mock_checker.check_safety.return_value = SafetyResult( + is_safe=True, confidence=0.9, violations=[], metadata={} + ) + + long_text = "This is another very long piece of text for testing sliding window in stopping criteria. " * 10 + self.mock_tokenizer.decode.return_value = long_text + + # Test with sliding window enabled + config = SafetyConfig.from_checker( + self.mock_checker, + sliding_window_size=100, + incremental_checking=True, + ) + + criteria = SafetyStoppingCriteria( + safety_checker=self.mock_checker, tokenizer=self.mock_tokenizer, safety_config=config + ) + + # Verify sliding window parameters are set + self.assertEqual(criteria.sliding_window_size, 100) + self.assertTrue(criteria.incremental_checking) + + # Test processing + input_ids = torch.tensor([[1, 2, 3, 4, 5]]) + scores = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + + should_stop = criteria(input_ids, scores) + self.assertFalse(should_stop[0]) # Should not stop for safe content + + def test_incremental_checking_performance_benefit(self): + """Test that incremental checking reduces safety check calls.""" + # Setup mock to count calls + check_call_count = [0] + + def count_check_calls(text): + check_call_count[0] += 1 + return SafetyResult(is_safe=True, confidence=0.9, violations=[], metadata={}) + + self.mock_checker.check_safety.side_effect = count_check_calls + + # Create processor with incremental checking + config = SafetyConfig.from_checker(self.mock_checker, incremental_checking=True) + + processor = SafetyLogitsProcessor( + safety_checker=self.mock_checker, + tokenizer=self.mock_tokenizer, + safety_config=config, + check_interval=1, # Check every token + ) + + # Simulate progressive sequence building + sequences = ["Hello", "Hello world", "Hello world this", "Hello world this is", "Hello world this is a test"] + + for seq in sequences: + self.mock_tokenizer.decode.return_value = seq + input_ids = torch.tensor([[1] * len(seq.split())]) # Approximate tokens + scores = torch.randn(1, 1000) + processor(input_ids, scores) + + # With incremental checking, we should have fewer calls than sequences + # because short additions don't trigger new checks + print(f"Check calls made: {check_call_count[0]} out of {len(sequences)} sequences") + self.assertLessEqual(check_call_count[0], len(sequences)) + + def test_sliding_window_with_unsafe_content(self): + """Test sliding window behavior when unsafe content is detected.""" + # Setup mock to return unsafe result + self.mock_checker.check_safety.return_value = SafetyResult( + is_safe=False, + confidence=0.8, + violations=[SafetyViolation("toxicity", 0.8, "high", "Toxic content detected")], + metadata={}, + ) + + config = SafetyConfig.from_checker( + self.mock_checker, + sliding_window_size=50, + incremental_checking=True, + ) + + processor = SafetyLogitsProcessor( + safety_checker=self.mock_checker, tokenizer=self.mock_tokenizer, safety_config=config + ) + + self.mock_tokenizer.decode.return_value = "This contains toxic content that should be blocked" + + input_ids = torch.tensor([[1, 2, 3, 4, 5]]) + scores = torch.ones(1, 1000) # All tokens have same score + + result = processor(input_ids, scores) + + # All tokens should be suppressed (set to negative infinity) + self.assertTrue(torch.all(result < scores)) + self.assertTrue(torch.all(result == float("-inf"))) + + def test_prefix_cache_functionality(self): + """Test that prefix caching works correctly.""" + # This test verifies the _PrefixSafetyCache is used when incremental_checking=True + config = SafetyConfig.from_checker( + self.mock_checker, + incremental_checking=True, # Should use prefix cache + cache_size=50, + ) + + processor = SafetyLogitsProcessor( + safety_checker=self.mock_checker, tokenizer=self.mock_tokenizer, safety_config=config + ) + + # Verify correct cache type is used + from transformers.generation.safety.processors import _PrefixSafetyCache + + self.assertIsInstance(processor._sequence_cache, _PrefixSafetyCache) + + # Test with incremental_checking=False + config_no_incremental = SafetyConfig.from_checker( + self.mock_checker, + incremental_checking=False, # Should use simple cache + ) + + processor_simple = SafetyLogitsProcessor( + safety_checker=self.mock_checker, tokenizer=self.mock_tokenizer, safety_config=config_no_incremental + ) + + # Verify simple cache is used + from transformers.generation.safety.processors import _SafetyCache + + self.assertIsInstance(processor_simple._sequence_cache, _SafetyCache) + + def test_safety_state_reset_functionality(self): + """Test that safety states can be reset properly.""" + config = SafetyConfig.from_checker(self.mock_checker, incremental_checking=True) + + processor = SafetyLogitsProcessor( + safety_checker=self.mock_checker, tokenizer=self.mock_tokenizer, safety_config=config + ) + + # Process some sequences to populate safety states + self.mock_tokenizer.decode.return_value = "test text" + self.mock_checker.check_safety.return_value = SafetyResult( + is_safe=True, confidence=0.9, violations=[], metadata={} + ) + + input_ids = torch.tensor([[1, 2, 3, 4, 5]]) + scores = torch.randn(1, 1000) + processor(input_ids, scores) + + # Verify states were created + self.assertGreater(len(processor._safety_states), 0) + + # Reset states + processor.reset_safety_states() + + # Verify states were cleared + self.assertEqual(len(processor._safety_states), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/pipelines/test_text_generation_safety.py b/tests/pipelines/test_text_generation_safety.py new file mode 100644 index 000000000000..9199159d98cf --- /dev/null +++ b/tests/pipelines/test_text_generation_safety.py @@ -0,0 +1,108 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import pipeline +from transformers.generation.safety import SafetyChecker, SafetyConfig, SafetyResult, SafetyViolation +from transformers.testing_utils import require_torch, slow + + +class MockSafetyChecker(SafetyChecker): + """Mock safety checker for testing""" + + def __init__(self, is_safe=True, name="mock"): + self.is_safe = is_safe + self.name = name + self.check_safety_calls = [] + + def check_safety(self, text, **kwargs): + self.check_safety_calls.append(text) + return SafetyResult( + is_safe=self.is_safe, + confidence=0.9, + violations=[] if self.is_safe else [SafetyViolation("test", 0.9, "high", "Test violation")], + metadata={"checker": self.name}, + ) + + @property + def supported_categories(self): + return ["test"] + + +@require_torch +class TestTextGenerationPipelineSafety(unittest.TestCase): + """Tests for safety integration in TextGenerationPipeline""" + + def test_safety_config_per_call(self): + """Test passing safety_config per generate call""" + checker = MockSafetyChecker(is_safe=True) + config = SafetyConfig.from_checker(checker) + + pipe = pipeline("text-generation", model="sshleifer/tiny-gpt2") + result = pipe("Hello", safety_config=config, max_new_tokens=10) + + # Verify safety was applied + self.assertGreater(len(checker.check_safety_calls), 0) + self.assertIsNotNone(result) + + def test_safety_disabled_by_default(self): + """Test that safety is not applied when no config provided""" + pipe = pipeline("text-generation", model="sshleifer/tiny-gpt2") + result = pipe("Hello", max_new_tokens=10) + + # Should work normally without safety + self.assertIsNotNone(result) + self.assertEqual(len(result), 1) + self.assertIn("generated_text", result[0]) + + def test_unsafe_content_blocked(self): + """Test that unsafe content generation is blocked""" + checker = MockSafetyChecker(is_safe=False) # Always unsafe + config = SafetyConfig.from_checker(checker) + + pipe = pipeline("text-generation", model="sshleifer/tiny-gpt2") + result = pipe("Hello", safety_config=config, max_new_tokens=10, do_sample=False) + + # Generation should be stopped early due to safety + self.assertIsNotNone(result) + # Exact behavior depends on safety implementation + # But checker should have been called + self.assertGreater(len(checker.check_safety_calls), 0) + + def test_safety_with_batch(self): + """Test safety checking with batch input""" + checker = MockSafetyChecker(is_safe=True) + config = SafetyConfig.from_checker(checker) + + pipe = pipeline("text-generation", model="sshleifer/tiny-gpt2") + results = pipe(["Hello", "World"], safety_config=config, max_new_tokens=10) + + # Verify safety was applied to batch + self.assertGreater(len(checker.check_safety_calls), 0) + self.assertEqual(len(results), 2) + + @slow + def test_safety_with_actual_model(self): + """Test safety with actual model generation (slow test)""" + checker = MockSafetyChecker(is_safe=True) + config = SafetyConfig.from_checker(checker) + + pipe = pipeline("text-generation", model="gpt2") + result = pipe("The capital of France is", safety_config=config, max_new_tokens=5, do_sample=False) + + self.assertIsNotNone(result) + self.assertIn("generated_text", result[0]) + self.assertGreater(len(checker.check_safety_calls), 0)