8000 refactor: remove rai.utils module by maciejmajek · Pull Request #524 · RobotecAI/rai · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

refactor: remove rai.utils module #524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ rosdep install --from-paths src --ignore-src -r -y
Run the configuration tool to set up your vendor and other settings:

```bash
poetry run streamlit run src/rai_core/rai/utils/configurator.py
poetry run streamlit run src/rai_core/rai/frontend/configurator.py
```

> [!TIP]
Expand Down
2 changes: 1 addition & 1 deletion docs/debugging_assistant.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The ROS 2 Debugging Assistant is an interactive tool that helps developers inspe

```sh
source setup_shell.sh
streamlit run src/rai_core/rai/tools/debugging_assistant.py
streamlit run examples/debugging_assistant.py
```

## Usage Examples
Expand Down
2 changes: 1 addition & 1 deletion docs/developer_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ from myrobot import robot

from rai.agents.conversational_agent import create_conversational_agent
from rai.agents.state_based import create_state_based_agent
from rai.utils.model_initialization import get_llm_model
from rai import get_llm_model

SYSTEM_PROMPT = "You are a robot with interfaces..."

Expand Down
9 changes: 5 additions & 4 deletions docs/developer_guide/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ TODO(docs): add link to the BaseAgent docs (regarding distributed setup)
from rai.agents import ReActAgent
from rai.communication import ROS2ARIConnector, ROS2HRIConnector
from rai.tools.ros2 import ROS2Toolkit
from rai.utils import ROS2Context, wait_for_shutdown
from rai.communication.ros2 import ROS2Context
from rai import AgentRunner

@ROS2Context()
def main() -> None:
Expand All @@ -206,8 +207,8 @@ def main() -> None:
connectors={"hri": connector},
tools=initialize_tools(connector=ari_connector),
)
agent.run()
wait_for_shutdown([agent])
runner = AgentRunner([agent])
runner.run_and_wait_for_shutdown()

# Example:
# ros2 topic pub /from_human rai_interfaces/msg/HRIMessage "{\"text\": \"What do you see?\"}"
Expand All @@ -221,7 +222,7 @@ def main() -> None:
```python
from rai.agents.langchain import create_react_runnable
from langchain.schema import HumanMessage
from rai.utils import ROS2Context, wait_for_shutdown
from rai.communication.ros2 import ROS2Context

@ROS2Context()
def main():
Expand Down
2 changes: 1 addition & 1 deletion docs/multimodal_messages.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Example:

```python
from rai.messages import HumanMultimodalMessage, preprocess_image
from rai.utils.model_initialization import get_llm_model
from rai import get_llm_model

base64_image = preprocess_image('https://raw.githubusercontent.com/RobotecAI/RobotecGPULidar/develop/docs/image/rgl-logo.png')

Expand Down
2 changes: 1 addition & 1 deletion docs/tracing.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ To enable tracing in your RAI application, you need to import the get_tracing_ca
1. First, import the get_tracing_callbacks() function:

```python
from rai.utils.model_initialization import get_tracing_callbacks
from rai import get_tracing_callbacks
```

2. Then, add it to the configuration when invoking your agent or model:
Expand Down
9 changes: 4 additions & 5 deletions examples/agents/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# See the License for the specific language goveself.rning permissions and
# limitations under the License.

from rai.agents import ReActAgent
from rai.communication.ros2 import ROS2ARIConnector, ROS2HRIConnector
from rai.agents import AgentRunner, ReActAgent
from rai.communication.ros2 import ROS2ARIConnector, ROS2Context, ROS2HRIConnector
from rai.tools.ros2 import ROS2Toolkit
from rai.utils import ROS2Context, wait_for_shutdown


@ROS2Context()
Expand All @@ -26,8 +25,8 @@ def main():
connectors={"hri": connector},
tools=ROS2Toolkit(connector=ari_connector).get_tools(),
) # type: ignore
agent.run()
wait_for_shutdown([agent])
runner = AgentRunner([agent])
runner.run_and_wait_for_shutdown()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/agriculture-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
import rclpy
from langchain_core.messages import HumanMessage
from langchain_core.runnables import Runnable
from rai import get_llm_model
from rai.agents.conversational_agent import State, create_conversational_agent
from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.tools.ros2 import ROS2ServicesToolkit, ROS2TopicsToolkit
from rai.tools.time import WaitForSecondsTool
from rai.utils.model_initialization import get_llm_model
from rclpy.callback_groups import ReentrantCallbackGroup
from rclpy.executors import MultiThreadedExecutor
from rclpy.node import Node
Expand Down
50 changes: 50 additions & 0 deletions examples/debugging_assistant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import streamlit as st
from rai import get_llm_model
from rai.agents import create_conversational_agent
from rai.frontend import run_streamlit_app
from rai.tools.ros2 import ROS2CLIToolkit


@st.cache_resource
def initialize_agent():
llm = get_llm_model(model_type="complex_model", streaming=True)
agent = create_conversational_agent(
llm,
ROS2CLIToolkit().get_tools(),
system_prompt="""You are a ROS 2 expert helping a user with their ROS 2 questions. You have access to various tools that allow you to query the ROS 2 system.
Be proactive and use the tools to answer questions. Retrieve as much information from the ROS 2 system as possible.
""",
)
return agent


