-
Notifications
You must be signed in to change notification settings - Fork 682
FEAT: Add modality support detection for targets #1383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2a403c6
d72b2b4
6be87e6
4905948
3202098
94b9aa8
8ecc5bc
0d7277c
56b68c8
02f6f1e
cce207c
51a20bb
aa036f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| # Modality Test Assets | ||
|
|
||
| Benign, minimal test files used by `pyrit.prompt_target.modality_verification` to | ||
| verify which modalities a target actually supports at runtime. | ||
|
|
||
| - **test_image.png** — 1×1 white pixel PNG | ||
| - **test_audio.wav** — TTS-generated speech: "raccoons are extraordinary creatures" | ||
| - **test_video.mp4** — 1-frame, 16×16 solid color video | ||
|
|
||
| These are intentionally simple and non-controversial so they won't be blocked by | ||
| content filters during modality verification. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,7 @@ | |
|
|
||
| from pyrit.identifiers import ComponentIdentifier, Identifiable | ||
| from pyrit.memory import CentralMemory, MemoryInterface | ||
| from pyrit.models import Message | ||
| from pyrit.models import Message, PromptDataType | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -26,6 +26,17 @@ class PromptTarget(Identifiable): | |
| #: An empty list implies that the prompt target supports all converters. | ||
| supported_converters: List[Any] | ||
|
|
||
| #: Set of supported input modality combinations. | ||
| #: Each frozenset represents a valid combination of modalities that can be sent together. | ||
| #: For example: {frozenset(["text"]), frozenset(["text", "image_path"])} | ||
| #: means the target supports either text-only OR text+image combinations. | ||
| SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} | ||
|
|
||
| #: Set of supported output modality combinations. | ||
| #: Each frozenset represents a valid combination of modalities that can be returned. | ||
| #: Most targets currently only support text output. | ||
| SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} | ||
|
|
||
| _identifier: Optional[ComponentIdentifier] = None | ||
|
|
||
| def __init__( | ||
|
|
@@ -78,6 +89,52 @@ def _validate_request(self, *, message: Message) -> None: | |
| message: The message to validate. | ||
| """ | ||
|
|
||
| def input_modality_supported(self, modalities: set[PromptDataType]) -> bool: | ||
| """ | ||
| Check if a specific combination of input modalities is supported. | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My proposal here throws a bit of a wrench in this PR: https://github.com/Azure/PyRIT/pull/1433/changes#r2874969508 However, I still like a lot of things here. For example, I think this does a good job of setting what the default modalities should be. And these functions would still be useful. But it may need some updates based on how 1433 is tackled.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @fitzpr; hold for a second. I think once @romanlutz and I settle on design we can update this |
||
| Args: | ||
| modalities: Set of modality types to check (e.g., {"text", "image_path"}) | ||
|
|
||
| Returns: | ||
| True if this exact combination is supported, False otherwise | ||
| """ | ||
| modalities_frozen = frozenset(modalities) | ||
| return modalities_frozen in self.SUPPORTED_INPUT_MODALITIES | ||
|
|
||
| def output_modality_supported(self, modalities: set[PromptDataType]) -> bool: | ||
| """ | ||
| Check if a specific combination of output modalities is supported. | ||
| Most targets only support text output currently. | ||
|
|
||
| Args: | ||
| modalities: Set of modality types to check | ||
|
|
||
| Returns: | ||
| True if this exact combination is supported, False otherwise | ||
| """ | ||
| modalities_frozen = frozenset(modalities) | ||
| return modalities_frozen in self.SUPPORTED_OUTPUT_MODALITIES | ||
|
|
||
| async def verify_actual_modalities(self) -> set[frozenset[PromptDataType]]: | ||
| """ | ||
| Verify what modalities this target actually supports at runtime. | ||
|
|
||
| This optional verification tests the target with minimal requests to determine | ||
| actual capabilities, which may be a subset of the static API declarations. | ||
|
|
||
| Returns: | ||
| Set of actually supported input modality combinations | ||
|
|
||
| Example: | ||
| # Check what a specific OpenAI model actually supports | ||
| actual = await target.verify_actual_modalities() | ||
| # Returns: {frozenset(["text"])} or {frozenset(["text"]), frozenset(["text", "image_path"])} | ||
| """ | ||
| from pyrit.prompt_target.modality_verification import verify_target_modalities | ||
|
|
||
| return await verify_target_modalities(self) | ||
|
|
||
| def set_model_name(self, *, model_name: str) -> None: | ||
| """ | ||
| Set the model name for this target. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,156 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT license. | ||
|
|
||
| """ | ||
| Optional modality verification system for prompt targets. | ||
|
|
||
| This module provides runtime modality discovery to determine what modalities | ||
| a specific target actually supports, beyond what the API declares as possible. | ||
|
|
||
| Usage: | ||
| from pyrit.prompt_target.modality_verification import verify_target_modalities | ||
|
|
||
| # Get static API modalities | ||
| api_modalities = target.SUPPORTED_INPUT_MODALITIES | ||
|
|
||
| # Optionally verify actual model modalities | ||
| actual_modalities = await verify_target_modalities(target) | ||
| """ | ||
|
|
||
| import logging | ||
| import os | ||
| from typing import Optional | ||
|
|
||
| from pyrit.common.path import DATASETS_PATH | ||
| from pyrit.models import Message, MessagePiece, PromptDataType | ||
| from pyrit.prompt_target.common.prompt_target import PromptTarget | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| # Path to the assets directory containing test files for modality verification | ||
| _ASSETS_DIR = DATASETS_PATH / "modality_test_assets" | ||
|
|
||
| # Mapping from PromptDataType to test asset filenames | ||
| _TEST_ASSETS: dict[str, str] = { | ||
| "image_path": str(_ASSETS_DIR / "test_image.png"), | ||
| "audio_path": str(_ASSETS_DIR / "test_audio.wav"), | ||
| "video_path": str(_ASSETS_DIR / "test_video.mp4"), | ||
| } | ||
|
|
||
|
|
||
| async def verify_target_modalities( | ||
| target: PromptTarget, | ||
| test_modalities: Optional[set[frozenset[PromptDataType]]] = None, | ||
| ) -> set[frozenset[PromptDataType]]: | ||
| """ | ||
| Verify which modality combinations a target actually supports. | ||
|
|
||
| This function tests the target with minimal requests to determine actual | ||
| modalities, trimming down from the static API declarations. | ||
|
|
||
| Args: | ||
| target: The prompt target to test | ||
| test_modalities: Specific modalities to test (defaults to target's declared modalities) | ||
|
|
||
| Returns: | ||
| Set of actually supported input modality combinations | ||
|
|
||
| Example: | ||
| actual = await verify_target_modalities(openai_target) | ||
| # Returns: {frozenset(["text"])} or {frozenset(["text"]), frozenset(["text", "image_path"])} | ||
| """ | ||
| if test_modalities is None: | ||
| test_modalities = target.SUPPORTED_INPUT_MODALITIES | ||
|
|
||
| verified_modalities: set[frozenset[PromptDataType]] = set() | ||
|
|
||
| for modality_combination in test_modalities: | ||
| try: | ||
| is_supported = await _test_modality_combination(target, modality_combination) | ||
| if is_supported: | ||
| verified_modalities.add(modality_combination) | ||
| except Exception as e: | ||
| logger.info(f"Failed to verify {modality_combination}: {e}") | ||
|
|
||
| return verified_modalities | ||
|
|
||
|
|
||
| async def _test_modality_combination( | ||
| target: PromptTarget, | ||
| modalities: frozenset[PromptDataType], | ||
| ) -> bool: | ||
| """ | ||
| Test a specific modality combination with a minimal API request. | ||
|
|
||
| Args: | ||
| target: The target to test | ||
| modalities: The combination of modalities to test | ||
|
|
||
| Returns: | ||
| True if the combination is supported, False otherwise | ||
| """ | ||
| test_message = _create_test_message(modalities) | ||
|
|
||
| try: | ||
| responses = await target.send_prompt_async(message=test_message) | ||
|
|
||
| # Check if the response itself indicates an error | ||
| for response in responses: | ||
| for piece in response.message_pieces: | ||
| if piece.response_error != "none": | ||
| logger.info(f"Modality {modalities} returned error response: {piece.converted_value}") | ||
| return False | ||
|
|
||
| return True | ||
|
|
||
| except Exception as e: | ||
romanlutz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| logger.info(f"Modality {modalities} not supported: {e}") | ||
| return False | ||
|
|
||
|
|
||
| def _create_test_message(modalities: frozenset[PromptDataType]) -> Message: | ||
| """ | ||
| Create a minimal test message for the specified modalities. | ||
|
|
||
| Args: | ||
| modalities: The modalities to include in the test message | ||
|
|
||
| Returns: | ||
| A Message object with minimal content for each requested modality | ||
|
|
||
| Raises: | ||
| FileNotFoundError: If a required test asset file is missing | ||
| ValueError: If a modality has no configured test asset or no pieces could be created | ||
| """ | ||
| pieces: list[MessagePiece] = [] | ||
| conversation_id = "modality-verification-test" | ||
|
|
||
| for modality in modalities: | ||
| if modality == "text": | ||
| pieces.append( | ||
| MessagePiece( | ||
| role="user", | ||
| original_value="test", | ||
| original_value_data_type="text", | ||
| conversation_id=conversation_id, | ||
| ) | ||
| ) | ||
| elif modality in _TEST_ASSETS: | ||
| asset_path = _TEST_ASSETS[modality] | ||
| if not os.path.isfile(asset_path): | ||
| raise FileNotFoundError(f"Test asset not found for modality '{modality}': {asset_path}") | ||
| pieces.append( | ||
| MessagePiece( | ||
| role="user", | ||
| original_value=asset_path, | ||
| original_value_data_type=modality, | ||
| conversation_id=conversation_id, | ||
| ) | ||
| ) | ||
| else: | ||
| raise ValueError(f"No test asset configured for modality: {modality}") | ||
|
|
||
| if not pieces: | ||
| raise ValueError(f"Could not create test message for modalities: {modalities}") | ||
|
|
||
| return Message(pieces) | ||
Uh oh!
There was an error while loading. Please reload this page.