8000 [async] Propagate trace context to webhook and upload requests by aron · Pull Request #1787 · replicate/cog · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[async] Propagate trace context to webhook and upload requests #1787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions python/cog/server/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .eventtypes import PredictionInput
from .response_throttler import ResponseThrottler
from .retry_transport import RetryTransport
from .telemetry import current_trace_context

log = structlog.get_logger(__name__)

Expand Down Expand Up @@ -45,14 +46,25 @@ def _get_version() -> str:
WebhookSenderType = Callable[[Any, WebhookEvent], Awaitable[None]]


def webhook_headers() -> "dict[str, str]":
def common_headers() -> "dict[str, str]":
headers = {"user-agent": _user_agent}
return headers


def webhook_headers() -> "dict[str, str]":
headers = common_headers()
auth_token = os.environ.get("WEBHOOK_AUTH_TOKEN")
if auth_token:
headers["authorization"] = "Bearer " + auth_token

return headers


async def on_request_trace_context_hook(request: httpx.Request) -> None:
ctx = current_trace_context() or {}
request.headers.update(ctx)


def httpx_webhook_client() -> httpx.AsyncClient:
return httpx.AsyncClient(headers=webhook_headers(), follow_redirects=True)

Expand All @@ -68,7 +80,10 @@ def httpx_retry_client() -> httpx.AsyncClient:
retryable_methods=["POST"],
)
return httpx.AsyncClient(
headers=webhook_headers(), transport=transport, follow_redirects=True
event_hooks={"request": [on_request_trace_context_hook]},
headers=webhook_headers(),
transport=transport,
follow_redirects=True,
)


Expand All @@ -87,6 +102,8 @@ def httpx_file_client() -> httpx.AsyncClient:
# httpx default for pool is 5, use that
timeout = httpx.Timeout(connect=10, read=15, write=None, pool=5)
return httpx.AsyncClient(
event_hooks={"request": [on_request_trace_context_hook]},
headers=common_headers(),
transport=transport,
follow_redirects=True,
timeout=timeout,
Expand Down
39 changes: 30 additions & 9 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
Dict,
Optional,
TypeVar,
Union,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -52,6 +51,7 @@
SetupTask,
UnknownPredictionError,
)
from .telemetry import make_trace_context, trace_context

log = structlog.get_logger("cog.server.http")

Expand Down Expand Up @@ -190,9 +190,16 @@ class TrainingRequest(
)
def train(
request: TrainingRequest = Body(default=None),
prefer: Union[str, None] = Header(default=None),
prefer: Optional[str] = Header(default=None),
traceparent: Optional[str] = Header(default=None, include_in_schema=False),
tracestate: Optional[str] = Header(default=None, include_in_schema=False),
) -> Any: # type: ignore
return predict(request, prefer)
return predict(
request,
prefer=prefer,
traceparent=traceparent,
tracestate=tracestate,
)

