8000 refactor: Disallow similar tool names in the tool registry by zastrowm · Pull Request #193 · strands-agents/sdk-python · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

refactor: Disallow similar tool names in the tool registry #193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,10 @@ def find_normalized_tool_name() -> Optional[str]:
# all tools that can be represented with the normalized name
if "_" in name:
filtered_tools = [
tool_name
for (tool_name, tool) in tool_registry.items()
if tool_name.replace("-", "_") == name
tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name
]

if len(filtered_tools) > 1:
raise AttributeError(f"Multiple tools matching '{name}' found: {', '.join(filtered_tools)}")

# The registry itself defends against similar names, so we can just take the first match
if filtered_tools:
return filtered_tools[0]

Expand Down
4 changes: 1 addition & 3 deletions src/strands/telemetry/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@

from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter

# See https://github.com/open-telemetry/opentelemetry-python/issues/4615 for the type ignore
from opentelemetry.sdk.resources import Resource # type: ignore[attr-defined]
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SimpleSpanProcessor
from opentelemetry.trace import StatusCode
Expand Down
15 changes: 15 additions & 0 deletions src/strands/tools/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,21 @@ def register_tool(self, tool: AgentTool) -> None:
tool.is_dynamic,
)

if self.registry.get(tool.tool_name) is None:
normalized_name = tool.tool_name.replace("-", "_")

matching_tools = [
tool_name
for (tool_name, tool) in self.registry.items()
if tool_name.replace("-", "_") == normalized_name
]

if matching_tools:
raise ValueError(
f"Tool name '{tool.tool_name}' already exists as '{matching_tools[0]}'."
" Cannot add a duplicate tool which differs by a '-' or '_'"
)

# Register in main registry
self.registry[tool.tool_name] = tool

Expand Down
22 changes: 0 additions & 22 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,28 +739,6 @@ def function(system_prompt: str) -> str:
}


def test_agent_tool_with_multiple_normalized_matches(agent, tool_registry, mock_randint):
agent.tool_handler = unittest.mock.Mock()

@strands.tools.tool(name="system-prompter_1")
def function1(system_prompt: str) -> str:
return system_prompt

@strands.tools.tool(name="system-prompter-1")
def function2(system_prompt: str) -> str:
return system_prompt

agent.tool_registry.register_tool(strands.tools.tools.FunctionTool(function1))
agent.tool_registry.register_tool(strands.tools.tools.FunctionTool(function2))

mock_randint.return_value = 1

with pytest.raises(AttributeError) as err:
agent.tool.system_prompter_1(system_prompt="tool prompt")

assert str(err.value) == "Multiple tools matching 'system_prompter_1' found: system-prompter_1, system-prompter-1"


def test_agent_tool_with_no_normalized_match(agent, tool_registry, mock_randint):
agent.tool_handler = unittest.mock.Mock()

Expand Down
20 changes: 20 additions & 0 deletions tests/strands/tools/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
Tests for the SDK tool registry module.
"""

from unittest.mock import MagicMock

import pytest

from strands.tools import PythonAgentTool
from strands.tools.registry import ToolRegistry


Expand All @@ -23,3 +26,20 @@ def test_process_tools_with_invalid_path():

with pytest.raises(ValueError, match=f"Failed to load tool {invalid_path.split('.')[0]}: Tool file not found:.*"):
tool_registry.process_tools([invalid_path])


def test_register_tool_with_similar_name_raises():
tool_1 = PythonAgentTool(tool_name="tool-like-this", tool_spec=MagicMock(), callback=lambda: None)
tool_2 = PythonAgentTool(tool_name="tool_like_this", tool_spec=MagicMock(), callback=lambda: None)

tool_registry = ToolRegistry()

tool_registry.register_tool(tool_1)

with pytest.raises(ValueError) as err:
tool_registry.register_tool(tool_2)

assert (
str(err.value) == "Tool name 'tool_like_this' already exists as 'tool-like-this'. "
"Cannot add a duplicate tool which differs by a '-' or '_'"
)
0