8000 Extract template endpoint by guillaq · Pull Request #237 · WorkflowAI/WorkflowAI · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Extract template endpoint #237

New is 8000 sue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions api/api/routers/agents_v1.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Mapping
from datetime import datetime, timedelta
from typing import Annotated, Any

Expand All @@ -10,13 +11,15 @@
from api.dependencies.storage import StorageDep
from api.tags import RouteTags
from core.domain.analytics_events.analytics_events import CreatedTaskProperties, TaskProperties
from core.domain.errors import BadRequestError
from core.domain.events import TaskSchemaCreatedEvent
from core.domain.fields.chat_message import ChatMessage
from core.domain.page import Page
from core.domain.task_io import SerializableTaskIO
from core.domain.task_variant import SerializableTaskVariant
from core.utils import strings
from core.utils.fields import datetime_factory
from core.utils.templates import InvalidTemplateError, extract_variable_schema

router = APIRouter(prefix="/v1/{tenant}/agents", tags=[RouteTags.AGENTS])

Expand Down Expand Up @@ -154,3 +157,20 @@ async def get_agent_stats(
async for stat in storage.task_runs.run_count_by_agent_uid(from_date)
]
return Page(items=items)


class ExtractTemplateRequest(BaseModel):
template: str


class ExtractTemplateResponse(BaseModel):
json_schema: Mapping[str, Any]


@router.post("/{agent_id}/templates/extract")
async def extract_template(request: ExtractTemplateRequest) -> ExtractTemplateResponse:
try:
json_schema = extract_variable_schema(request.template)
except InvalidTemplateError as e:
raise BadRequestError(e.message)
return ExtractTemplateResponse(json_schema=json_schema)
161 changes: 160 additions & 1 deletion api/core/utils/templates.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import asyncio
import re
from collections.abc import Mapping, Sequence
from typing import Any

from cachetools import LRUCache
from jinja2 import Environment, Template, TemplateError
from jinja2 import Environment, Template, TemplateError, nodes
from jinja2.meta import find_undeclared_variables
from jinja2.visitor import NodeVisitor

from core.domain.errors import BadRequestError

# Compiled regepx to check if instructions are a template
# Jinja templates use {%%} for expressions {{}} for variables and {# ... #} for comments
Expand Down Expand Up @@ -67,3 +71,158 @@ async def render_template(self, template: str, data: dict[str, Any]):

rendered = await compiled.render_async(data)
return rendered, variables


class _SchemaBuilder(NodeVisitor):
def __init__(self):
self._schema: dict[str, Any] = {"type": "object", "properties": {}}
self._aliases: list[Mapping[str, Any]] = []

# ---- helpers ----------------------------------------------------------
def _ensure_path(self, path: Sequence[str]):
"""
Given a tuple like ('order', 'items', '*', 'price')
make sure the schema contains the corresponding nested structure.
"""
cur = self._schema
for i, part in enumerate(path):
last = i == len(path) - 1

# Array ­­­­­­­­­­­­­­­­­­­­­­­­­­­­­­­­­­­­­­­
if part == "*":
if cur.get("type") != "array":
cur.update(
{
"type": "array",
"items": {"type": "object", "properties": {}},
},
)
cur = cur["items"]
continue

# Object ­­­­­­­­­­­­­­­­­­­­­­­­­­­­­­­­­­­­­­­
cur.setdefault("type", "object")
cur.setdefault("properties", {})
cur = cur["properties"].setdefault(part, {})

if last:
# crude default ‑‑ upgrade later if you have better hints
cur.setdefault("type", "string")

def _collect_chain(self, node: nodes.Node):
"""Turn nested getattr/getitem into a tuple path.
This can't be combined with _ensure_path since we the order is reversed
"""
path: list[str] = []
while isinstance(node, (nodes.Getattr, nodes.Getitem)):
match node:
case nodes.Getattr():
path.insert(0, node.attr)
node = node.node
case nodes.Getitem():
path.insert(0, "*")
node = node.node

if isinstance(node, nodes.Name):
path.insert(0, node.name)
self._ensure_path(path)

def _push_scope(self, mapping: Mapping[str, Any] | None):
self._aliases.append(mapping or {})

def _pop_scope(self):
self._aliases.pop()

def _lookup_alias(self, name: str) -> Any | None:
# walk stack from innermost to outermost
for scope in reversed(self._aliases):
if name in scope:
return scope[name]
return None

def _expr_to_path(self, node: nodes.Node) -> list[str] | None:
"""Return tuple path for Name/Getattr/Getitem chains, else None."""
path: list[str] = []
while isinstance(node, (nodes.Getattr, nodes.Getitem)):
if isinstance(node, nodes.Getattr):
path.insert(0, node.attr)
node = node.node
else: # Getitem -> wildcard
path.insert(0, "*")
node = node.node
if isinstance(node, nodes.Name):
alias = self._lookup_alias(node.name)
if alias is not None:
path = list(alias) + path # expand alias
else:
path.insert(0, node.name)
return path
return None

