8000 refactor: restructure websocket implementation by augustoccesar · Pull Request #235 · mentimeter/linkup · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

refactor: restructure websocket implementation #235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 13, 2025
Merged
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion local-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ name = "linkup_local_server"
path = "src/lib.rs"

[dependencies]
axum = { version = "0.8.1", features = ["http2", "json"] }
axum = { version = "0.8.1", features = ["http2", "json", "ws"] }
axum-server = { version = "0.7", features = ["tls-rustls"] }
http = "1.2.0"
hickory-server = { version = "0.25.1", features = ["resolver"] }
Expand All @@ -28,6 +28,7 @@ tokio = { version = "1.43.0", features = [
"signal",
"rt-multi-thread",
] }
tokio-tungstenite = "0.26.1"
tower-http = { version = "0.6.2", features = ["trace"] }
tower = "0.5.2"
rcgen = { version = "0.13", features = ["x509-parser"] }
Expand Down
179 changes: 56 additions & 123 deletions local-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use http::{header::HeaderMap, Uri};
use hyper_rustls::HttpsConnector;
use hyper_util::{
client::legacy::{connect::HttpConnector, Client},
rt::{TokioExecutor, TokioIo},
rt::TokioExecutor,
};
use linkup::{
allow_all_cors, get_additional_headers, get_target_service, MemoryStringStore, NameKind,
Expand All @@ -40,10 +40,12 @@ use std::{
};
use std::{path::Path, sync::Arc};
use tokio::{net::UdpSocket, signal};
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tower::ServiceBuilder;
use tower_http::trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer};

pub mod certificates;
mod ws;

type HttpsClient = Client<HttpsConnector<HttpConnector>, Body>;

Expand Down Expand Up @@ -182,6 +184,7 @@ pub async fn start_dns_server(linkup_session_name: String, domains: Vec<String>)
async fn linkup_request_handler(
Extension(store): Extension<MemoryStringStore>,
Extension(client): Extension<HttpsClient>,
ws: ws::ExtractOptionalWebSocketUpgrade,
req: Request,
) -> Response {
let sessions = SessionAllocator::new(&store);
10000 Expand Down Expand Up @@ -224,15 +227,58 @@ async fn linkup_request_handler(

let extra_headers = get_additional_headers(&url, &headers, &session_name, &target_service);

if req
.headers()
.get("upgrade")
.map(|v| v == "websocket")
.unwrap_or(false)
{
handle_ws_req(req, target_service, extra_headers, client).await
} else {
handle_http_req(req, target_service, extra_headers, client).await
match ws.0 {
Some(downstream_upgrade) => {
let mut url = target_service.url;
if url.starts_with("http://") {
url = url.replace("http://", "ws://");
} else if url.starts_with("https://") {
url = url.replace("https://", "wss://");
}

let uri = url.parse::<Uri>().unwrap();
let mut upstream_request = uri.into_client_request().unwrap();

let extra_http_headers: HeaderMap = extra_headers.into();
for (key, value) in extra_http_headers.iter() {
upstream_request.headers_mut().insert(key, value.clone());
}

let (upstream_ws_stream, upstream_response) =
match tokio_tungstenite::connect_async(upstream_request).await {
Ok(connection) => connection,
Err(error) => match error {
tokio_tungstenite::tungstenite::Error::Http(response) => {
let (parts, body) = response.into_parts();
let body = body.unwrap_or_default();

return Response::from_parts(parts, Body::from(body));
}
error => {
return Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Body::from(error.to_string()))
.unwrap()
}
},
};

let mut upstream_upgrade_response =
downstream_upgrade.on_upgrade(ws::context_handle_socket(upstream_ws_stream));

let websocket_upgrade_response_headers = upstream_upgrade_response.headers_mut();
for upstream_header in upstream_response.headers() {
if !websocket_upgrade_response_headers.contains_key(upstream_header.0) {
websocket_upgrade_response_headers
.append(upstream_header.0, upstream_header.1.clone());
}
}

websocket_upgrade_response_headers.extend(allow_all_cors());

upstream_upgrade_response
}
None => handle_http_req(req, target_service, extra_headers, client).await,
}
}

Expand Down Expand Up @@ -272,119 +318,6 @@ async fn handle_http_req(
resp.into_response()
}

async fn handle_ws_req(
req: Request,
target_service: TargetService,
extra_headers: linkup::HeaderMap,
client: HttpsClient,
) -> Response {
let extra_http_headers: HeaderMap = extra_headers.into();

let target_ws_req_result = Request::builder()
.uri(target_service.url)
.method(req.method().clone())
.body(Body::empty());

let mut target_ws_req = match target_ws_req_result {
Ok(request) => request,
Err(e) => {
return ApiError::new(
format!("Failed to build request: {}", e),
StatusCode::INTERNAL_SERVER_ERROR,
)
.into_response();
}
};

target_ws_req.headers_mut().extend(req.headers().clone());
target_ws_req.headers_mut().extend(extra_http_headers);
target_ws_req.headers_mut().remove(http::header::HOST);

// Send the modified request to the target service.
let target_ws_resp = match client.request(target_ws_req).await {
Ok(resp) => resp,
Err(e) => {
return ApiError::new(
format!("Failed to proxy request: {}", e),
StatusCode::BAD_GATEWAY,
)
.into_response()
}
};

let status = target_ws_resp.status();
if status != 101 {
return ApiError::new(
format!(
"Failed to proxy request: expected 101 Switching Protocols, got {}",
status
),
StatusCode::BAD_GATEWAY,
)
.into_response();
}

let target_ws_resp_headers = target_ws_resp.headers().clone();

let upgraded_target = match hyper::upgrade::on(target_ws_resp).await {
Ok(upgraded) => upgraded,
Err(e) => {
return ApiError::new(
format!("Failed to upgrade connection: {}", e),
StatusCode::BAD_GATEWAY,
)
.into_response()
}
};

tokio::spawn(async move {
// We won't get passed this until the 101 response returns to the client
let upgraded_incoming = match hyper::upgrade::on(req).await {
Ok(upgraded) => upgraded,
Err(e) => {
println!("Failed to upgrade incoming connection: {}", e);
return;
}
};

let mut incoming_stream = TokioIo::new(upgraded_incoming);
let mut target_stream = TokioIo::new(upgraded_target);

let res = tokio::io::copy_bidirectional(&mut incoming_stream, &mut target_stream).await;

match res {
Ok((incoming_to_target, target_to_incoming)) => {
println!(
"Copied {} bytes from incoming to target and {} bytes from target to incoming",
incoming_to_target, target_to_incoming
);
}
Err(e) => {
eprintln!("Error copying between incoming and target: {}", e);
}
}
});

let mut resp_builder = Response::builder().status(101);
let resp_headers_result = resp_builder.headers_mut();
if let Some(resp_headers) = resp_headers_result {
for (header, value) in target_ws_resp_headers {
if let Some(header_name) = header {
resp_headers.append(header_name, value);
}
}
}

match resp_builder.body(Body::empty()) {
Ok(response) => response,
Err(e) => ApiError::new(
format!("Failed to build response: {}", e),
StatusCode::INTERNAL_SERVER_ERROR,
)
.into_response(),
}
}

async fn linkup_config_handler(
Extension(store): Extension<MemoryStringStore>,
Json(update_req): Json<UpdateSessionRequest>,
Expand Down
Loading
0