8000 refactor: tools by maciejmajek · Pull Request #521 · RobotecAI/rai · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

refactor: tools #521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 14, 2025
2 changes: 1 addition & 1 deletion examples/manipulation-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from langchain_core.messages import HumanMessage
from rai.agents.conversational_agent import create_conversational_agent
from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.tools.ros.manipulation import GetObjectPositionsTool, MoveToPointTool
from rai.tools.ros2 import GetROS2ImageTool, GetROS2TopicsNamesAndTypesTool
from rai.tools.ros2.manipulation import GetObjectPositionsTool, MoveToPointTool
from rai.utils.model_initialization import get_llm_model
from rai_open_set_vision.tools import GetGrabbingPointTool

Expand Down
3 changes: 2 additions & 1 deletion examples/rosbot-xl-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
from rai.agents import ReActAgent
from rai.communication.ros2 import ROS2ARIConnector
from rai.frontend.streamlit import run_streamlit_app
from rai.tools.ros.manipulation import GetGrabbingPointTool, GetObjectPositionsTool
from rai.tools.ros2 import (
GetObjectPositionsTool,
GetROS2ImageConfiguredTool,
GetROS2TransformConfiguredTool,
Nav2Toolkit,
)
from rai.tools.time import WaitForSecondsTool
from rai.utils.model_initialization import get_llm_model
from rai_open_set_vision.tools import GetGrabbingPointTool

# Set page configuration first
st.set_page_config(
Expand Down
6 changes: 2 additions & 4 deletions src/rai_bench/rai_bench/examples/o3de_test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@
from langchain.tools import BaseTool
from rai.agents.conversational_agent import create_conversational_agent
from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.tools.ros.manipulation import (
GetObjectPositionsTool,
MoveToPointTool,
)
from rai.tools.ros2 import (
GetObjectPositionsTool,
GetROS2ImageTool,
GetROS2TopicsNamesAndTypesTool,
MoveToPointTool,
)
from rai.utils.model_initialization import get_llm_model
from rai_open_set_vision.tools import GetGrabbingPointTool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,14 @@
from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.communication.ros2.messages import ROS2ARIMessage
from rai.messages import MultimodalArtifact, preprocess_image
from rai.tools.ros.manipulation import (
GetGrabbingPointTool,
GetObjectPositionsTool,
MoveToPointTool,
)
from rai.tools.ros2 import (
GetObjectPositionsTool,
GetROS2ImageTool,
GetROS2TopicsNamesAndTypesTool,
MoveToPointTool,
ReceiveROS2MessageTool,
)
from rai_open_set_vision.tools import GetGrabbingPointTool


class MockGetROS2TopicsNamesAndTypesTool(GetROS2TopicsNamesAndTypesTool):
Expand Down
9E7A
Original file line number Diff line numberDiff line change
Expand Up @@ -21,7 +21,7 @@
from langchain_core.messages import AIMessage
from langchain_core.messages.tool import ToolCall
from langchain_core.tools import BaseTool
from rai.tools.ros.manipulation import MoveToPointToolInput
from rai.tools.ros2 import MoveToPointToolInput

from rai_bench.tool_calling_agent_bench.agent_tasks_interfaces import (
ROS2ToolCallingAgentTask,
Expand Down
12 changes: 12 additions & 0 deletions src/rai_core/rai/communication/ros2/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@

from .action import ROS2ActionAPI
from .base import IROS2Message
from .conversion import (
convert_ros_img_to_base64,
convert_ros_img_to_cv2mat,
convert_ros_img_to_ndarray,
import_message_from_str,
ros2_message_to_dict,
)
from .service import ROS2ServiceAPI
from .topic import ConfigurableROS2TopicAPI, ROS2TopicAPI, TopicConfig

Expand All @@ -24,4 +31,9 @@
"ROS2ServiceAPI",
"ROS2TopicAPI",
"TopicConfig",
"convert_ros_img_to_base64",
"convert_ros_img_to_cv2mat",
"convert_ros_img_to_ndarray",
"import_message_from_str",
"ros2_message_to_dict",
]
2 changes: 1 addition & 1 deletion src/rai_core/rai/communication/ros2/api/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
BaseROS2API,
IROS2Message,
)
from rai.tools.ros.utils import import_message_from_str
from rai.communication.ros2.api.conversion import import_message_from_str


class ROS2ActionData(TypedDict):
Expand Down
2 changes: 1 addition & 1 deletion src/rai_core/rai/communication/ros2/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
)
from rclpy.topic_endpoint_info import TopicEndpointInfo

from rai.tools.ros.utils import import_message_from_str
from rai.communication.ros2.api.conversion import import_message_from_str


@runtime_checkable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import base64
from typing import Optional, Type, Union, cast
from typing import Any, OrderedDict, Type, cast