st.set_page_config(
page_title="ROS 2 Debugging Assistant",
page_icon=":robot:",
)


def main():
run_streamlit_app(
initialize_agent(),
page_title="ROS 2 Debugging Assistant",
initial_message="Hi! I am a ROS 2 assistant. How can I help you?",
)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion examples/manipulation-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
import rclpy
import rclpy.qos
from langchain_core.messages import HumanMessage
from rai import get_llm_model
from rai.agents.conversational_agent import create_conversational_agent
from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.tools.ros2 import GetROS2ImageTool, GetROS2TopicsNamesAndTypesTool
from rai.tools.ros2.manipulation import GetObjectPositionsTool, MoveToPointTool
from rai.utils.model_initialization import get_llm_model
from rai_open_set_vision.tools import GetGrabbingPointTool


Expand Down
2 changes: 1 addition & 1 deletion examples/rosbot-xl-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import rclpy
import streamlit as st
from langchain_core.tools import BaseTool
from rai import get_llm_model
from rai.agents import ReActAgent
from rai.communication.ros2 import ROS2ARIConnector
from rai.frontend.streamlit import run_streamlit_app
Expand All @@ -27,7 +28,6 @@
Nav2Toolkit,
)
from rai.tools.time import WaitForSecondsTool
from rai.utils.model_initialization import get_llm_model
from rai_open_set_vision.tools import GetGrabbingPointTool

# Set page configuration first
Expand Down
2 changes: 1 addition & 1 deletion examples/s2s/conversational.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
import rclpy
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from rai import get_llm_model
from rai.agents.base import BaseAgent
from rai.communication import BaseConnector
from rai.communication.ros2 import IROS2Message, ROS2HRIConnector, TopicConfig
from rai.utils.model_initialization import get_llm_model

from rai_interfaces.msg import HRIMessage as InterfacesHRIMessage

Expand Down
2 changes: 1 addition & 1 deletion src/rai_bench/rai_bench/examples/o3de_test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
from langchain.tools import BaseTool
from rai.agents.conversational_agent import create_conversational_agent
from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.initialization import get_llm_model
from rai.tools.ros2 import (
GetObjectPositionsTool,
GetROS2ImageTool,
GetROS2TopicsNamesAndTypesTool,
MoveToPointTool,
)
from rai.utils.model_initialization import get_llm_model
from rai_open_set_vision.tools import GetGrabbingPointTool

from rai_bench.benchmark_model import Benchmark
Expand Down
10000
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
from datetime import datetime
from pathlib import Path

from rai.agents.conversational_agent import create_conversational_agent
from rai.utils.model_initialization import (
from rai import (
get_llm_model,
get_llm_model_config_and_vendor,
)
from rai.agents.conversational_agent import create_conversational_agent

from rai_bench.examples.tool_calling_agent_bench_tasks import tasks
from rai_bench.tool_calling_agent_bench.agent_bench import ToolCallingAgentBenchmark
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from langgraph.errors import GraphRecursionError
from langgraph.graph.state import CompiledStateGraph
from pydantic import BaseModel, Field
from rai.messages.multimodal import HumanMultimodalMessage
from rai.messages import HumanMultimodalMessage

from rai_bench.tool_calling_agent_bench.agent_tasks_interfaces import (
ToolCallingAgentTask,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.tracers.langchain import LangChainTracer
from langfuse.callback import CallbackHandler
from rai.utils.model_initialization import get_tracing_callbacks
from rai.initialization import get_tracing_callbacks


class ScoreTracingHandler:
Expand Down
17 changes: 17 additions & 0 deletions src/rai_core/rai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .agents import AgentRunner, ReActAgent
from .initialization import (
get_embeddings_model,
get_llm_model,
get_llm_model_config_and_vendor,
get_tracing_callbacks,
)

__all__ = [
"AgentRunner",
"ReActAgent",
"get_embeddings_model",
"get_llm_model",
"get_llm_model_config_and_vendor",
"get_tracing_callbacks",
]
3 changes: 3 additions & 0 deletions src/rai_core/rai/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@

from rai.agents.conversational_agent import create_conversational_agent
from rai.agents.react_agent import ReActAgent
from rai.agents.runner import AgentRunner, wait_for_shutdown
from rai.agents.state_based import create_state_based_agent
from rai.agents.tool_runner import ToolRunner

__all__ = [
"AgentRunner",
"ReActAgent",
"ToolRunner",
"create_conversational_agent",
"create_state_based_agent",
"wait_for_shutdown",
]
2 changes: 1 addition & 1 deletion src/rai_core/rai/agents/langchain/runnables.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from langgraph.prebuilt.tool_node import tools_condition

from rai.agents.tool_runner import ToolRunner
from rai.utils.model_initialization import get_llm_model
from rai.initialization import get_llm_model


class ReActAgentState(TypedDict):
Expand Down
2 changes: 1 addition & 1 deletion src/rai_core/rai/agents/react_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from rai.agents.langchain import HRICallbackHandler, create_react_runnable
from rai.agents.langchain.runnables import ReActAgentState
from rai.communication.hri_connector import HRIConnector, HRIMessage
from rai.utils.model_initialization import get_tracing_callbacks
from rai.initialization import get_tracing_callbacks


class ReActAgent(BaseAgent):
Expand Down
Loading
Loading
0