@app.put(
"/trainings/{training_id}",
Expand All @@ -202,9 +209,17 @@ def train(
def train_idempotent(
training_id: str = Path(..., title="Training ID"),
request: TrainingRequest = Body(..., title="Training Request"),
prefer: Union[str, None] = Header(default=None),
prefer: Optional[str] = Header(default=None),
traceparent: Optional[str] = Header(default=None, include_in_schema=False),
tracestate: Optional[str] = Header(default=None, include_in_schema=False),
) -> Any:
return predict_idempotent(training_id, request, prefer)
return predict_idempotent(
prediction_id=training_id,
request=request,
prefer=prefer,
traceparent=traceparent,
tracestate=tracestate,
)

@app.post("/trainings/{training_id}/cancel")
def cancel_training(training_id: str = Path(..., title="Training ID")) -> Any:
Expand Down Expand Up @@ -270,7 +285,9 @@ async def ready() -> Any:
)
async def predict(
request: PredictionRequest = Body(default=None),
prefer: Union[str, None] = Header(default=None),
prefer: Optional[str] = Header(default=None),
traceparent: Optional[str] = Header(default=None, include_in_schema=False),
tracestate: Optional[str] = Header(default=None, include_in_schema=False),
) -> Any: # type: ignore
"""
Run a single prediction on the model
Expand All @@ -285,7 +302,8 @@ async def predict(
# TODO: spec-compliant parsing of Prefer header.
respond_async = prefer == "respond-async"

return await shared_predict(request=request, respond_async=respond_async)
with trace_context(make_trace_context(traceparent, tracestate)):
return await shared_predict(request=request, respond_async=respond_async)

@limited
@app.put(
Expand All @@ -296,7 +314,9 @@ async def predict(
async def predict_idempotent(
prediction_id: str = Path(..., title="Prediction ID"),
request: PredictionRequest = Body(..., title="Prediction Request"),
prefer: Union[str, None] = Header(default=None),
prefer: Optional[str] = Header(default=None),
traceparent: Optional[str] = Header(default=None, include_in_schema=False),
tracestate: Optional[str] = Header(default=None, include_in_schema=False),
) -> Any:
"""
Run a single prediction on the model (idempotent creation).
Expand All @@ -314,7 +334,8 @@ async def predict_idempotent(
# TODO: spec-compliant parsing of Prefer header.
respond_async = prefer == "respond-async"

return await shared_predict(request=request, respond_async=respond_async)
with trace_context(make_trace_context(traceparent, tracestate)):
return await shared_predict(request=request, respond_async=respond_async)

async def shared_predict(
*, request: Optional[PredictionRequest], respond_async: bool = False
Expand Down
54 changes: 54 additions & 0 deletions python/cog/server/telemetry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Generator, Optional

# TypedDict was added in 3.8
from typing_extensions import TypedDict


# See: https://www.w3.org/TR/trace-context/
class TraceContext(TypedDict, total=False):
traceparent: str
tracestate: str


TRACE_CONTEXT: ContextVar[Optional[TraceContext]] = ContextVar(
"trace_context", default=None
)


def make_trace_context(
traceparent: Optional[str] = None, tracestate: Optional[str] = None
) -> TraceContext:
"""
Creates a trace context dictionary from the given traceparent and tracestate
headers. This is used to pass the trace context between services.
"""
ctx: TraceContext = {}
if traceparent:
ctx["traceparent"] = traceparent
if tracestate:
ctx["tracestate"] = tracestate
return ctx


def current_trace_context() -> Optional[TraceContext]:
"""
Returns the current trace context, this needs to be added via HTTP headers
to all outgoing HTTP requests.
"""
return TRACE_CONTEXT.get()


@contextmanager
def trace_context(ctx: TraceContext) -> Generator[None, None, None]:
"""
A helper for managing the current trace context provided by the inbound
HTTP request. This context is used to link requests across the system and
needs to be added to all internal outgoing HTTP requests.
"""
t = TRACE_CONTEXT.set(ctx)
try:
yield
finally:
TRACE_CONTEXT.reset(t)
60 changes: 60 additions & 0 deletions python/tests/server/test_http.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import base64
import httpx
import io
import respx
import time
import unittest.mock as mock

Expand Down Expand Up @@ -560,6 +562,64 @@ def test_asynchronous_prediction_endpoint(client, match):
assert webhook.call_count == 1


# End-to-end test for passing tracing headers on to downstream services.
@pytest.mark.asyncio
@pytest.mark.respx(base_url="https://example.com")
@uses_predictor_with_client_options(
"output_file", upload_url="https://example.com/upload"
)
async def test_asynchronous_prediction_endpoint_with_trace_context(
respx_mock: respx.MockRouter, client, match
):
webhook = respx_mock.post(
"/webhook",
json__id="12345abcde",
json__status="succeeded",
json__output="https://example.com/upload/file",
headers={
"traceparent": "traceparent-123",
"tracestate": "tracestate-123",
},
).respond(200)
uploader = respx_mock.put(
"/upload/file",
headers={
"content-type": "application/octet-stream",
"traceparent": "traceparent-123",
"tracestate": "tracestate-123",
},
).respond(200)

resp = client.post(
"/predictions",
json={
"id": "12345abcde",
"input": {},
"webhook": "https://example.com/webhook",
"webhook_events_filter": ["completed"],
},
headers={
"Prefer": "respond-async",
"traceparent": "traceparent-123",
"tracestate": "tracestate-123",
},
)
assert resp.status_code == 202

assert resp.json() == match(
{"status": "processing", "output": None, "started_at": mock.ANY}
)
assert resp.json()["started_at"] is not None

n = 0
while webhook.call_count < 1 and n < 10:
time.sleep(0.1)
n += 1

assert webhook.call_count == 1
assert uploader.call_count == 1


@uses_predictor("sleep")
def test_prediction_cancel(client):
resp = client.post("/predictions/123/cancel")
Expand Down
Loading
0