8000 Fix closure mirascope response model by willbakst · Pull Request #56 · Mirascope/lilypad · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Fix closure mirascope response model #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions lilypad/_utils/closure.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,23 +224,34 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
self.definitions_to_analyze.append(definition)
self.generic_visit(node)

def visit_Call(self, node: ast.Call) -> None:
if isinstance(node.func, ast.Name):
if obj := getattr(self.module, node.func.id, None):
def _process_name_or_attribute(self, node: ast.AST) -> None:
if isinstance(node, ast.Name):
if (obj := getattr(self.module, node.id, None)) and hasattr(
obj, "__name__"
):
self.definitions_to_include.append(obj)
elif isinstance(node.func, ast.Attribute):
elif isinstance(node, ast.Attribute):
names = []
current = node.func
current = node
while isinstance(current, ast.Attribute):
names.append(current.attr)
current = current.value
if isinstance(current, ast.Name):
names.append(current.id)
full_path = ".".join(reversed(names))
if full_path in self.used_names and (
definition := getattr(self.module, names[0], None)
if (
full_path in self.used_names
and (definition := getattr(self.module, names[0], None))
and hasattr(definition, "__name__")
):
self.definitions_to_include.append(definition)

def visit_Call(self, node: ast.Call) -> None:
self._process_name_or_attribute(node.func)
for arg in node.args:
self._process_name_or_attribute(arg)
for keyword in node.keywords:
self._process_name_or_attribute(keyword.value)
self.generic_visit(node)


Expand Down
4 changes: 3 additions & 1 deletion lilypad/server/models/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .base_organization_sql_model import BaseOrganizationSQLModel
from .generations import GenerationPublic
from .prompts import PromptPublic
from .response_models import ResponseModelPublic
from .table_names import PROJECT_TABLE_NAME

if TYPE_CHECKING:
Expand Down Expand Up @@ -35,6 +36,7 @@ class ProjectPublic(_ProjectBase):
uuid: UUID
generations: list[GenerationPublic] = []
prompts: list[PromptPublic] = []
response_models: list[ResponseModelPublic] = []


class ProjectTable(_ProjectBase, BaseOrganizationSQLModel, table=True):
Expand All @@ -47,7 +49,7 @@ class ProjectTable(_ProjectBase, BaseOrganizationSQLModel, table=True):
prompts: list["PromptTable"] = Relationship(
back_populates="project", cascade_delete=True
)
organization: "OrganizationTable" = Relationship(back_populates="projects")
response_models: list["ResponseModelTable"] = Relationship(
back_populates="project", cascade_delete=True
)
organization: "OrganizationTable" = Relationship(back_populates="projects")
4 changes: 4 additions & 0 deletions lilypad/server/models/response_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
if TYPE_CHECKING:
from .generations import GenerationTable
from .projects import ProjectTable
from .spans import SpanTable


class _ResponseModelBase(SQLModel):
Expand Down Expand Up @@ -49,6 +50,9 @@ class ResponseModelTable(_ResponseModelBase, BaseOrganizationSQLModel, table=Tru
__tablename__ = "response_models" # pyright: ignore [reportAssignmentType]

project: "ProjectTable" = Relationship(back_populates="response_models")
spans: list["SpanTable"] = Relationship(
back_populates="response_model", cascade_delete=True
)
generations: list["GenerationTable"] = Relationship(
back_populates="response_model", cascade_delete=True
)
10 changes: 10 additions & 0 deletions lilypad/server/models/spans.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@
from .base_organization_sql_model import BaseOrganizationSQLModel
from .generations import GenerationPublic
from .prompts import PromptPublic
from .response_models import ResponseModelPublic
from .table_names import (
GENERATION_TABLE_NAME,
PROJECT_TABLE_NAME,
PROMPT_TABLE_NAME,
RESPONSE_MODEL_TABLE_NAME,
SPAN_TABLE_NAME,
)

if TYPE_CHECKING:
from .generations import GenerationTable
from .prompts import PromptTable
from .response_models import ResponseModelTable


class Scope(str, Enum):
Expand Down Expand Up @@ -51,6 +54,9 @@ class _SpanBase(SQLModel):
prompt_uuid: UUID | None = Field(
default=None, foreign_key=f"{PROMPT_TABLE_NAME}.uuid"
)
response_model_uuid: UUID | None = Field(
default=None, foreign_key=f"{RESPONSE_MODEL_TABLE_NAME}.uuid"
)
type: SpanType | None = Field(default=None)
cost: float | None = Field(default=None)
scope: Scope = Field(nullable=False)
Expand All @@ -73,6 +79,7 @@ class SpanPublic(_SpanBase):
display_name: str | None = None
generation: GenerationPublic | None = None
prompt: PromptPublic | None = None
response_model: ResponseModelPublic | None = None
child_spans: list["SpanPublic"]
created_at: datetime

Expand Down Expand Up @@ -121,6 +128,9 @@ class SpanTable(_SpanBase, BaseOrganizationSQLModel, table=True):
__table_args__ = (UniqueConstraint("span_id"), Index("ix_spans_span_id", "span_id"))
generation: Optional["GenerationTable"] = Relationship(back_populates="spans")
prompt: Optional["PromptTable"] = Relationship(back_populates="spans")
response_model: Optional["ResponseModelTable"] = Relationship(
back_populates="spans"
)
child_spans: list["SpanTable"] = Relationship(
back_populates="parent_span", cascade_delete=True
)
Expand Down
2 changes: 1 addition & 1 deletion lilypad/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Settings(BaseSettings):
# Server settings
environment: str = Field(default="production")
port: int = Field(default=8000)
remote_base_url: str = Field(default="https://lilypad-production.up.railway.app")
remote_base_url: str = Field(default="https://app.lilypad.so")

# GitHub OAuth settings
github_client_id: str = Field(default="my_client_id")
Expand Down
1 change: 1 addition & 0 deletions tests/_utils/closure/closure_test_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
inner_fn,
inner_sub_fn,
internal_imports_fn,
mirascope_response_model_fn,
multi_decorated_fn,
self_fn_class_fn,
single_fn,
Expand Down
25 changes: 25 additions & 0 deletions tests/_utils/closure/closure_ 9E7A test_functions/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mirascope.core import BaseMessageParam, openai, prompt_template
from openai import OpenAI as OAI
from openai.types.chat import ChatCompletionUserMessageParam
from pydantic import BaseModel

import tests._utils.closure.closure_test_functions.other
import tests._utils.closure.closure_test_functions.other as cloth
Expand Down Expand Up @@ -564,3 +565,27 @@ def closure_with_long_function_name_that_wraps_around_fn(
return {"role": "user", "content": "Hello, world!"}
"""
return {"role": "user", "content": "Hello, world!"}


class Response(BaseModel):
"""Test response model."""

response: str


@openai.call("gpt-4o-mini", response_model=Response)
def mirascope_response_model_fn() -> str:
"""
from mirascope.core import openai
from pydantic import BaseModel


class Response(BaseModel):
response: str


@openai.call("gpt-4o-mini", response_model=Response)
def mirascope_response_model_fn() -> str:
return "Hello, world!"
"""
return "Hello, world!"
17 changes: 17 additions & 0 deletions tests/_utils/closure/test_closure.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
inner_fn,
inner_sub_fn,
internal_imports_fn,
mirascope_response_model_fn,
multi_decorated_fn,
self_fn_class_fn,
single_fn,
Expand Down Expand Up @@ -388,3 +389,19 @@ def fn(arg: str) -> str:

closure = Closure.from_fn(fn)
assert closure.run("Hello, world!") == "Hello, world!"


def test_mirascope_response_model_fn() -> None:
"""Test the `Closure` class with a Mirascope response model."""
closure = Closure.from_fn(mirascope_response_model_fn)
assert closure.code == _expected(mirascope_response_model_fn)
assert closure.dependencies == {
"mirascope": {
"version": importlib.metadata.version("mirascope"),
"extras": ["anthropic", "gemini", "openai", "opentelemetry"],
},
"pydantic": {
"extras": None,
"version": "2.10.3",
},
}
Loading
0