From 435bb783e47f72caa182fe5ead8af5a40de876e7 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Mon, 12 May 2025 18:24:01 +0200 Subject: [PATCH] ci: enable linting in ci - fix all linting errors except: - ignore not so important linting errors (e.g., E501) for certain lines, files and folders to reduce effort --- .github/workflows/checks.yaml | 1 + pdm.lock | 73 +------ pyproject.toml | 23 ++- src/askui/chat/__main__.py | 84 ++++---- src/askui/chat/click_recorder.py | 18 +- src/askui/chat/exceptions.py | 47 +++++ src/askui/locators/locators.py | 5 +- src/askui/locators/relatable.py | 10 +- src/askui/locators/serializers.py | 46 +++-- src/askui/models/anthropic/__init__.py | 1 + src/askui/models/anthropic/claude.py | 32 +-- src/askui/models/anthropic/claude_agent.py | 4 +- src/askui/models/askui/__init__.py | 1 + src/askui/models/askui/ai_element_utils.py | 3 +- src/askui/models/askui/api.py | 12 +- src/askui/models/askui/exceptions.py | 28 +++ src/askui/models/huggingface/__init__.py | 1 + src/askui/models/huggingface/spaces_api.py | 13 +- src/askui/models/models.py | 36 +++- src/askui/models/router.py | 57 ++++-- src/askui/models/types/response_schemas.py | 22 ++- src/askui/models/ui_tars_ep/__init__.py | 1 + src/askui/models/ui_tars_ep/parser.py | 11 +- src/askui/models/ui_tars_ep/prompts.py | 24 ++- src/askui/models/ui_tars_ep/ui_tars_api.py | 16 +- src/askui/reporting.py | 51 +++-- src/askui/telemetry/anonymous_id.py | 26 ++- src/askui/telemetry/context.py | 3 +- src/askui/telemetry/pkg_version.py | 6 +- src/askui/telemetry/processors.py | 3 +- src/askui/telemetry/telemetry.py | 83 ++++---- src/askui/telemetry/user_identification.py | 10 +- src/askui/telemetry/utils.py | 7 +- src/askui/tools/agent_os.py | 43 ++-- src/askui/tools/anthropic/base.py | 3 +- src/askui/tools/anthropic/computer.py | 64 +++--- src/askui/tools/askui/askui_controller.py | 183 ++++++++++-------- src/askui/tools/askui/askui_hub.py | 63 +++--- .../askui/askui_workspaces/api_client.py | 7 +- .../tools/askui/askui_workspaces/rest.py | 2 +- src/askui/tools/askui/exceptions.py | 28 +++ src/askui/tools/toolbox.py | 6 +- src/askui/tools/utils.py | 8 +- src/askui/utils/image_utils.py | 138 +++++++------ tests/e2e/agent/test_get.py | 4 +- tests/e2e/agent/test_model_composition.py | 2 +- .../test_askui_locator_serializer.py | 2 +- tests/unit/telemetry/__init__.py | 0 tests/unit/telemetry/test_telemetry.py | 22 ++- tests/unit/utils/__init__.py | 0 tests/unit/utils/test_image_utils.py | 29 ++- tests/utils/__init__.py | 0 tests/utils/generate_ai_elements.py | 6 +- 53 files changed, 807 insertions(+), 561 deletions(-) create mode 100644 src/askui/chat/exceptions.py create mode 100644 src/askui/models/anthropic/__init__.py create mode 100644 src/askui/models/askui/__init__.py create mode 100644 src/askui/models/askui/exceptions.py create mode 100644 src/askui/models/huggingface/__init__.py create mode 100644 src/askui/models/ui_tars_ep/__init__.py create mode 100644 tests/unit/telemetry/__init__.py create mode 100644 tests/unit/utils/__init__.py create mode 100644 tests/utils/__init__.py diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index d34250e..4174845 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -20,4 +20,5 @@ jobs: - run: pdm install - run: pdm run typecheck:all - run: pdm run format --check + - run: pdm run lint - run: pdm run test:unit diff --git a/pdm.lock b/pdm.lock index 6a9453d..090558f 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "chat", "test"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:d469eb562d3f96a3079869d51794332527ddd78c05882b68bb6572a7cf796b43" +content_hash = "sha256:ee418ee4c04af70fff91a2b30c650ad1d04eb5e4bb4a1cbb47c2f43a1327a1cc" [[metadata.targets]] requires_python = ">=3.10" @@ -101,42 +101,6 @@ files = [ {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, ] -[[package]] -name = "black" -version = "25.1.0" -requires_python = ">=3.9" -summary = "The uncompromising code formatter." -groups = ["test"] -dependencies = [ - "click>=8.0.0", - "mypy-extensions>=0.4.3", - "packaging>=22.0", - "pathspec>=0.9.0", - "platformdirs>=2", - "tomli>=1.1.0; python_version < \"3.11\"", - "typing-extensions>=4.0.1; python_version < \"3.11\"", -] -files = [ - {file = "black-25.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:759e7ec1e050a15f89b770cefbf91ebee8917aac5c20483bc2d80a6c3a04df32"}, - {file = "black-25.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e519ecf93120f34243e6b0054db49c00a35f84f195d5bce7e9f5cfc578fc2da"}, - {file = "black-25.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:055e59b198df7ac0b7efca5ad7ff2516bca343276c466be72eb04a3bcc1f82d7"}, - {file = "black-25.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:db8ea9917d6f8fc62abd90d944920d95e73c83a5ee3383493e35d271aca872e9"}, - {file = "black-25.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a39337598244de4bae26475f77dda852ea00a93bd4c728e09eacd827ec929df0"}, - {file = "black-25.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96c1c7cd856bba8e20094e36e0f948718dc688dba4a9d78c3adde52b9e6c2299"}, - {file = "black-25.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bce2e264d59c91e52d8000d507eb20a9aca4a778731a08cfff7e5ac4a4bb7096"}, - {file = "black-25.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:172b1dbff09f86ce6f4eb8edf9dede08b1fce58ba194c87d7a4f1a5aa2f5b3c2"}, - {file = "black-25.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4b60580e829091e6f9238c848ea6750efed72140b91b048770b64e74fe04908b"}, - {file = "black-25.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1e2978f6df243b155ef5fa7e558a43037c3079093ed5d10fd84c43900f2d8ecc"}, - {file = "black-25.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b48735872ec535027d979e8dcb20bf4f70b5ac75a8ea99f127c106a7d7aba9f"}, - {file = "black-25.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:ea0213189960bda9cf99be5b8c8ce66bb054af5e9e861249cd23471bd7b0b3ba"}, - {file = "black-25.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8f0b18a02996a836cc9c9c78e5babec10930862827b1b724ddfe98ccf2f2fe4f"}, - {file = "black-25.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:afebb7098bfbc70037a053b91ae8437c3857482d3a690fefc03e9ff7aa9a5fd3"}, - {file = "black-25.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:030b9759066a4ee5e5aca28c3c77f9c64789cdd4de8ac1df642c40b708be6171"}, - {file = "black-25.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:a22f402b410566e2d1c950708c77ebf5ebd5d0d88a6a2e87c86d9fb48afa0d18"}, - {file = "black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717"}, - {file = "black-25.1.0.tar.gz", hash = "sha256:33496d5cd1222ad73391352b4ae8da15253c5de89b93a80b3e2c8d9a19ec2666"}, -] - [[package]] name = "blinker" version = "1.9.0" @@ -238,7 +202,7 @@ name = "click" version = "8.1.8" requires_python = ">=3.7" summary = "Composable command line interface toolkit" -groups = ["chat", "test"] +groups = ["chat"] dependencies = [ "colorama; platform_system == \"Windows\"", "importlib-metadata; python_version < \"3.8\"", @@ -700,17 +664,6 @@ files = [ {file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"}, ] -[[package]] -name = "isort" -version = "6.0.1" -requires_python = ">=3.9.0" -summary = "A Python utility / library to sort Python imports." -groups = ["test"] -files = [ - {file = "isort-6.0.1-py3-none-any.whl", hash = "sha256:2dc5d7f65c9678d94c88dfc29161a320eec67328bc97aad576874cb4be1e9615"}, - {file = "isort-6.0.1.tar.gz", hash = "sha256:1cb5df28dfbc742e490c5e41bad6da41b805b0a8be7bc93cd0fb2a8a890ac450"}, -] - [[package]] name = "jinja2" version = "3.1.6" @@ -1114,17 +1067,6 @@ files = [ {file = "pandas-2.2.3.tar.gz", hash = "sha256:4f18ba62b61d7e192368b84517265a99b4d7ee8912f8708660fb4a366cc82667"}, ] -[[package]] -name = "pathspec" -version = "0.12.1" -requires_python = ">=3.8" -summary = "Utility library for gitignore style pattern matching of file paths." -groups = ["test"] -files = [ - {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, - {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, -] - [[package]] name = "pillow" version = "11.1.0" @@ -1194,17 +1136,6 @@ files = [ {file = "pillow-11.1.0.tar.gz", hash = "sha256:368da70808b36d73b4b390a8ffac11069f8a5c85f29eff1f1b01bcf3ef5b2a20"}, ] -[[package]] -name = "platformdirs" -version = "4.3.7" -requires_python = ">=3.9" -summary = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." -groups = ["test"] -files = [ - {file = "platformdirs-4.3.7-py3-none-any.whl", hash = "sha256:a03875334331946f13c549dbd8f4bac7a13a50a895a0eb1e8c6a8ace80d40a94"}, - {file = "platformdirs-4.3.7.tar.gz", hash = "sha256:eb437d586b6a0986388f0d6f74aa0cde27b48d0e3d66843640bfb6bdcdb6e351"}, -] - [[package]] name = "pluggy" version = "1.5.0" diff --git a/pyproject.toml b/pyproject.toml index b6471ce..47c9fc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -175,8 +175,27 @@ unfixable = [] dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" [tool.ruff.lint.per-file-ignores] -"tests/*" = ["S101", "PLR2004"] -"src/askui/chat/*" = ["F401", "F403"] +"src/askui/agent.py" = ["E501"] +"src/askui/chat/*" = ["E501", "F401", "F403"] +"src/askui/tools/askui/askui_workspaces/*" = ["ALL"] +"src/askui/tools/askui/askui_ui_controller_grpc/*" = ["ALL"] +"src/askui/locators/locators.py" = ["E501"] +"src/askui/locators/relatable.py" = ["E501", "SLF001"] +"src/askui/locators/serializers.py" = ["E501", "SLF001"] +"src/askui/models/anthropic/claude_agent.py" = ["E501"] +"src/askui/models/askui/ai_element_utils.py" = ["E501"] +"src/askui/models/huggingface/spaces_api.py" = ["E501"] +"src/askui/models/ui_tars_ep/ui_tars_api.py" = ["E501"] +"src/askui/reporting.py" = ["E501"] +"src/askui/telemetry/telemetry.py" = ["E501"] +"src/askui/utils/image_utils.py" = ["E501"] +"tests/*" = ["S101", "PLR2004", "SLF001"] +"tests/e2e/agent/test_get.py" = ["E501"] +"tests/e2e/agent/test_locate_with_relations.py" = ["E501"] +"tests/unit/locators/test_locators.py" = ["E501"] +"tests/unit/locators/serializers/test_askui_locator_serializer.py" = ["E501"] +"tests/unit/locators/serializers/test_locator_string_representation.py" = ["E501"] +"tests/unit/utils/test_image_utils.py" = ["E501"] [tool.ruff.lint.flake8-quotes] docstring-quotes = "double" diff --git a/src/askui/chat/__main__.py b/src/askui/chat/__main__.py index a270dfe..05b39ae 100644 --- a/src/askui/chat/__main__.py +++ b/src/askui/chat/__main__.py @@ -1,9 +1,8 @@ -import glob import json import logging -import os import re -from datetime import datetime +from datetime import datetime, timezone +from pathlib import Path from random import randint from typing import Union @@ -13,6 +12,7 @@ from askui import VisionAgent from askui.chat.click_recorder import ClickRecorder +from askui.chat.exceptions import FunctionExecutionError, InvalidFunctionError from askui.models import ModelName from askui.reporting import Reporter from askui.utils.image_utils import base64_to_image, draw_point_on_image @@ -23,28 +23,29 @@ ) -CHAT_SESSIONS_DIR_PATH = "./chat/sessions" -CHAT_IMAGES_DIR_PATH = "./chat/images" +CHAT_SESSIONS_DIR_PATH = Path("./chat/sessions") +CHAT_IMAGES_DIR_PATH = Path("./chat/images") click_recorder = ClickRecorder() -def setup_chat_dirs(): - os.makedirs(CHAT_SESSIONS_DIR_PATH, exist_ok=True) - os.makedirs(CHAT_IMAGES_DIR_PATH, exist_ok=True) +def setup_chat_dirs() -> None: + Path.mkdir(CHAT_SESSIONS_DIR_PATH, parents=True, exist_ok=True) + Path.mkdir(CHAT_IMAGES_DIR_PATH, parents=True, exist_ok=True) -def get_session_id_from_path(path): - return os.path.splitext(os.path.basename(path))[0] +def get_session_id_from_path(path: str) -> str: + """Get session ID from file path.""" + return Path(path).stem -def load_chat_history(session_id): - messages = [] - session_path = os.path.join(CHAT_SESSIONS_DIR_PATH, f"{session_id}.jsonl") - if os.path.exists(session_path): - with open(session_path, "r") as f: - for line in f: - messages.append(json.loads(line)) +def load_chat_history(session_id: str) -> list[dict]: + """Load chat history for a given session ID.""" + messages: list[dict] = [] + session_path = CHAT_SESSIONS_DIR_PATH / f"{session_id}.jsonl" + if session_path.exists(): + with session_path.open("r") as f: + messages.extend(json.loads(line) for line in f) return messages @@ -60,7 +61,8 @@ def load_chat_history(session_id): def get_image(img_b64_str_or_path: str) -> Image.Image: - if os.path.isfile(img_b64_str_or_path): + """Get image from base64 string or file path.""" + if Path(img_b64_str_or_path).is_file(): return Image.open(img_b64_str_or_path) return base64_to_image(img_b64_str_or_path) @@ -75,7 +77,7 @@ def write_message( | list[str] | list[Image.Image] | None = None, -): +) -> None: _role = ROLE_MAP.get(role.lower(), UNKNOWN_ROLE) avatar = None if _role != UNKNOWN_ROLE else "❔" with st.chat_message(_role, avatar=avatar): @@ -96,10 +98,11 @@ def write_message( def save_image(image: Image.Image) -> str: - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") - image_path = os.path.join(CHAT_IMAGES_DIR_PATH, f"image_{timestamp}.png") + """Save image to disk and return path.""" + timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S_%f") + image_path = CHAT_IMAGES_DIR_PATH / f"image_{timestamp}.png" image.save(image_path) - return image_path + return str(image_path) class Message(TypedDict): @@ -127,18 +130,15 @@ def add_message( _images = image else: _images = [image] - for img in _images: - image_paths.append(save_image(img)) + image_paths.extend(save_image(img) for img in _images) message = Message( role=role, content=content, - timestamp=datetime.now().isoformat(), + timestamp=datetime.now(tz=timezone.utc).isoformat(), image=image_paths, ) write_message(**message) - with open( - os.path.join(CHAT_SESSIONS_DIR_PATH, f"{self._session_id}.jsonl"), "a" - ) as f: + with (CHAT_SESSIONS_DIR_PATH / f"{self._session_id}.jsonl").open("a") as f: json.dump(message, f) f.write("\n") @@ -147,17 +147,18 @@ def generate(self) -> None: pass -def get_available_sessions(): - session_files = glob.glob(os.path.join(CHAT_SESSIONS_DIR_PATH, "*.jsonl")) +def get_available_sessions() -> list[str]: + """Get list of available session IDs.""" + session_files = list(CHAT_SESSIONS_DIR_PATH.glob("*.jsonl")) return sorted([get_session_id_from_path(f) for f in session_files], reverse=True) def create_new_session() -> str: - timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f") + """Create a new chat session.""" + timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d%H%M%S%f") random_suffix = f"{randint(100, 999)}" session_id = f"{timestamp}{random_suffix}" - with open(os.path.join(CHAT_SESSIONS_DIR_PATH, f"{session_id}.jsonl"), "w") as f: - pass + (CHAT_SESSIONS_DIR_PATH / f"{session_id}.jsonl").touch() return session_id @@ -200,7 +201,7 @@ def paint_crosshair( """ -def rerun(): +def rerun() -> None: st.markdown("### Re-running...") with VisionAgent( log_level=logging.DEBUG, @@ -220,9 +221,8 @@ def rerun(): r"mouse\((\d+),\s*(\d+)\)", message["content"] ): if not screenshot: - raise ValueError( - "Screenshot is required to paint crosshair" - ) + error_msg = "Screenshot is required to paint crosshair" + raise ValueError(error_msg) # noqa: TRY301 x, y = map(int, match.groups()) screenshot_with_crosshair = paint_crosshair( screenshot, (x, y) @@ -235,7 +235,7 @@ def rerun(): write_message( message["role"], f"Move mouse to {element_description}", - datetime.now().isoformat(), + datetime.now(tz=timezone.utc).isoformat(), image=screenshot_with_crosshair, ) agent.mouse_move( @@ -246,7 +246,7 @@ def rerun(): write_message( message["role"], message["content"], - datetime.now().isoformat(), + datetime.now(tz=timezone.utc).isoformat(), message.get("image"), ) func_call = f"agent.tools.os.{message['content']}" @@ -254,9 +254,9 @@ def rerun(): except json.JSONDecodeError: continue except AttributeError: - st.write(f"Invalid function: {message['content']}") - except Exception as e: - st.write(f"Error executing {message['content']}: {str(e)}") + st.write(str(InvalidFunctionError(message["content"]))) + except Exception as e: # noqa: BLE001 - We want to catch all other exceptions here + st.write(str(FunctionExecutionError(message["content"], e))) setup_chat_dirs() diff --git a/src/askui/chat/click_recorder.py b/src/askui/chat/click_recorder.py index 90fb7c0..17b64b8 100644 --- a/src/askui/chat/click_recorder.py +++ b/src/askui/chat/click_recorder.py @@ -1,16 +1,18 @@ -import glob import json import os import subprocess import sys import tempfile from datetime import datetime +from pathlib import Path from typing import List, Tuple from PIL import Image from pydantic import UUID4, BaseModel, ConfigDict from pydantic.alias_generators import to_camel +from askui.chat.exceptions import AnnotationError + Coordinate = Tuple[int, int] @@ -60,9 +62,8 @@ def __init__(self) -> None: def __find_remote_device_controller(self) -> str: if sys.platform == "darwin": return f"{os.environ['ASKUI_INSTALLATION_DIRECTORY']}/DependencyCache/AskUIRemoteDeviceSnippingTool-0.2.0/AskuiRemoteDeviceSnippingTool" - raise NotImplementedError( - "Snipping tool not supported on this platform, yet, as the path was unknown at the time of writing" - ) + error_msg = "Snipping tool not supported on this platform, yet, as the path was unknown at the time of writing" + raise NotImplementedError(error_msg) def __start_process(self, binary_path: str, output_directory: str) -> None: self.process = subprocess.check_output( @@ -71,16 +72,17 @@ def __start_process(self, binary_path: str, output_directory: str) -> None: def annotate(self) -> Tuple[Image.Image, AnnoationContainer]: with tempfile.TemporaryDirectory() as tempdir: + tempdir_path = Path(tempdir) self.__start_process(self.__find_remote_device_controller(), tempdir) - json_files = glob.glob(tempdir + "/*.json") - png_files = glob.glob(tempdir + "/*.png") + json_files = list(tempdir_path.glob("*.json")) + png_files = list(tempdir_path.glob("*.png")) if len(json_files) != 1 or len(png_files) != 1: - raise Exception("No annotation Done!") + raise AnnotationError json_file = json_files[0] annotation = None - with open(json_file) as json_data: + with Path.open(json_file) as json_data: annotation = AnnoationContainer(**json.load(json_data)) return Image.open(png_files[0]).copy(), annotation diff --git a/src/askui/chat/exceptions.py b/src/askui/chat/exceptions.py new file mode 100644 index 0000000..3c70c17 --- /dev/null +++ b/src/askui/chat/exceptions.py @@ -0,0 +1,47 @@ +"""Exceptions for the chat module.""" + + +class ChatError(Exception): + """Base exception for chat-related errors.""" + + def __init__(self, message: str): + self.message = message + super().__init__(self.message) + + +class InvalidFunctionError(ChatError): + """Exception raised when an invalid function is called.""" + + def __init__(self, function_name: str): + super().__init__(f"Invalid function: {function_name}") + + +class FunctionExecutionError(ChatError): + """Exception raised when a function execution fails.""" + + def __init__(self, function_name: str, error: Exception): + super().__init__(f"Error executing {function_name}: {str(error)}") + self.original_error = error + + +class AnnotationError(ChatError): + """Exception raised when annotation is not done or invalid.""" + + def __init__(self, message: str = "No annotation Done!"): + super().__init__(message) + + +class ActionTimeoutError(ChatError): + """Exception raised when an action times out.""" + + def __init__(self, message: str = "Action not yet done"): + super().__init__(message) + + +__all__ = [ + "ChatError", + "InvalidFunctionError", + "FunctionExecutionError", + "AnnotationError", + "ActionTimeoutError", +] diff --git a/src/askui/locators/locators.py b/src/askui/locators/locators.py index bcaf8a5..7e3f387 100644 --- a/src/askui/locators/locators.py +++ b/src/askui/locators/locators.py @@ -188,9 +188,8 @@ def __init__( ) -> None: super().__init__() if threshold > stop_threshold: - raise ValueError( - f"threshold ({threshold}) must be less than or equal to stop_threshold ({stop_threshold})" - ) + error_msg = f"threshold ({threshold}) must be less than or equal to stop_threshold ({stop_threshold})" + raise ValueError(error_msg) self._threshold = threshold self._stop_threshold = stop_threshold self._mask = mask diff --git a/src/askui/locators/relatable.py b/src/askui/locators/relatable.py index 87e8b56..99c1a7f 100644 --- a/src/askui/locators/relatable.py +++ b/src/askui/locators/relatable.py @@ -278,7 +278,7 @@ def __init__( super().__init__(message) -class Relatable(ABC): +class Relatable(ABC): # noqa: B024 """Abstract base class for locators that can be related to other locators, e.g., spatially, logically etc. Cannot be instantiated directly. Subclassed by all (relatable) locators, e.g., `Prompt`, `Text`, `Image`, etc.""" @@ -984,15 +984,17 @@ def _relations_str(self) -> str: for i, relation in enumerate(self._relations): [other_locator_str, *nested_relation_strs] = str(relation).split("\n") result.append(f" {i + 1}. {other_locator_str}") - for nested_relation_str in nested_relation_strs: - result.append(f" {nested_relation_str}") + result.extend( + f" {nested_relation_str}" + for nested_relation_str in nested_relation_strs + ) return "\n" + "\n".join(result) def _str_with_relation(self) -> str: return self._str() + self._relations_str() def raise_if_cycle(self) -> None: - """Raises CircularDependencyError if the relations form a cycle (see [Cycle (graph theory)](https://en.wikipedia.org/wiki/Cycle_(graph_theory))).""" + """Raises CircularDependencyError if the relations form a cycle (see [Cycle (graph theory)](https://en.wikipedia.org/wiki/Cycle_(graph_theory))).""" # noqa: E501 if self._has_cycle(): raise CircularDependencyError diff --git a/src/askui/locators/serializers.py b/src/askui/locators/serializers.py index b9dad21..32dd398 100644 --- a/src/askui/locators/serializers.py +++ b/src/askui/locators/serializers.py @@ -31,9 +31,10 @@ class VlmLocatorSerializer: def serialize(self, locator: Relatable) -> str: locator.raise_if_cycle() if len(locator._relations) > 0: - raise NotImplementedError( + error_msg = ( "Serializing locators with relations is not yet supported for VLMs" ) + raise NotImplementedError(error_msg) if isinstance(locator, Text): return self._serialize_text(locator) @@ -42,14 +43,13 @@ def serialize(self, locator: Relatable) -> str: if isinstance(locator, Prompt): return self._serialize_prompt(locator) if isinstance(locator, Image): - raise NotImplementedError( - "Serializing image locators is not yet supported for VLMs" - ) + error_msg = "Serializing image locators is not yet supported for VLMs" + raise NotImplementedError(error_msg) if isinstance(locator, AiElementLocator): - raise NotImplementedError( - "Serializing AI element locators is not yet supported for VLMs" - ) - raise ValueError(f"Unsupported locator type: {type(locator)}") + error_msg = "Serializing AI element locators is not yet supported for VLMs" + raise NotImplementedError(error_msg) + error_msg = f"Unsupported locator type: {type(locator)}" + raise ValueError(error_msg) def _serialize_class(self, class_: Element) -> str: if class_._class_name: @@ -107,10 +107,11 @@ def __init__(self, ai_element_collection: AiElementCollection, reporter: Reporte def serialize(self, locator: Relatable) -> AskUiSerializedLocator: locator.raise_if_cycle() if len(locator._relations) > 1: - # If we lift this constraint, we also have to make sure that custom element references are still working + we need, e.g., some symbol or a structured format to indicate precedence - raise NotImplementedError( - "Serializing locators with multiple relations is not yet supported by AskUI" - ) + # If we lift this constraint, we also have to make sure that custom element + # references are still working + we need, e.g., some symbol or a structured + # format to indicate precedence + error_msg = "Serializing locators with multiple relations is not yet supported by AskUI" + raise NotImplementedError(error_msg) result = AskUiSerializedLocator(instruction="", customElements=[]) if isinstance(locator, Text): @@ -124,7 +125,8 @@ def serialize(self, locator: Relatable) -> AskUiSerializedLocator: elif isinstance(locator, AiElementLocator): result = self._serialize_ai_element(locator) else: - raise ValueError(f'Unsupported locator type: "{type(locator)}"') + error_msg = f'Unsupported locator type: "{type(locator)}"' + raise TypeError(error_msg) if len(locator._relations) == 0: return result @@ -154,15 +156,13 @@ def _serialize_text(self, text: Text) -> str: return ( f"text {self._TEXT_DELIMITER}{text._text}{self._TEXT_DELIMITER}" ) - return f"text with text {self._TEXT_DELIMITER}{text._text}{self._TEXT_DELIMITER} that matches to {text._similarity_threshold} %" + return f"text with text {self._TEXT_DELIMITER}{text._text}{self._TEXT_DELIMITER} that matches to {text._similarity_threshold} %" # noqa: E501 case "exact": - return f"text equals text {self._TEXT_DELIMITER}{text._text}{self._TEXT_DELIMITER}" + return f"text equals text {self._TEXT_DELIMITER}{text._text}{self._TEXT_DELIMITER}" # noqa: E501 case "contains": - return f"text contain text {self._TEXT_DELIMITER}{text._text}{self._TEXT_DELIMITER}" + return f"text contain text {self._TEXT_DELIMITER}{text._text}{self._TEXT_DELIMITER}" # noqa: E501 case "regex": - return f"text match regex pattern {self._TEXT_DELIMITER}{text._text}{self._TEXT_DELIMITER}" - case _: - raise ValueError(f'Unsupported text match type: "{text.match_type}"') + return f"text match regex pattern {self._TEXT_DELIMITER}{text._text}{self._TEXT_DELIMITER}" # noqa: E501 def _serialize_relation(self, relation: Relation) -> AskUiSerializedLocator: match relation.type: @@ -174,15 +174,13 @@ def _serialize_relation(self, relation: Relation) -> AskUiSerializedLocator: relation, LogicalRelation | BoundingRelation | NearestToRelation ) return self._serialize_non_neighbor_relation(relation) - case _: - raise ValueError(f'Unsupported relation type: "{relation.type}"') def _serialize_neighbor_relation( self, relation: NeighborRelation ) -> AskUiSerializedLocator: serialized_other_locator = self.serialize(relation.other_locator) return AskUiSerializedLocator( - instruction=f"index {relation.index} {self._RELATION_TYPE_MAPPING[relation.type]} intersection_area {self._RP_TO_INTERSECTION_AREA_MAPPING[relation.reference_point]} {serialized_other_locator['instruction']}", + instruction=f"index {relation.index} {self._RELATION_TYPE_MAPPING[relation.type]} intersection_area {self._RP_TO_INTERSECTION_AREA_MAPPING[relation.reference_point]} {serialized_other_locator['instruction']}", # noqa: E501 customElements=serialized_other_locator["customElements"], ) @@ -191,7 +189,7 @@ def _serialize_non_neighbor_relation( ) -> AskUiSerializedLocator: serialized_other_locator = self.serialize(relation.other_locator) return AskUiSerializedLocator( - instruction=f"{self._RELATION_TYPE_MAPPING[relation.type]} {serialized_other_locator['instruction']}", + instruction=f"{self._RELATION_TYPE_MAPPING[relation.type]} {serialized_other_locator['instruction']}", # noqa: E501 customElements=serialized_other_locator["customElements"], ) @@ -225,7 +223,7 @@ def _serialize_image_base( for image_source in image_sources ] return AskUiSerializedLocator( - instruction=f"custom element with text {self._TEXT_DELIMITER}{image_locator._name}{self._TEXT_DELIMITER}", + instruction=f"custom element with text {self._TEXT_DELIMITER}{image_locator._name}{self._TEXT_DELIMITER}", # noqa: E501 customElements=custom_elements, ) diff --git a/src/askui/models/anthropic/__init__.py b/src/askui/models/anthropic/__init__.py new file mode 100644 index 0000000..8b67cb2 --- /dev/null +++ b/src/askui/models/anthropic/__init__.py @@ -0,0 +1 @@ +"""Anthropic model implementations.""" diff --git a/src/askui/models/anthropic/claude.py b/src/askui/models/anthropic/claude.py index 89ea6d2..92a5382 100644 --- a/src/askui/models/anthropic/claude.py +++ b/src/askui/models/anthropic/claude.py @@ -1,8 +1,15 @@ +import json import os import anthropic from PIL import Image +from askui.exceptions import ( + ElementNotFoundError, + NoResponseToQueryError, + UnexpectedResponseToQueryError, +) +from askui.logger import logger from askui.utils.image_utils import ( ImageSource, image_to_base64, @@ -10,12 +17,6 @@ scale_image_with_padding, ) -from ...exceptions import ( - ElementNotFoundError, - NoResponseToQueryError, - UnexpectedResponseToQueryError, -) -from ...logger import logger from .utils import extract_click_coordinates @@ -58,7 +59,7 @@ def _inference( def locate_inference(self, image: Image.Image, locator: str) -> tuple[int, int]: prompt = f"Click on {locator}" screen_width, screen_height = self.resolution[0], self.resolution[1] - system_prompt = f"Use a mouse and keyboard to interact with a computer, and take screenshots.\n* This is an interface to a desktop GUI. You do not have access to a terminal or applications menu. You must click on desktop icons to start applications.\n* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try taking another screenshot.\n* The screen's resolution is {screen_width}x{screen_height}.\n* The display number is 0\n* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.\n* If you tried clicking on a program or link but it failed to load, even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.\n* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked.\n" + system_prompt = f"Use a mouse and keyboard to interact with a computer, and take screenshots.\n* This is an interface to a desktop GUI. You do not have access to a terminal or applications menu. You must click on desktop icons to start applications.\n* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try taking another screenshot.\n* The screen's resolution is {screen_width}x{screen_height}.\n* The display number is 0\n* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.\n* If you tried clicking on a program or link but it failed to load, even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.\n* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked.\n" # noqa: E501 scaled_image = scale_image_with_padding(image, screen_width, screen_height) response = self._inference(image_to_base64(scaled_image), prompt, system_prompt) assert len(response) > 0 @@ -67,8 +68,9 @@ def locate_inference(self, image: Image.Image, locator: str) -> tuple[int, int]: logger.debug("ClaudeHandler received locator: %s", r.text) try: scaled_x, scaled_y = extract_click_coordinates(r.text) - except Exception: - raise ElementNotFoundError(f"Element not found: {locator}") + except (ValueError, json.JSONDecodeError) as e: + error_msg = f"Element not found: {locator}" + raise ElementNotFoundError(error_msg) from e x, y = scale_coordinates_back( scaled_x, scaled_y, image.width, image.height, screen_width, screen_height ) @@ -80,19 +82,17 @@ def get_inference(self, image: ImageSource, query: str) -> str: max_width=self.resolution[0], max_height=self.resolution[1], ) - system_prompt = "You are an agent to process screenshots and answer questions about things on the screen or extract information from it. Answer only with the response to the question and keep it short and precise." + system_prompt = "You are an agent to process screenshots and answer questions about things on the screen or extract information from it. Answer only with the response to the question and keep it short and precise." # noqa: E501 response = self._inference( base64_image=image_to_base64(scaled_image), prompt=query, system_prompt=system_prompt, ) if len(response) == 0: - raise NoResponseToQueryError( - f"No response from Claude to query: {query}", query - ) + error_msg = f"No response from Claude to query: {query}" + raise NoResponseToQueryError(error_msg, query) r = response[0] if r.type == "text": return r.text - raise UnexpectedResponseToQueryError( - f"Unexpected response from Claude: {r}", query, r - ) + error_msg = f"Unexpected response from Claude: {r}" + raise UnexpectedResponseToQueryError(error_msg, query, r) diff --git a/src/askui/models/anthropic/claude_agent.py b/src/askui/models/anthropic/claude_agent.py index 4a15df0..564f4e6 100644 --- a/src/askui/models/anthropic/claude_agent.py +++ b/src/askui/models/anthropic/claude_agent.py @@ -1,6 +1,6 @@ import platform import sys -from datetime import datetime +from datetime import datetime, timezone from typing import Any, cast from anthropic import ( @@ -159,7 +159,7 @@ * When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available. * When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. * Valid keyboard keys available are {", ".join(PC_KEY)} -* The current date is {datetime.today().strftime("%A, %B %d, %Y").replace(" 0", " ")}. +* The current date is {datetime.now(tz=timezone.utc).strftime("%A, %B %d, %Y").replace(" 0", " ")}. diff --git a/src/askui/models/askui/__init__.py b/src/askui/models/askui/__init__.py new file mode 100644 index 0000000..9eaf5b3 --- /dev/null +++ b/src/askui/models/askui/__init__.py @@ -0,0 +1 @@ +"""AskUI model implementations.""" diff --git a/src/askui/models/askui/ai_element_utils.py b/src/askui/models/askui/ai_element_utils.py index cfb182a..a9629a5 100644 --- a/src/askui/models/askui/ai_element_utils.py +++ b/src/askui/models/askui/ai_element_utils.py @@ -96,7 +96,8 @@ def __init__( workspace_id = os.getenv("ASKUI_WORKSPACE_ID") if workspace_id is None: - raise ValueError("ASKUI_WORKSPACE_ID is not set") + error_msg = "ASKUI_WORKSPACE_ID is not set" + raise ValueError(error_msg) locations_from_env: list[pathlib.Path] = [] if locations_env := os.getenv("ASKUI_AI_ELEMENT_LOCATIONS"): diff --git a/src/askui/models/askui/api.py b/src/askui/models/askui/api.py index e7eac32..4b3f066 100644 --- a/src/askui/models/askui/api.py +++ b/src/askui/models/askui/api.py @@ -15,6 +15,7 @@ from askui.utils.image_utils import ImageSource, image_to_base64 from ..types.response_schemas import ResponseSchema, to_response_schema +from .exceptions import ApiResponseError, TokenNotSetError class AskUiInferenceApi: @@ -37,12 +38,15 @@ def _build_askui_token_auth_header( return {"Authorization": f"Bearer {bearer_token}"} if self.token is None: - raise Exception("ASKUI_TOKEN is not set.") + raise TokenNotSetError token_base64 = base64.b64encode(self.token.encode("utf-8")).decode("utf-8") return {"Authorization": f"Basic {token_base64}"} def _build_base_url(self, endpoint: str) -> str: - return f"{self.inference_endpoint}/api/v3/workspaces/{self.workspace_id}/{endpoint}" + return ( + f"{self.inference_endpoint}/api/v3/workspaces/" + f"{self.workspace_id}/{endpoint}" + ) def _request(self, endpoint: str, json: dict[str, Any] | None = None) -> Any: response = requests.post( @@ -55,9 +59,7 @@ def _request(self, endpoint: str, json: dict[str, Any] | None = None) -> Any: timeout=30, ) if response.status_code != 200: - raise Exception( - f"{response.status_code}: Unknown Status Code\n", response.text - ) + raise ApiResponseError(response.status_code, response.text) return response.json() diff --git a/src/askui/models/askui/exceptions.py b/src/askui/models/askui/exceptions.py new file mode 100644 index 0000000..97be6d6 --- /dev/null +++ b/src/askui/models/askui/exceptions.py @@ -0,0 +1,28 @@ +class AskUiApiError(Exception): + """Base exception for AskUI API errors.""" + + def __init__(self, message: str) -> None: + self.message = message + super().__init__(self.message) + + +class TokenNotSetError(AskUiApiError): + """Exception raised when a token is not set.""" + + def __init__(self, message: str = "Token not set") -> None: + super().__init__(message) + + +class ApiResponseError(AskUiApiError): + """Exception raised when an API response is not as expected.""" + + def __init__(self, status_code: int, message: str) -> None: + self.status_code = status_code + super().__init__(f"API response error: {status_code} - {message}") + + +__all__ = [ + "AskUiApiError", + "TokenNotSetError", + "ApiResponseError", +] diff --git a/src/askui/models/huggingface/__init__.py b/src/askui/models/huggingface/__init__.py new file mode 100644 index 0000000..2f46c77 --- /dev/null +++ b/src/askui/models/huggingface/__init__.py @@ -0,0 +1 @@ +"""Hugging Face model implementations.""" diff --git a/src/askui/models/huggingface/spaces_api.py b/src/askui/models/huggingface/spaces_api.py index d3ece41..27dd3ef 100644 --- a/src/askui/models/huggingface/spaces_api.py +++ b/src/askui/models/huggingface/spaces_api.py @@ -2,6 +2,7 @@ import tempfile from typing import Callable +import httpx from gradio_client import Client, handle_file # type: ignore from PIL import Image @@ -55,10 +56,12 @@ def _rescale_bounding_boxes( # type: ignore def predict( self, screenshot: Image.Image, locator: str, model_name: str = "AskUI/PTA-1" ) -> tuple[int, int]: + """Predict element location using Hugging Face Spaces.""" try: return self.spaces[model_name](screenshot, locator, model_name) - except Exception as e: - raise AutomationError(f"Hugging Face Spaces Exception: {e}") + except (ValueError, json.JSONDecodeError, httpx.HTTPError) as e: + error_msg = f"Hugging Face Spaces Exception: {e}" + raise AutomationError(error_msg) from e def predict_askui_pta1( self, screenshot: Image.Image, locator: str, model_name: str | None = None @@ -126,8 +129,12 @@ def predict_qwen2_vl( return x, y def predict_showui( - self, screenshot: Image.Image, locator: str, model_name: str | None = None + self, + screenshot: Image.Image, + locator: str, + model_name: str | None = None, # noqa: ARG002 ) -> tuple[int, int]: + """Predict element location using ShowUI model.""" client = self.get_space_client("showlab/ShowUI") with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: screenshot.save(temp_file, format="PNG") diff --git a/src/askui/models/models.py b/src/askui/models/models.py index 961c9a4..553659b 100644 --- a/src/askui/models/models.py +++ b/src/askui/models/models.py @@ -30,19 +30,29 @@ class ModelDefinition(BaseModel): A definition of a model. Args: - task (str): The task the model is trained for, e.g., end-to-end OCR (`"e2e_ocr"`) or object detection (`"od"`) - architecture (str): The architecture of the model, e.g., `"easy_ocr"` or `"yolo"` + task (str): The task the model is trained for, e.g., end-to-end OCR + (`"e2e_ocr"`) or object detection (`"od"`) + architecture (str): The architecture of the model, e.g., `"easy_ocr"` or + `"yolo"` version (str): The version of the model - interface (str): The interface the model is trained for, e.g., `"online_learning"` - use_case (str, optional): The use case the model is trained for. In the case of workspace specific AskUI models, this is often the workspace id but with "-" replaced by "_". Defaults to `"00000000_0000_0000_0000_000000000000"` (custom null value). - tags (list[str], optional): Tags for identifying the model that cannot be represented by other properties, e.g., `["trained", "word_level"]` + interface (str): The interface the model is trained for, e.g., + `"online_learning"` + use_case (str, optional): The use case the model is trained for. In the case + of workspace specific AskUI models, this is often the workspace id but + with "-" replaced by "_". Defaults to + `"00000000_0000_0000_0000_000000000000"` (custom null value). + tags (list[str], optional): Tags for identifying the model that cannot be + represented by other properties, e.g., `["trained", "word_level"]` """ model_config = ConfigDict( populate_by_name=True, ) task: ModelDefinitionProperty = Field( - description="The task the model is trained for, e.g., end-to-end OCR (e2e_ocr) or object detection (od)", + description=( + "The task the model is trained for, e.g., end-to-end OCR (e2e_ocr) or " + "object detection (od)" + ), examples=["e2e_ocr", "od"], ) architecture: ModelDefinitionProperty = Field( @@ -54,7 +64,10 @@ class ModelDefinition(BaseModel): examples=["online_learning"], ) use_case: ModelDefinitionProperty = Field( - description='The use case the model is trained for. In the case of workspace specific AskUI models, this is often the workspace id but with "-" replaced by "_"', + description=( + "The use case the model is trained for. In the case of workspace specific " + 'AskUI models, this is often the workspace id but with "-" replaced by "_"' + ), examples=[ "fb3b9a7b_3aea_41f7_ba02_e55fd66d1c1e", "00000000_0000_0000_0000_000000000000", @@ -64,7 +77,10 @@ class ModelDefinition(BaseModel): ) tags: list[ModelDefinitionProperty] = Field( default_factory=list, - description="Tags for identifying the model that cannot be represented by other properties", + description=( + "Tags for identifying the model that cannot be represented by other " + "properties" + ), examples=["trained", "word_level"], ) @@ -87,7 +103,9 @@ def model_name(self) -> str: class ModelComposition(RootModel[list[ModelDefinition]]): """ - A composition of models (list of `ModelDefinition`) to be used for a task, e.g., locating an element on the screen to be able to click on it or extracting text from an image. + A composition of models (list of `ModelDefinition`) to be used for a task, e.g., + locating an element on the screen to be able to click on it or extracting text from + an image. """ def __iter__(self) -> Iterator[ModelDefinition]: # type: ignore diff --git a/src/askui/models/router.py b/src/askui/models/router.py index 45836c5..4d41141 100644 --- a/src/askui/models/router.py +++ b/src/askui/models/router.py @@ -33,7 +33,8 @@ def handle_response( ) -> tuple[int, int]: x, y = response if x is None or y is None: - raise ElementNotFoundError(f"Element not found: {locator}") + error_msg = f"Element not found: {locator}" + raise ElementNotFoundError(error_msg) return x, y @@ -75,9 +76,11 @@ def locate( model: ModelComposition | str | None = None, ) -> Point: if not self._inference_api.authenticated: - raise AutomationError( - "NoAskUIAuthenticationSet! Please set 'AskUI ASKUI_WORKSPACE_ID' or 'ASKUI_TOKEN' as env variables!" + error_msg = ( + "NoAskUIAuthenticationSet! Please set 'AskUI ASKUI_WORKSPACE_ID' or " + "'ASKUI_TOKEN' as env variables!" ) + raise AutomationError(error_msg) if not isinstance(model, str) or model == ModelName.ASKUI: logger.debug("Routing locate prediction to askui") locator = Text(locator) if isinstance(locator, str) else locator @@ -85,9 +88,12 @@ def locate( x, y = self._inference_api.predict(screenshot, locator, _model) return handle_response((x, y), locator) if not isinstance(locator, str): - raise AutomationError( - f'Locators of type `{type(locator)}` are not supported for models "askui-pta", "askui-ocr" and "askui-combo" and "askui-ai-element". Please provide a `str`.' + error_msg = ( + f"Locators of type `{type(locator)}` are not supported for models " + '"askui-pta", "askui-ocr" and "askui-combo" and "askui-ai-element". ' + "Please provide a `str`." ) + raise AutomationError(error_msg) if model == ModelName.ASKUI__PTA: logger.debug("Routing locate prediction to askui-pta") x, y = self._inference_api.predict(screenshot, Prompt(locator)) @@ -107,7 +113,8 @@ def locate( _locator = AiElement(locator) x, y = self._inference_api.predict(screenshot, _locator) return handle_response((x, y), _locator) - raise AutomationError(f'Invalid model: "{model}"') + error_msg = f'Invalid model: "{model}"' + raise AutomationError(error_msg) @override def is_responsible(self, model: ModelComposition | str | None = None) -> bool: @@ -152,7 +159,8 @@ def act(self, goal: str, model: ModelComposition | str | None = None) -> None: and model.startswith(ModelName.ANTHROPIC) ): self._claude_computer_agent.run(goal) - raise AutomationError(f"Invalid model for act: {model}") + error_msg = f"Invalid model for act: {model}" + raise AutomationError(error_msg) def get_inference( self, @@ -163,17 +171,21 @@ def get_inference( ) -> ResponseSchema | str: if self._tars.authenticated and model == ModelName.TARS: if response_schema not in [str, None]: - raise NotImplementedError( - "(Non-String) Response schema is not yet supported for UI-TARS models." + error_msg = ( + "(Non-String) Response schema is not yet supported for " + "UI-TARS models." ) + raise NotImplementedError(error_msg) return self._tars.get_inference(image=image, query=query) if self._claude.authenticated and ( isinstance(model, str) and model.startswith(ModelName.ANTHROPIC) ): if response_schema not in [str, None]: - raise NotImplementedError( - "(Non-String) Response schema is not yet supported for Anthropic models." + error_msg = ( + "(Non-String) Response schema is not yet supported for " + "Anthropic models." ) + raise NotImplementedError(error_msg) return self._claude.get_inference(image=image, query=query) if self._askui.authenticated and (model == ModelName.ASKUI or model is None): return self._askui.get_inference( @@ -181,9 +193,11 @@ def get_inference( query=query, response_schema=response_schema, ) - raise AutomationError( - f"Executing get commands requires to authenticate with an Automation Model Provider supporting it: {model}" + error_msg = ( + "Executing get commands requires to authenticate with an Automation " + f"Model Provider supporting it: {model}" ) + raise AutomationError(error_msg) def _serialize_locator(self, locator: str | Locator) -> str: if isinstance(locator, Locator): @@ -191,7 +205,7 @@ def _serialize_locator(self, locator: str | Locator) -> str: return locator @telemetry.record_call(exclude={"locator", "screenshot"}) - def locate( + def locate( # noqa: C901 self, screenshot: Image.Image, locator: str | Locator, @@ -211,13 +225,16 @@ def locate( return handle_response((x, y), locator) if isinstance(model, str): if model.startswith(ModelName.ANTHROPIC) and not self._claude.authenticated: - raise AutomationError( + error_msg = ( "You need to provide Anthropic credentials to use Anthropic models." ) + raise AutomationError(error_msg) if model.startswith(ModelName.TARS) and not self._tars.authenticated: - raise AutomationError( - "You need to provide UI-TARS HF Endpoint credentials to use UI-TARS models." + error_msg = ( + "You need to provide UI-TARS HF Endpoint credentials to use " + "UI-TARS models." ) + raise AutomationError(error_msg) if self._tars.authenticated and model == ModelName.TARS: x, y = self._tars.locate_prediction( screenshot, self._serialize_locator(locator) @@ -249,6 +266,8 @@ def locate( ) return handle_response((x, y), locator) - raise AutomationError( - "Executing locate commands requires to authenticate with an Automation Model Provider." + error_msg = ( + "Executing locate commands requires to authenticate with an " + "Automation Model Provider." ) + raise AutomationError(error_msg) diff --git a/src/askui/models/types/response_schemas.py b/src/askui/models/types/response_schemas.py index 1191ffb..155564a 100644 --- a/src/askui/models/types/response_schemas.py +++ b/src/askui/models/types/response_schemas.py @@ -4,9 +4,11 @@ class ResponseSchemaBase(BaseModel): - """Base class for response schemas to be used for defining the response of data extraction, e.g., using `askui.VisionAgent.get()`. + """Response schemas for defining the response of data extraction, e.g., using + `askui.VisionAgent.get()`. - This class extends Pydantic's BaseModel and adds constraints and configuration on top so that it can be used with models to define the schema (type) of the data to be extracted. + This module adds constraints and configuration on top so that it can be used with + models to define the schema (type) of the data to be extracted. Example: ```python @@ -17,9 +19,13 @@ class UrlResponse(ResponseSchemaBase): class NestedResponse(ResponseSchemaBase): nested: UrlResponse - # metadata, e.g., `examples` or `description` of `Field`, is generally also passed to and considered by the models + # metadata, e.g., `examples` or `description` of `Field`, is generally also + # passed to and considered by the models class UrlResponse(ResponseSchemaBase): - url: str = Field(description="The URL of the response. Should used `\"https\"` scheme.", examples=["https://www.example.com"]) + url: str = Field( + description="The URL of the response. Should used `\"https\"` scheme.", + examples=["https://www.example.com"] + ) ``` """ @@ -42,8 +48,9 @@ class UrlResponse(ResponseSchemaBase): - `int`: Integer responses - `float`: Floating point responses -Usually, serialized as a JSON schema, e.g., `str` as `{"type": "string"}`, to be passed to model(s). -Also used for validating the responses of the model(s) used for data extraction. +Usually, serialized as a JSON schema, e.g., `str` as `{"type": "string"}`, to be +passed to model(s). Also used for validating the responses of the model(s) used for +data extraction. """ @@ -87,4 +94,5 @@ def to_response_schema( return Float if issubclass(response_schema, ResponseSchemaBase): return response_schema - raise ValueError(f"Invalid response schema type: {response_schema}") + error_msg = f"Invalid response schema type: {response_schema}" + raise ValueError(error_msg) diff --git a/src/askui/models/ui_tars_ep/__init__.py b/src/askui/models/ui_tars_ep/__init__.py new file mode 100644 index 0000000..afcaec4 --- /dev/null +++ b/src/askui/models/ui_tars_ep/__init__.py @@ -0,0 +1 @@ +"""UI TARS model implementations.""" diff --git a/src/askui/models/ui_tars_ep/parser.py b/src/askui/models/ui_tars_ep/parser.py index 39371c0..1d42977 100644 --- a/src/askui/models/ui_tars_ep/parser.py +++ b/src/askui/models/ui_tars_ep/parser.py @@ -15,7 +15,8 @@ def parse(cls, coord_str: str) -> "BoxCoordinate": """Parse a coordinate string in the format (x,y).""" match = re.match(r"\((\d+),(\d+)\)", coord_str) if not match: - raise ValueError(f"Invalid coordinate format: {coord_str}") + error_msg = f"Invalid coordinate format: {coord_str}" + raise ValueError(error_msg) return cls(x=int(match.group(1)), y=int(match.group(2))) @@ -133,12 +134,13 @@ def parse_message(cls, message: str) -> "UITarsEPMessage": return cls(thought=thought, raw_action=action, parsed_action=parsed_action) @staticmethod - def parse_action(action_str: str) -> ActionType: + def parse_action(action_str: str) -> ActionType: # noqa: C901 """Parse the action string into the appropriate action type.""" # Extract action type and parameters match = re.match(r"(\w+)\((.*)\)", action_str) if not match: - raise ValueError(f"Invalid action format: {action_str}") + error_msg = f"Invalid action format: {action_str}" + raise ValueError(error_msg) action_type, params_str = match.groups() @@ -178,4 +180,5 @@ def parse_action(action_str: str) -> ActionType: return FinishedAction() if action_type == "call_user": return CallUserAction() - raise ValueError(f"Unknown action type: {action_type}") + error_msg = f"Unknown action type: {action_type}" + raise ValueError(error_msg) diff --git a/src/askui/models/ui_tars_ep/prompts.py b/src/askui/models/ui_tars_ep/prompts.py index bc03018..dddb63c 100644 --- a/src/askui/models/ui_tars_ep/prompts.py +++ b/src/askui/models/ui_tars_ep/prompts.py @@ -1,4 +1,6 @@ -PROMPT = r"""You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. +PROMPT = r"""You are a GUI agent. +You are given a task and your action history, with screenshots. +You need to perform the next action to complete the task. ## Output Format ```\nThought: ... @@ -8,22 +10,32 @@ click(start_box='<|box_start|>(x1,y1)<|box_end|>') left_double(start_box='<|box_start|>(x1,y1)<|box_end|>') right_single(start_box='<|box_start|>(x1,y1)<|box_end|>') -drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>') +drag( + start_box='<|box_start|>(x1,y1)<|box_end|>', + end_box='<|box_start|>(x3,y3)<|box_end|>', +) hotkey(key='') type(content='') #If you want to submit your input, use \"\\" at the end of `content`. -scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left') +scroll( + start_box='<|box_start|>(x1,y1)<|box_end|>', + direction='down or up or right or left', +) wait() #Sleep for 5s and take a screenshot to check for any changes. finished() -call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help. +call_user() # Submit the task and call the user when the task is unsolvable, or +# when you need the user's help. ## Note - Use English in `Thought` part. -- Summarize your next action (with its target element) in one sentence in `Thought` part. +- Summarize your next action (with its target element) in one sentence in + `Thought` part. ## User Instruction """ -PROMPT_QA = r"""You are a GUI agent for screen QA. Your are given a question and a screenshot with the answer on it. Your goal is to answer the question. +PROMPT_QA = r"""You are a GUI agent for screen QA. +Your are given a question and a screenshot with the answer on it. +Your goal is to answer the question. ## Output Format ```\nAnswer: ...\n``` diff --git a/src/askui/models/ui_tars_ep/ui_tars_api.py b/src/askui/models/ui_tars_ep/ui_tars_api.py index 00df480..0bb34ce 100644 --- a/src/askui/models/ui_tars_ep/ui_tars_api.py +++ b/src/askui/models/ui_tars_ep/ui_tars_api.py @@ -65,6 +65,7 @@ def locate_prediction( instruction=askui_locator, prompt=PROMPT, ) + assert prediction is not None pattern = r"click\(start_box='(\(\d+,\d+\))'\)" match = re.search(pattern, prediction) if match: @@ -85,9 +86,8 @@ def get_inference(self, image: ImageSource, query: str) -> str: prompt=PROMPT_QA, ) if response is None: - raise NoResponseToQueryError( - f"No response from UI-TARS to query: {query}", query - ) + error_msg = f"No response from UI-TARS to query: {query}" + raise NoResponseToQueryError(error_msg, query) return response def act(self, goal: str) -> None: @@ -99,7 +99,9 @@ def act(self, goal: str) -> None: { "type": "image_url", "image_url": { - "url": f"data:image/png;base64,{image_to_base64(screenshot)}" + "url": ( + f"data:image/png;base64,{image_to_base64(screenshot)}" + ) }, }, {"type": "text", "text": PROMPT + goal}, @@ -117,7 +119,9 @@ def add_screenshot_to_history(self, message_history: list[dict[str, Any]]) -> No { "type": "image_url", "image_url": { - "url": f"data:image/png;base64,{image_to_base64(screenshot)}" + "url": ( + f"data:image/png;base64,{image_to_base64(screenshot)}" + ) }, } ], @@ -190,7 +194,7 @@ def execute_act(self, message_history: list[dict[str, Any]]) -> None: try: message = UITarsEPMessage.parse_message(raw_message) print(message) - except Exception as e: + except Exception as e: # noqa: BLE001 - We want to catch all other exceptions here message_history.append( {"role": "user", "content": [{"type": "text", "text": str(e)}]} ) diff --git a/src/askui/reporting.py b/src/askui/reporting.py index cdb0404..1dc266e 100644 --- a/src/askui/reporting.py +++ b/src/askui/reporting.py @@ -4,7 +4,7 @@ import random import sys from abc import ABC, abstractmethod -from datetime import datetime +from datetime import datetime, timezone from importlib.metadata import distributions from io import BytesIO from pathlib import Path @@ -31,9 +31,12 @@ def add_message( """Add a message to the report. Args: - role (str): The role of the message sender (e.g., `"User"`, `"Assistant"`, `"System"`) - content (Union[str, dict, list]): The message content, which can be a string, dictionary, or list, e.g. `'click 2x times on text "Edit"'` - image (Optional[PIL.Image.Image | list[PIL.Image.Image]], optional): PIL Image or list of PIL Images to include with the message + role (str): The role of the message sender (e.g., `"User"`, `"Assistant"`, + `"System"`) + content (str | dict | list): The message content, which can be a string, + dictionary, or list, e.g. `'click 2x times on text "Edit"'` + image (PIL.Image.Image | list[PIL.Image.Image], optional): PIL Image or + list of PIL Images to include with the message """ raise NotImplementedError @@ -41,16 +44,21 @@ def add_message( def generate(self) -> None: """Generates the final report. - Implementing this method is only required if the report is not generated in "real-time", e.g., on calls of `add_message()`, but must be generated at the end of the execution. + Implementing this method is only required if the report is not generated + in "real-time", e.g., on calls of `add_message()`, but must be generated + at the end of the execution. - This method is called when the `askui.VisionAgent` context is exited or `askui.VisionAgent.close()` is called. + This method is called when the `askui.VisionAgent` context is exited or + `askui.VisionAgent.close()` is called. """ class CompositeReporter(Reporter): """A reporter that combines multiple reporters. - Allows generating different reports simultaneously. Each message added will be forwarded to all reporters passed to the constructor. The reporters are called (`add_message()`, `generate()`) in the order they are ordered in the `reporters` list. + Allows generating different reports simultaneously. Each message added will be forwarded to all + reporters passed to the constructor. The reporters are called (`add_message()`, `generate()`) in + the order they are ordered in the `reporters` list. Args: reporters (list[Reporter] | None, optional): List of reporters to combine @@ -87,7 +95,8 @@ class SimpleHtmlReporter(Reporter): """A reporter that generates HTML reports with conversation logs and system information. Args: - report_dir (str, optional): Directory where reports will be saved. Defaults to `reports`. + report_dir (str, optional): Directory where reports will be saved. + Defaults to `reports`. """ def __init__(self, report_dir: str = "reports") -> None: @@ -132,7 +141,7 @@ def add_message( _images = [image] message = { - "timestamp": datetime.now(), + "timestamp": datetime.now(tz=timezone.utc), "role": role, "content": self._format_content(content), "is_json": isinstance(content, (dict, list)), @@ -153,8 +162,11 @@ def generate(self) -> None: Vision Agent Report - {{ timestamp }} - - + +