10000 [NA] Alexkuzmik/fix adk tracer to not close traces too early by alexkuzmik · Pull Request #2403 · comet-ml/opik · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[NA] Alexkuzmik/fix adk tracer to not close traces too early #2403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/sdk-e2e-library-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ jobs:
GOOGLE_GENAI_USE_VERTEXAI: TRUE
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
OPENAI_ORG_ID: ${{ secrets.OPENAI_ORG_ID }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
uses: ./.github/actions/install_opik_and_run_e2e_lib_integration_tests
with:
python_version: ${{ matrix.python_version }}
Expand Down
280 changes: 94 additions & 186 deletions sdks/python/src/opik/integrations/adk/opik_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from google.adk.tools.tool_context import ToolContext

from opik import context_storage
from opik.api_objects import helpers, opik_client, span, trace
from opik.types import DistributedTraceHeadersDict, LLMProvider, SpanType
from opik.decorator import arguments_helpers, span_creation_handler
from opik.api_objects import opik_client, span, trace
from opik.types import DistributedTraceHeadersDict
from . import helpers as adk_helpers, litellm_wrappers, llm_response_wrapper

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -42,112 +43,31 @@ def __init__(
# in case we need to use different context storage for ADK in the future
self._context_storage = context_storage.get_current_context_instance()

self._external_parent_span_id: contextvars.ContextVar[Optional[str]] = (
contextvars.ContextVar("external_parent_span_id", default=None)
)
self._opik_created_spans: Set[str] = (
set()
) # TODO: use contextvar set for a more reliable clean-up?

self._opik_created_trace_id: Optional[str] = None
self._opik_created_spans: Set[str] = set()
self._current_trace_created_by_opik_tracer: contextvars.ContextVar[
Optional[str]
] = contextvars.ContextVar("current_trace_created_by_opik_tracer", default=None)

self._opik_client = opik_client.get_client_cached()

_patch_adk()

def _attach_span_to_existing_span(
self,
current_span_data: span.SpanData,
name: str,
input: Dict[str, Any],
type: SpanType,
metadata: Optional[Dict[str, Any]] = None,
provider: Optional[Union[str, LLMProvider]] = None,
model: Optional[str] = None,
) -> None:
project_name = helpers.resolve_child_span_project_name(
parent_project_name=current_span_data.project_name,
child_project_name=self.project_name,
)

span_data = span.SpanData(
trace_id=current_span_data.trace_id,
parent_span_id=current_span_data.id,
name=name,
provider=provider,
model=model,
input=input,
type=type,
project_name=project_name,
metadata=self.metadata
if metadata is None
else {**self.metadata, **metadata},
)
self._set_current_context_data(span_data)
self._opik_created_spans.add(span_data.id)

def _attach_span_to_existing_trace(
self,
current_trace_data: trace.TraceData,
name: str,
input: Dict[str, Any],
type: SpanType,
metadata: Optional[Dict[str, Any]] = None,
provider: Optional[Union[str, LLMProvider]] = None,
model: Optional[str] = None,
) -> None:
project_name = helpers.resolve_child_span_project_name(
parent_project_name=current_trace_data.project_name,
child_project_name=self.project_name,
)
span_data = span.SpanData(
trace_id=current_trace_data.id,
parent_span_id=None,
name=name,
provider=provider,
model=model,
input=input,
type=type,
project_name=project_name,
metadata=self.metadata
if metadata is None
else {**self.metadata, **metadata},
)
self._set_current_context_data(span_data)
self._opik_created_spans.add(span_data.id)

def _start_trace(
self,
new_trace_data: trace.TraceData,
) -> None:
self._set_current_context_data(new_trace_data)
self._opik_created_trace_id = new_trace_data.id

def _end_current_trace(self) -> None:
is_error = True

if trace_data := self._context_storage.get_trace_data():
if trace_data.id == self._opik_created_trace_id:
self._context_storage.set_trace_data(None)
trace_data.init_end_time()
self._opik_client.trace(**trace_data.__dict__)
is_error = False

if is_error:
LOGGER.error("Failed during _end_current_trace(): trace is not found.")
trace_data = self._context_storage.pop_trace_data()
assert trace_data is not None
trace_data.init_end_time()
self._opik_client.trace(**trace_data.__dict__)

def _end_current_span(
self,
) -> None:
is_error = True

if span_data := self._context_storage.top_span_data():
if span_data.id in self._opik_created_spans:
self._context_storage.pop_span_data()
span_data.init_end_time()
self._opik_client.span(**span_data.__dict__)
is_error = False

if is_error:
LOGGER.error("Failed during _end_current_span(): span is not found.")
span_data = self._context_storage.pop_span_data()
assert span_data is not None
span_data.init_end_time()
self._opik_client.span(**span_data.__dict__)

def _set_current_context_data(self, value: SpanOrTraceData) -> None:
if isinstance(value, span.SpanData):
Expand All @@ -157,22 +77,6 @@ def _set_current_context_data(self, value: SpanOrTraceData) -> None:
else:
raise ValueError(f"Invalid context type: {type(value)}")

def _ensure_no_hanging_opik_tracer_spans(self) -> None:
# handle spans created by this tracer
if external_parent_span_id := self._external_parent_span_id.get():
self._context_storage.trim_span_data_stack_to_certain_span(
external_parent_span_id
)
else:
self._context_storage.clear_spans()
8000 self._opik_created_spans.clear()

# handle trace created by this tracer
if current_trace_data := self._context_storage.get_trace_data():
if current_trace_data.id == self._opik_created_trace_id:
self._opik_created_trace_id = None
self._context_storage.set_trace_data(None)

def before_agent_callback(
self, callback_context: CallbackContext, *args: Any, **kwargs: Any
) -> None:
Expand All @@ -191,32 +95,33 @@ def before_agent_callback(
)
name = self.name or callback_context.agent_name

if current_span_data := self._context_storage.top_span_data():
self._attach_span_to_existing_span(
current_span_data=current_span_data,
name=name,
input=user_input,
type="general",
metadata=self.metadata,
)
self._external_parent_span_id.set(current_span_data.id)
elif current_trace_data := self._context_storage.get_trace_data():
self._attach_span_to_existing_trace(
current_trace_data=current_trace_data,
current_trace_data = self._context_storage.get_trace_data()
if current_trace_data is None: # todo: support distributed headers
current_trace = trace.TraceData(
name=name,
project_name=self.project_name,
metadata=trace_metadata,
thread_id=thread_id,
input=user_input,
type="general",
metadata=self.metadata,
)
self._context_storage.set_trace_data(current_trace)
self._current_trace_created_by_opik_tracer.set(current_trace.id)
else:
new_trace_data = trace.TraceData(
start_span_arguments = arguments_helpers.StartSpanParameters(
name=name,
input=user_input,
metadata=trace_metadata,
project_name=self.project_name,
thread_id=thread_id,
metadata=trace_metadata,
type="general",
)
self._start_trace(new_trace_data)
_, opik_span_data = (
span_creation_handler.create_span_respecting_context(
start_span_arguments=start_span_arguments,
distributed_trace_headers=None,
opik_context_storage=self._context_storage,
)
)
self._context_storage.add_span_data(opik_span_data)
self._opik_created_spans.add(opik_span_data.id)
except Exception as e:
LOGGER.error(f"Failed during before_agent_callback(): {e}", exc_info=True)

Expand All @@ -226,19 +131,22 @@ def after_agent_callback(
try:
output = self._last_model_output

if span_data := self._context_storage.top_span_data():
span_data.update(output=output).init_end_time()
self._end_current_span()
if (span_data := self._context_storage.top_span_data()) is not None:
if span_data.id in self._opik_created_spans:
span_data.update(output=output)
self._end_current_span()
self._opik_created_spans.discard(span_data.id)
else:
trace_data = self._context_storage.get_trace_data()
assert trace_data is not None
trace_data.update(output=output).init_end_time()
self._end_current_trace()
self._last_model_output = None

if trace_data.id == self._current_trace_created_by_opik_tracer.get():
trace_data.update(output=output)
self._end_current_trace()
self._current_trace_created_by_opik_tracer.set(None)
self._last_model_output = None
except Exception as e:
LOGGER.error(f"Failed during after_agent_callback(): {e}", exc_info=True)
finally:
self._ensure_no_hanging_opik_tracer_spans()

def before_model_callback(
self,
Expand All @@ -254,28 +162,23 @@ def before_model_callback(
llm_request.model
)

if current_span_data := self._context_storage.top_span_data():
self._attach_span_to_existing_span(
current_span_data=current_span_data,
_, span_data = span_creation_handler.create_span_respecting_context(
start_span_arguments=arguments_helpers.StartSpanParameters(
name=llm_request.model,
provider=provider,
model=model,
input=input,
type="llm",
project_name=self.project_name,
metadata=self.metadata,
)
else:
current_trace_data = self._context_storage.get_trace_data()
assert current_trace_data is not None
self._attach_span_to_existing_trace(
current_trace_data=current_trace_data,
name=llm_request.model,
provider=provider,
type="llm",
model=model,
provider=provider,
input=input,
type="llm",
metadata=self.metadata,
)
),
distributed_trace_headers=None,
opik_context_storage=self._context_storage,
)

self._context_storage.add_span_data(span_data)
self._opik_created_spans.add(span_data.id)

except Exception as e:
LOGGER.error(f"Failed during before_model_callback(): {e}", exc_info=True)

Expand Down Expand Up @@ -317,13 +220,16 @@ def after_model_callback(
try:
span_data = self._context_storage.top_span_data()
assert span_data is not None
span_data.update(
output=output,
usage=usage,
model=model,
provider=provider,
)
self._end_current_span()

if span_data.id in self._opik_created_spans:
span_data.update(
output=output,
usage=usage,
model=model,
provider=provider,
)
self._end_current_span()
self._opik_created_spans.discard(span_data.id)
except Exception as e:
LOGGER.error(f"Failed during after_model_callback(): {e}", exc_info=True)

Expand All @@ -336,26 +242,24 @@ def before_tool_callback(
**kwargs: Any,
) -> None:
try:
metadata = {"function_call_id": tool_context.function_call_id}
metadata = {
"function_call_id": tool_context.function_call_id,
**self.metadata,
}

if (current_span_data := self._context_storage.top_span_data()) is not None:
self._attach_span_to_existing_span(
current_span_data=current_span_data,
_, span_data = span_creation_handler.create_span_respecting_context(
start_span_arguments=arguments_helpers.StartSpanParameters(
name=tool.name,
input=args,
type="tool",
project_name=self.project_name,
metadata=metadata,
)
else:
current_trace_data = self._context_storage.get_trace_data()
assert current_trace_data is not None
self._attach_span_to_existing_trace(
current_trace_data=current_trace_data,
name=tool.name,
input=args,
type="tool",
metadata=metadata,
)
input=args,
),
distributed_trace_headers=None,
opik_context_storage=self._context_storage,
)
self._context_storage.add_span_data(span_data)
self._opik_created_spans.add(span_data.id)
except Exception as e:
LOGGER.error(f"Failed during before_tool_callback(): {e}", exc_info=True)

Expand All @@ -372,11 +276,15 @@ def after_tool_callback(
current_span_data = self._context_storage.top_span_data()
assert current_span_data is not None

if isinstance(tool_response, dict):
current_span_data.update(output=tool_response)
else:
current_span_data.update(output={"output": tool_response})
self._end_current_span()
output = (
tool_response
if isinstance(tool_response, dict)
else {"output": tool_response}
)
if current_span_data.id in self._opik_created_spans:
current_span_data.update(output=output)
self._end_current_span()
self._opik_created_spans.discard(current_span_data.id)
except Exception as e:
LOGGER.error(f"Failed during after_tool_callback(): {e}", exc_info=True)

Expand Down
0