8000 Pass parameters to custom routers through LLMConfig by eicherseiji · Pull Request #53870 · ray-project/ray · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Pass parameters to custom routers through LLMConfig #53870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions java/serve/src/main/java/io/ray/serve/config/DeploymentConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,10 @@ public class DeploymentConfig implements Serializable {
*/
private Double healthCheckTimeoutS = Constants.DEFAULT_HEALTH_CHECK_TIMEOUT_S;

/** Frequency at which the controller will record request routing stats. */
private Double requestRoutingStatsPeriodS = Constants.DEFAULT_REQUEST_ROUTING_STATS_PERIOD_S;

/**
* Timeout that the controller will wait for a response from the replica's request routing stats
* before retrying.
*/
private Double requestRoutingStatsTimeoutS = Constants.DEFAULT_REQUEST_ROUTING_STATS_TIMEOUT_S;

private AutoscalingConfig autoscalingConfig;

private RouterConfig routerConfig;

/** This flag is used to let replica know they are deplyed from a different language. */
private Boolean isCrossLanguage = false;

Expand Down Expand Up @@ -150,23 +143,23 @@ public DeploymentConfig setHealthCheckTimeoutS(Double healthCheckTimeoutS) {
}

public Double getRequestRoutingStatsPeriodS() {
return requestRoutingStatsPeriodS;
return routerConfig.getRequestRoutingStatsPeriodS();
}

public DeploymentConfig setRequestRoutingStatsPeriodS(Double requestRoutingStatsPeriodS) {
if (requestRoutingStatsPeriodS != null) {
this.requestRoutingStatsPeriodS = requestRoutingStatsPeriodS;
routerConfig.setRequestRoutingStatsPeriodS(requestRoutingStatsPeriodS);
}
return this;
}

public Double getRequestRoutingStatsTimeoutS() {
return requestRoutingStatsTimeoutS;
return routerConfig.getRequestRoutingStatsTimeoutS();
}

public DeploymentConfig setRequestRoutingStatsTimeoutS(Double requestRoutingStatsTimeoutS) {
if (requestRoutingStatsTimeoutS != null) {
this.requestRoutingStatsTimeoutS = requestRoutingStatsTimeoutS;
routerConfig.setRequestRoutingStatsTimeoutS(requestRoutingStatsTimeoutS);
}
return this;
}
Expand All @@ -180,6 +173,15 @@ public DeploymentConfig setAutoscalingConfig(AutoscalingConfig autoscalingConfig
return this;
}

public RouterConfig getRouterConfig() {
return routerConfig;
}

public DeploymentConfig setRouterConfig(RouterConfig routerConfig) {
this.routerConfig = routerConfig;
return this;
}

