10000 refactor: ros2 tools in manimulation demo by boczekbartek · Pull Request #551 · RobotecAI/rai · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

refactor: ros2 tools in manimulation demo #551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: development
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions src/rai_core/rai/tools/ros2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
raise ImportError(
"This is a ROS2 feature. Make sure ROS2 is installed and sourced."
)

from .cli import (
ROS2CLIToolkit,
ros2_action,
Expand Down Expand Up @@ -48,11 +47,6 @@
ROS2TopicsToolkit,
StartROS2ActionTool,
)
from .manipulation.custom import (
GetObjectPositionsTool,
MoveToPointTool,
MoveToPointToolInput,
)
from .navigation.nav2 import (
CancelNavigateToPoseTool,
GetNavigateToPoseFeedbackTool,
Expand All @@ -71,7 +65,6 @@
"CancelROS2ActionTool",
"GetNavigateToPoseFeedbackTool",
"GetNavigateToPoseResultTool",
"GetObjectPositionsTool",
"GetROS2ActionFeedbackTool",
"GetROS2ActionIDsTool",
"GetROS2ActionResultTool",
Expand All @@ -83,8 +76,6 @@
"GetROS2TopicsNamesAndTypesTool",
"GetROS2TransformConfiguredTool",
"GetROS2TransformTool",
"MoveToPointTool",
"MoveToPointToolInput",
"Nav2Toolkit",
"NavigateToPoseTool",
"PublishROS2MessageTool",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@

import numpy as np
import sensor_msgs.msg
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field
from rai.communication.ros2 import ROS2Connector
from rai.communication.ros2.api import convert_ros_img_to_ndarray
from rai.communication.ros2.ros_async import get_future_result
from rai.tools.ros2.base import BaseROS2Tool
from rclpy.exceptions import (
ParameterNotDeclaredException,
ParameterUninitializedException,
Expand Down Expand Up @@ -78,9 +77,7 @@ class DistanceMeasurement(NamedTuple):


# --------------------- Tools ---------------------
class GroundingDinoBaseTool(BaseTool):
connector: ROS2Connector = Field(..., exclude=True)

class GroundingDinoBaseTool(BaseROS2Tool):
box_threshold: float = Field(default=0.35, description="Box threshold for GDINO")
text_threshold: float = Field(default=0.45, description="Text threshold for GDINO")

Expand All @@ -89,7 +86,7 @@ def _call_gdino_node(
) -> Future:
cli = self.connector.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME)
while not cli.wait_for_service(timeout_sec=1.0):
self.node.get_logger().info(
self.connector.node.get_logger().info(
f"service {GDINO_SERVICE_NAME} not available, waiting again..."
)
req = RAIGroundingDino.Request()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
import numpy as np
import rclpy
import sensor_msgs.msg
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field
from rai.communication.ros2.api import (
convert_ros_img_to_base64,
convert_ros_img_to_ndarray,
)
from rai.communication.ros2.connectors import ROS2Connector
from rai.communication.ros2.ros_async import get_future_result
from rai.tools.ros2.base import BaseROS2Tool
from rclpy import Future
from rclpy.exceptions import (
ParameterNotDeclaredException,
Expand Down Expand Up @@ -67,12 +66,7 @@ class GetGrabbingPointInput(BaseModel):


# --------------------- Tools ---------------------
class GetSegmentationTool:
connector: ROS2Connector = Field(..., exclude=True)

name: str = ""
description: str = ""

class GetSegmentationTool(BaseROS2Tool):
box_threshold: float = Field(default=0.35, description="Box threshold for GDINO")
text_threshold: float = Field(default=0.45, description="Text threshold for GDINO")

Expand Down Expand Up @@ -194,9 +188,7 @@ def depth_to_point_cloud(
return points


class GetGrabbingPointTool(BaseTool):
connector: ROS2Connector = Field(..., exclude=True)

class GetGrabbingPointTool(BaseROS2Tool):
name: str = "GetGrabbingPointTool"
description: str = "Get the grabbing point of an object"
pcd: List[Any] = []
Expand Down
0