8000 Add basis for conformance-based tests by zastrowm · Pull Request #403 · strands-agents/sdk-python · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add basis for conformance-based tests #403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions tests_integ/models/conformance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest

from strands.types.models import Model
from tests_integ.models.providers import ProviderInfo, all_providers


def get_models():
return [
pytest.param(
provider_info,
id=provider_info.id, # Adds the provider name to the test name
marks=[provider_info.mark], # ignores tests that don't have the requirements
)
for provider_info in all_providers
]


@pytest.fixture(params=get_models())
def provider_info(request) -> ProviderInfo:
return request.param


@pytest.fixture()
def model(provider_info):
return provider_info.create_model()


def test_model_can_be_constructed(model: Model):
assert model is not None
pass
127 changes: 112 additions & 15 deletions tests_integ/models/providers.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,51 @@
"""
Aggregates all providers for testing all providers in one go.
"""

import os
from dataclasses import dataclass
from typing import Callable, Optional

import requests
from pytest import mark

from strands.models import BedrockModel
from strands.models.anthropic import AnthropicModel
from strands.models.litellm import LiteLLMModel
from strands.models.llamaapi import LlamaAPIModel
from strands.models.mistral import MistralModel
from strands.models.ollama import OllamaModel
from strands.models.openai import OpenAIModel
from strands.models.writer import WriterModel
from strands.types.models import Model


@dataclass
class ApiKeyProviderInfo:
class ProviderInfo:
"""Provider-based info for providers that require an APIKey via environment variables."""

def __init__(self, id: str, environment_variable: str) -> None:
def __init__(
self,
id: str,
factory: Callable[[], Model],
environment_variable: Optional[str] = None,
) -> None:
self.id = id
self.environment_variable = environment_variable
self.model_factory = factory
self.mark = mark.skipif(
self.environment_variable not in os.environ,
reason=f"{self.environment_variable} environment variable missing",
environment_variable is not None and environment_variable not in os.environ,
reason=f"{environment_variable} environment variable missing",
)

def create_model(self) -> Model:
return self.model_factory()


class OllamaProviderInfo:
class OllamaProviderInfo(ProviderInfo):
"""Special case ollama as it's dependent on the server being available."""

def __init__(self):
self.id = "ollama"
super().__init__(
id="ollama", factory=lambda: OllamaModel(host="http://localhost:11434", model_id="llama3.3:70b")
)

is_server_available = False
try:
Expand All @@ -36,11 +59,85 @@ def __init__(self):
)


anthropic = ApiKeyProviderInfo(id="anthropic", environment_variable="ANTHROPIC_API_KEY")
cohere = ApiKeyProviderInfo(id="cohere", environment_variable="CO_API_KEY")
llama = ApiKeyProviderInfo(id="cohere", environment_variable="LLAMA_API_KEY")
mistral = ApiKeyProviderInfo(id="mistral", environment_variable="MISTRAL_API_KEY")
openai = ApiKeyProviderInfo(id="openai", environment_variable="OPENAI_API_KEY")
writer = ApiKeyProviderInfo(id="writer", environment_variable="WRITER_API_KEY")
anthropic = ProviderInfo(
id="anthropic",
environment_variable="ANTHROPIC_API_KEY",
factory=lambda: AnthropicModel(
client_args={
"api_key": os.getenv("ANTHROPIC_API_KEY"),
},
model_id="claude-3-7-sonnet-20250219",
max_tokens=512,
),
)
bedrock = ProviderInfo(id="bedrock", factory=lambda: BedrockModel())
cohere = ProviderInfo(
id="cohere",
environment_variable="CO_API_KEY",
factory=lambda: OpenAIModel(
client_args={
"base_url": "https://api.cohere.com/compatibility/v1",
"api_key": os.getenv("CO_API_KEY"),
},
model_id="command-a-03-2025",
params={"stream_options": None},
),
)
litellm = ProviderInfo(
id="litellm", factory=lambda: LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0")
)
llama = ProviderInfo(
id="llama",
environment_variable="LLAMA_API_KEY",
factory=lambda: LlamaAPIModel(
model_id="Llama-4-Maverick-17B-128E-Instruct-FP8",
client_args={
"api_key": os.getenv("LLAMA_API_KEY"),
},
),
)
mistral = ProviderInfo(
id="mistral",
environment_variable="MISTRAL_API_KEY",
factory=lambda: MistralModel(
model_id="mistral-medium-latest",
api_key=os.getenv("MISTRAL_API_KEY"),
stream=True,
temperature=0.7,
max_tokens=1000,
top_p=0.9,
),
)
openai = ProviderInfo(
id="openai",
environment_variable="OPENAI_API_KEY",
factory=lambda: OpenAIModel(
model_id="gpt-4o",
client_args={
"api_key": os.getenv("OPENAI_API_KEY"),
},
),
)
writer = ProviderInfo(
id="writer",
environment_variable="WRITER_API_KEY",
factory=lambda: WriterModel(
model_id="palmyra-x4",
client_args={"api_key": os.getenv("WRITER_API_KEY", "")},
stream_options={"include_usage": True},
),
)

ollama = OllamaProviderInfo()


all_providers = [
bedrock,
anthropic,
cohere,
llama,
litellm,
mistral,
openai,
writer,
]
0