public boolean isCrossLanguage() {
return isCrossLanguage;
}
Expand Down Expand Up @@ -230,8 +232,6 @@ public byte[] toProtoBytes() {
.setGracefulShutdownTimeoutS(gracefulShutdownTimeoutS)
.setHealthCheckPeriodS(healthCheckPeriodS)
.setHealthCheckTimeoutS(healthCheckTimeoutS)
.setRequestRoutingStatsPeriodS(requestRoutingStatsPeriodS)
.setRequestRoutingStatsTimeoutS(requestRoutingStatsTimeoutS)
.setIsCrossLanguage(isCrossLanguage)
.setDeploymentLanguage(deploymentLanguage)
.setVersion(version);
Expand All @@ -241,6 +241,9 @@ public byte[] toProtoBytes() {
if (null != autoscalingConfig) {
builder.setAutoscalingConfig(autoscalingConfig.toProto());
}
if (null != routerConfig) {
builder.setRouterConfig(routerConfig.toProto());
}
return builder.build().toByteArray();
}

Expand All @@ -253,8 +256,6 @@ public io.ray.serve.generated.DeploymentConfig toProto() {
.setGracefulShutdownTimeoutS(gracefulShutdownTimeoutS)
.setHealthCheckPeriodS(healthCheckPeriodS)
.setHealthCheckTimeoutS(healthCheckTimeoutS)
.setRequestRoutingStatsPeriodS(requestRoutingStatsPeriodS)
.setRequestRoutingStatsTimeoutS(requestRoutingStatsTimeoutS)
.setIsCrossLanguage(isCrossLanguage)
.setDeploymentLanguage(deploymentLanguage);
if (null != userConfig) {
Expand All @@ -263,6 +264,9 @@ public io.ray.serve.generated.DeploymentConfig toProto() {
if (null != autoscalingConfig) {
builder.setAutoscalingConfig(autoscalingConfig.toProto());
}
if (null != routerConfig) {
builder.setRouterConfig(routerConfig.toProto());
}
return builder.build();
}

Expand Down
38 changes: 38 additions & 0 deletions java/serve/src/main/java/io/ray/serve/config/RouterConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package io.ray.serve.config;

import io.ray.serve.common.Constants;
import java.io.Serializable;

public class RouterConfig implements Serializable {
/** Frequency at which the controller will record request routing stats. */
private Double requestRoutingStatsPeriodS = Constants.DEFAULT_REQUEST_ROUTING_STATS_PERIOD_S;

/**
* Timeout that the controller will wait for a response from the replica's request routing stats
* before retrying.
*/
private Double requestRoutingStatsTimeoutS = Constants.DEFAULT_REQUEST_ROUTING_STATS_TIMEOUT_S;

public Double getRequestRoutingStatsPeriodS() {
return requestRoutingStatsPeriodS;
}

public Double getRequestRoutingStatsTimeoutS() {
return requestRoutingStatsTimeoutS;
}

public void setRequestRoutingStatsPeriodS(Double requestRoutingStatsPeriodS) {
this.requestRoutingStatsPeriodS = requestRoutingStatsPeriodS;
}

public void setRequestRoutingStatsTimeoutS(Double requestRoutingStatsTimeoutS) {
this.requestRoutingStatsTimeoutS = requestRoutingStatsTimeoutS;
}

public io.ray.serve.generated.RouterConfig toProto() {
return io.ray.serve.generated.RouterConfig.newBuilder()
.setRequestRoutingStatsPeriodS(requestRoutingStatsPeriodS)
.setRequestRoutingStatsTimeoutS(requestRoutingStatsTimeoutS)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,6 @@ async def construct_request_router(loop: asyncio.AbstractEventLoop):
deployment_id=DeploymentID(name="TEST_DEPLOYMENT"),
handle_source=DeploymentHandleSource.REPLICA,
use_replica_queue_len_cache=False,
imbalanced_threshold=params.get("imbalanced_threshold", 10),
match_rate_threshold=params.get("match_rate_threshold", 0.1),
do_eviction=params.get("do_eviction", False),
eviction_threshold_chars=params.get("eviction_threshold_chars"),
eviction_target_chars=params.get("eviction_target_chars"),
eviction_interval_secs=params.get("eviction_interval_secs"),
get_curr_time_s=TIMER.time,
tree_actor=tree_actor,
)
Expand All @@ -62,6 +56,14 @@ async def construct_request_router(loop: asyncio.AbstractEventLoop):
request_router = asyncio.new_event_loop().run_until_complete(
construct_request_router(get_or_create_event_loop())
)
request_router.initialize_state(
imbalanced_threshold=params.get("imbalanced_threshold", 10),
match_rate_threshold=params.get("match_rate_threshold", 0.1),
do_eviction=params.get("do_eviction", False),
eviction_threshold_chars=params.get("eviction_threshold_chars"),
eviction_target_chars=params.get("eviction_target_chars"),
eviction_interval_secs=params.get("eviction_interval_secs"),
)

yield request_router
assert request_router.curr_num_routing_tasks == 0
Expand Down Expand Up @@ -124,7 +126,7 @@ async def test_fallback_when_no_prompt(self, prefix_request_router):

req = fake_pending_request()
for _ in range(10):
chosen = await prefix_request_router.choose_replica_for_request(req)
chosen = await prefix_request_router._choose_replica_for_request(req)
assert chosen == r1

@pytest.mark.asyncio
Expand Down Expand Up @@ -161,7 +163,7 @@ async def test_fallback_when_imbalanced(self, prefix_request_router):

req = fake_pending_request(prompt="hello world")
for _ in range(10):
chosen = await prefix_request_router.choose_replica_for_request(req)
chosen = await prefix_request_router._choose_replica_for_request(req)
# Even though r2 has a higher match rate, it is not chosen because the load is imbalanced
assert chosen == r1

Expand Down Expand Up @@ -199,13 +201,13 @@ async def test_high_match_rate_selects_matching_replica(

prompt_req = fake_pending_request(prompt="Hello world")
for _ in range(10):
chosen = await prefix_request_router.choose_replica_for_request(prompt_req)
chosen = await prefix_request_router._choose_replica_for_request(prompt_req)
assert chosen == r2
chat_req = fake_pending_request(
messages=[{"content": "Hello"}, {"content": " world"}]
)
for _ in range(10):
chosen = await prefix_request_router.choose_replica_for_request(chat_req)
chosen = await prefix_request_router._choose_replica_for_request(chat_req)
assert chosen == r2

@pytest.mark.asyncio
Expand Down Expand Up @@ -240,14 +242,15 @@ async def test_low_match_rate_uses_smallest_tree(self, prefix_request_router):
for _ in range(10):
# Both tenants have 0% match rate, so the smaller tenant (r1) is chosen
assert (
await prefix_request_router.choose_replica_for_request(prompt_req) == r1
await prefix_request_router._choose_replica_for_request(prompt_req)
== r1
)

chat_req = fake_pending_request(messages=[{"content": "z"}])
for _ in range(10):
# Both tenants have 0% match rate, so the smaller tenant (r1) is chosen
assert (
await prefix_request_router.choose_replica_for_request(chat_req) == r1
await prefix_request_router._choose_replica_for_request(chat_req) == r1
)


Expand Down
Loading
0