# ---- NodeVisitor interface -------------------------------------------
# No overrides below, names are dynamically generated

def visit_Name(self, node: nodes.Name):
path = self._expr_to_path(node)
if path:
self._ensure_path(path)

def visit_Getattr(self, node: nodes.Getattr):
path = self._expr_to_path(node)
if path:
self._ensure_path(path)

def visit_Getitem(self, node: nodes.Getitem):
path = self._expr_to_path(node)
if path:
self._ensure_path(path)

def visit_For(self, node: nodes.For):
# {% for item in order.items %} -> order.items is iterable
# 1) resolve iterable path and mark it as array
iter_path = self._expr_to_path(node.iter)
if iter_path is None:
self.generic_visit(node)
return

if iter_path[-1] != "*":
iter_path.append("*")
self._ensure_path(iter_path)

# 2) create alias mapping(s) for loop target(s)
alias_map: dict[str, list[str]] = {}

def add_alias(target: nodes.Node, base_path: list[str]):
if isinstance(target, nodes.Name):
alias_map[target.name] = base_path
elif isinstance(target, nodes.Tuple):
for t in target.items:
add_alias(t, base_path + ["*"])

add_alias(node.target, iter_path)
self._push_scope(alias_map)

# 3) process the loop body
self.generic_visit(node)

# 4) pop alias scope
self._pop_scope()

def visit_Call(self, node: nodes.Call):
raise BadRequestError("Template functions are not supported", capture=True)

@property
def schema(self) -> Mapping[str, Any]:
return self._schema


def extract_variable_schema(template: str) -> Mapping[str, Any]:
env = Environment()
try:
ast = env.parse(template)
except TemplateError as e:
raise InvalidTemplateError.from_jinja(e)

builder = _SchemaBuilder()
builder.visit(ast)
return builder.schema
136 changes: 135 additions & 1 deletion api/core/utils/templates_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from core.utils.templates import InvalidTemplateError, TemplateManager
from core.domain.errors import BadRequestError
from core.utils.templates import InvalidTemplateError, TemplateManager, extract_variable_schema


@pytest.fixture
Expand Down Expand Up @@ -63,3 +64,136 @@ async def test_render_template_remaining(self, template_manager: TemplateManager
assert rendered == "Hello, John!"
assert variables == {"name"}
assert data == {"name": "John", "hello": "world"}


class TestExtractVariableSchema:
def test_extract_variable_schema(self):
schema = extract_variable_schema("Hello, {{ name }}!")
assert schema == {"type": "object", "properties": {"name": {"type": "string"}}}

def test_attribute_access(self):
schema = extract_variable_schema("User: {{ user.name }}")
assert schema == {
"type": "object",
"properties": {"user": {"type": "object", "properties": {"name": {"type": "string"}}}},
}

def test_nested_attribute_access(self):
schema = extract_variable_schema("Email: {{ user.profile.email }}")
assert schema == {
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"profile": {
"type": "object",
"properties": {"email": {"type": "string"}},
},
},
},
},
}

def test_item_access_as_array(self):
# Note: Getitem is always treated as array access ('*') by the current implementation
schema = extract_variable_schema("First user: {{ users[0].name }}")
assert schema == {
"type": "object",
"properties": {
"users": {
"type": "array",
"items": {
"type": "object",
"properties": {"name": {"type": "string"}},
},
},
},
}

def test_for_loop(self):
template = "{% for item in items %}{{ item.name }}{% endfor %}"
schema = extract_variable_schema(template)
assert schema == {
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "object",
"properties": {"name": {"type": "string"}},
},
},
},
}

def test_nested_for_loop(self):
template = "{% for user in users %}{% for post in user.posts %}{{ post.title }}{% endfor %}{% endfor %}"
schema = extract_variable_schema(template)
assert schema == {
"type": "object",
"properties": {
"users": {
"type": "array",
"items": {
"type": "object",
"properties": {
"posts": {
"type": "array",
"items": {
"type": "object",
"properties": {"title": {"type": "string"}},
},
},
},
},
},
},
}

def test_conditional(self):
template = "{% if user.is_admin %}{{ user.name }}{% else %}Guest{% endif %}"
schema = extract_variable_schema(template)
assert schema == {
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"is_admin": {"type": "string"}, # Type defaults to string
"name": {"type": "string"},
},
},
},
}

def test_combined(self):
template = "{{ user.name }} {% for project in user.projects %}{{ project.id }}{% endfor %}"
schema = extract_variable_schema(template)
assert schema == {
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"name": {"type": "string"},
"projects": {
"type": "array",
"items": {
"type": "object",
"properties": {"id": {"type": "string"}},
},
},
},
},
},
}

def test_no_variables(self):
schema = extract_variable_schema("Just plain text.")
assert schema == {"type": "object", "properties": {}}

def test_function_call_raises_error(self):
# Functions are not supported
with pytest.raises(BadRequestError, match="Template functions are not supported"):
extract_variable_schema("{{ my_func() }}")
0