import cv2
import numpy as np
import rclpy
import rclpy.executors
import rclpy.node
import rclpy.time
import rosidl_runtime_py.convert
import rosidl_runtime_py.set_message
import rosidl_runtime_py.utilities
import sensor_msgs.msg
from cv_bridge import CvBridge
from rclpy.duration import Duration
from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile
from rclpy.signals import SignalHandlerGuardCondition
from rclpy.utilities import timeout_sec_to_nsec
from rosidl_parser.definition import NamespacedType
from rosidl_runtime_py.import_message import import_message_from_namespaced_type
from rosidl_runtime_py.utilities import get_namespaced_type
from tf2_ros import Buffer, LookupException, TransformListener, TransformStamped


def ros2_message_to_dict(message: Any) -> OrderedDict[str, Any]:
"""Convert any ROS2 message into a dictionary.

Args:
message: A ROS2 message instance

Returns:
A dictionary representation of the message

Raises:
TypeError: If the input is not a valid ROS2 message
"""
msg_dict: OrderedDict[str, Any] = rosidl_runtime_py.convert.message_to_ordereddict(
message
) # type: ignore
return msg_dict


def import_message_from_str(msg_type: str) -> Type[object]:
Expand Down Expand Up @@ -101,81 +110,3 @@ def convert_ros_img_to_base64(msg: sensor_msgs.msg.Image) -> str:
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
image_data = cv2.imencode(".png", cv_image)[1].tostring() # type: ignore
return base64.b64encode(image_data).decode("utf-8") # type: ignore


# Copied from https://github.com/ros2/rclpy/blob/jazzy/rclpy/rclpy/wait_for_message.py, to support humble
def wait_for_message(
msg_type: Type[object],
node: "Node",
topic: str,
*,
qos_profile: Union[QoSProfile, int] = 1,
time_to_wait: float = -1,
) -> tuple[bool, Optional[object]]:
"""
Wait for the next incoming message.

:param msg_type: message type
:param node: node to initialize the subscription on
:param topic: topic name to wait for message
:param qos_profile: QoS profile to use for the subscription
:param time_to_wait: seconds to wait before returning
:returns: (True, msg) if a message was successfully received, (False, None) if message
could not be obtained or shutdown was triggered asynchronously on the context.
"""
context = node.context
wait_set = _rclpy.WaitSet(1, 1, 0, 0, 0, 0, context.handle)
wait_set.clear_entities()

sub = node.create_subscription(
msg_type, topic, lambda _: None, qos_profile=qos_profile
)
try:
wait_set.add_subscription(sub.handle)
sigint_gc = SignalHandlerGuardCondition(context=context)
wait_set.add_guard_condition(sigint_gc.handle)

timeout_nsec = timeout_sec_to_nsec(time_to_wait)
wait_set.wait(timeout_nsec)

subs_ready = wait_set.get_ready_entities("subscription")
guards_ready = wait_set.get_ready_entities("guard_condition")

if guards_ready:
if sigint_gc.handle.pointer in guards_ready:
return False, None

if subs_ready:
if sub.handle.pointer in subs_ready:
msg_info = sub.handle.take_message(sub.msg_type, sub.raw)
if msg_info is not None:
return True, msg_info[0]
finally:
# TODO(boczekbartek): uncomment when rclpy resolves: https://github.com/ros2/rclpy/issues/1142
# node.destroy_subscription(sub)
pass

return False, None


def get_transform(
node: rclpy.node.Node,
target_frame: str,
source_frame: str,
timeout_sec: float = 5.0,
) -> TransformStamped:
tf_buffer = Buffer(node=node)
tf_listener = TransformListener(tf_buffer, node)

transform: Optional[TransformStamped] = tf_buffer.lookup_transform(
target_frame, source_frame, rclpy.time.Time(), timeout=Duration(seconds=3)
)

tf_listener.unregister()

if transform is None:
raise LookupException(
f"Could not find transform from {source_frame} to {target_frame} in {timeout_sec} seconds"
)

return transform
2 changes: 1 addition & 1 deletion src/rai_core/rai/communication/ros2/api/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from rai.communication.ros2.api.base import (
BaseROS2API,
)
from rai.tools.ros.utils import import_message_from_str
from rai.communication.ros2.api.conversion import import_message_from_str


class ROS2ServiceAPI(BaseROS2API):
Expand Down
2 changes: 1 addition & 1 deletion src/rai_core/rai/communication/ros2/api/topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
BaseROS2API,
IROS2Message,
)
from rai.tools.ros.utils import import_message_from_str
from rai.communication.ros2.api.conversion import import_message_from_str


class ROS2TopicAPI(BaseROS2API):
Expand Down
2 changes: 1 addition & 1 deletion src/rai_core/rai/messages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.


from .conversion import preprocess_image
from .multimodal import (
AIMultimodalMessage,
HumanMultimodalMessage,
MultimodalArtifact,
SystemMultimodalMessage,
ToolMultimodalMessage,
)
from .utils import preprocess_image

__all__ = [
"AIMultimodalMessage",
Expand Down
37 changes: 0 additions & 37 deletions src/rai_core/rai/tools/ros/nav2/basic_navigator.py

This file was deleted.

65 changes: 0 additions & 65 deletions src/rai_core/rai/tools/ros/nav2/navigator.py

This file was deleted.

Loading
0