8000 Don't return a PredictionResponse from PredictionRunner.setup by nickstenning · Pull Request #1433 · replicate/cog · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Don't return a PredictionResponse from PredictionRunner.setup #1433

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
if TYPE_CHECKING:
from typing import ParamSpec

import attrs
import structlog
import uvicorn
from fastapi import Body, FastAPI, Header, HTTPException, Path, Response
Expand All @@ -43,7 +44,13 @@
load_config,
load_predictor_from_ref,
)
from .runner import PredictionRunner, RunnerBusyError, UnknownPredictionError
from .runner import (
PredictionRunner,
RunnerBusyError,
SetupResult,
SetupTask,
UnknownPredictionError,
)

log = structlog.get_logger("cog.server.http")

Expand All @@ -59,8 +66,8 @@ class Health(Enum):

class MyState:
health: Health
setup_result: "Optional[asyncio.Task[schema.PredictionResponse]]"
setup_result_payload: Optional[schema.PredictionResponse]
setup_task: Optional[SetupTask]
setup_result: Optional[SetupResult]


class MyFastAPI(FastAPI):
Expand All @@ -83,8 +90,8 @@ def create_app(
)

app.state.health = Health.STARTING
app.state.setup_task = None
app.state.setup_result = None
app.state.setup_result_payload = None

predictor_ref = get_predictor_ref(config, mode)

Expand Down Expand Up @@ -122,7 +129,7 @@ async def wrapped(*args: "P.args", **kwargs: "P.kwargs") -> "T":

@app.on_event("startup")
def startup() -> None:
app.state.setup_result = runner.setup()
app.state.setup_task = runner.setup()

@app.on_event("shutdown")
def shutdown() -> None:
Expand All @@ -138,15 +145,15 @@ async def root() -> Any:

@app.get("/health-check")
async def healthcheck() -> Any:
await _check_setup_result()
await _check_setup_task()
if app.state.health == Health.READY:
health = Health.BUSY if runner.is_busy() else Health.READY
else:
health = app.state.health
return jsonable_encoder(
{
"status": health.name,
"setup": app.state.setup_result_payload,
"setup": attrs.asdict(app.state.setup_result),
}
)

Expand Down Expand Up @@ -274,25 +281,25 @@ async def start_shutdown() -> Any:
shutdown_event.set()
return JSONResponse({}, status_code=200)

async def _check_setup_result() -> Any:
if app.state.setup_result is None:
async def _check_setup_task() -> Any:
if app.state.setup_task is None:
return

if not app.state.setup_result.done():
if not app.state.setup_task.done():
return

# this can raise CancelledError
result = app.state.setup_result.result()
result = app.state.setup_task.result()

if result.status == schema.Status.SUCCEEDED:
app.state.health = Health.READY
else:
app.state.health = Health.SETUP_FAILED

app.state.setup_result_payload = result
app.state.setup_result = result

# Reset app.state.setup_result so future calls are a no-op
app.state.setup_result = None
# Reset app.state.setup_task so future calls are a no-op
app.state.setup_task = None

return app

Expand Down
40 changes: 22 additions & 18 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import threading
import traceback
import typing # TypeAlias, py3.10
from asyncio import Task
from datetime import datetime, timezone
from typing import Any, Callable, Optional, Tuple, cast
from typing import Any, Callable, Optional, Tuple, Union, cast

import requests
import structlog
from attrs import define
from fastapi.encoders import jsonable_encoder
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry # type: ignore
Expand Down Expand Up @@ -36,7 +36,17 @@ class UnknownPredictionError(Exception):
pass


PredictionTask: "typing.TypeAlias" = "Task[schema.PredictionResponse]"
@define
class SetupResult:
started_at: datetime
completed_at: datetime
logs: str
status: schema.Status


PredictionTask: "typing.TypeAlias" = "asyncio.Task[schema.PredictionResponse]"
SetupTask: "typing.TypeAlias" = "asyncio.Task[SetupResult]"
RunnerTask: "typing.TypeAlias" = Union[PredictionTask, SetupTask]


class PredictionRunner:
Expand All @@ -48,16 +58,16 @@ def __init__(
upload_url: Optional[str] = None,
) -> None:
self._response: Optional[schema.PredictionResponse] = None
self._result: "Optional[PredictionTask]" = None
self._result: Optional[RunnerTask] = None

self._worker = Worker(predictor_ref=predictor_ref)
self._should_cancel = asyncio.Event()

self._shutdown_event = shutdown_event
self._upload_url = upload_url

def make_error_handler(self, activity: str) -> Callable[[PredictionTask], None]:
def handle_error(task: PredictionTask) -> None:
def make_error_handler(self, activity: str) -> Callable[[RunnerTask], None]:
def handle_error(task: RunnerTask) -> None:
exc = task.exception()
if not exc:
return
Expand All @@ -73,7 +83,7 @@ def handle_error(task: PredictionTask) -> None:

return handle_error

def setup(self) -> "Task[schema.PredictionResponse]":
def setup(self) -> SetupTask:
if self.is_busy():
raise RunnerBusyError()
self._result = asyncio.create_task(setup(worker=self._worker))
Expand All @@ -84,7 +94,7 @@ def setup(self) -> "Task[schema.PredictionResponse]":
# no longer have to support Python 3.8
def predict(
self, prediction: schema.PredictionRequest, upload: bool = True
) -> Tuple[schema.PredictionResponse, "Task[schema.PredictionResponse]"]:
) -> Tuple[schema.PredictionResponse, PredictionTask]:
# It's the caller's responsibility to not call us if we're busy.
if self.is_busy():
# If self._result is set, but self._response is not, we're still
Expand All @@ -93,7 +103,8 @@ def predict(
raise RunnerBusyError()
assert self._result is not None
if prediction.id is not None and prediction.id == self._response.id:
return (self._response, self._result)
result = cast(PredictionTask, self._result)
return (self._response, result)
raise RunnerBusyError()

# Set up logger context for main thread. The same thing happens inside
Expand Down Expand Up @@ -279,7 +290,7 @@ def _upload_files(self, output: Any) -> Any:
raise FileUploadError("Got error trying to upload output files") from error


async def setup(*, worker: Worker) -> schema.PredictionResponse:
async def setup(*, worker: Worker) -> SetupResult:
logs = []
status = None
started_at = datetime.now(tz=timezone.utc)
Expand Down Expand Up @@ -309,17 +320,10 @@ async def setup(*, worker: Worker) -> schema.PredictionResponse:
probes = ProbeHelper()
probes.ready()

return schema.PredictionResponse(
input={},
output=None,
id=None,
version=None,
created_at=None,
return SetupResult(
started_at=started_at,
completed_at=completed_at,
logs="".join(logs),
error=None,
metrics=None,
status=status,
)

Expand Down
0