diff --git a/Cargo.lock b/Cargo.lock index e4930177..8f827bc8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2939,6 +2939,7 @@ dependencies = [ "base64 0.22.1", "bytes", "clap", + "futures", "hex", "hmac", "ipnet", @@ -2965,6 +2966,7 @@ dependencies = [ "tokio", "tokio-rustls", "tokio-stream", + "tokio-tungstenite 0.26.2", "tonic", "tracing", "tracing-appender", diff --git a/crates/openshell-sandbox/Cargo.toml b/crates/openshell-sandbox/Cargo.toml index 26da57ef..68e696e9 100644 --- a/crates/openshell-sandbox/Cargo.toml +++ b/crates/openshell-sandbox/Cargo.toml @@ -81,6 +81,8 @@ uuid = { version = "1", features = ["v4"] } [dev-dependencies] tempfile = "3" temp-env = "0.3" +tokio-tungstenite = { workspace = true } +futures = { workspace = true } [lints] workspace = true diff --git a/crates/openshell-sandbox/src/l7/provider.rs b/crates/openshell-sandbox/src/l7/provider.rs index df0dfb29..7516aa85 100644 --- a/crates/openshell-sandbox/src/l7/provider.rs +++ b/crates/openshell-sandbox/src/l7/provider.rs @@ -14,6 +14,22 @@ use std::collections::HashMap; use std::future::Future; use tokio::io::{AsyncRead, AsyncWrite}; +/// Outcome of relaying a single HTTP request/response pair. +#[derive(Debug)] +pub enum RelayOutcome { + /// Connection is reusable for further HTTP requests (keep-alive). + Reusable, + /// Connection was consumed (e.g. read-until-EOF or `Connection: close`). + Consumed, + /// Server responded with 101 Switching Protocols. + /// The connection has been upgraded (e.g. to WebSocket) and must be + /// relayed as raw bidirectional TCP from this point forward. + /// Contains any overflow bytes read from upstream past the 101 response + /// headers that belong to the upgraded protocol. The 101 headers + /// themselves have already been forwarded to the client. + Upgraded { overflow: Vec }, +} + /// Body framing for HTTP requests/responses. #[derive(Debug, Clone, Copy)] pub enum BodyLength { @@ -57,14 +73,15 @@ pub trait L7Provider: Send + Sync { /// Forward an allowed request to upstream and relay the response back. /// - /// Returns `true` if the upstream connection is reusable (keep-alive), - /// `false` if it was consumed (e.g. read-until-EOF or `Connection: close`). + /// Returns a [`RelayOutcome`] indicating whether the connection is + /// reusable (keep-alive), consumed, or has been upgraded (101 Switching + /// Protocols) and must be relayed as raw bidirectional TCP. fn relay( &self, req: &L7Request, client: &mut C, upstream: &mut U, - ) -> impl Future> + Send + ) -> impl Future> + Send where C: AsyncRead + AsyncWrite + Unpin + Send, U: AsyncRead + AsyncWrite + Unpin + Send; diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs index 49caea64..b2fb34b6 100644 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -7,7 +7,7 @@ //! Parses each request within the tunnel, evaluates it against OPA policy, //! and either forwards or denies the request. -use crate::l7::provider::L7Provider; +use crate::l7::provider::{L7Provider, RelayOutcome}; use crate::l7::{EnforcementMode, L7EndpointConfig, L7Protocol, L7RequestInfo}; use crate::secrets::{self, SecretResolver}; use miette::{IntoDiagnostic, Result, miette}; @@ -68,6 +68,40 @@ where } } +/// Handle an upgraded connection (101 Switching Protocols). +/// +/// Forwards any overflow bytes from the upgrade response to the client, then +/// switches to raw bidirectional TCP copy for the upgraded protocol (WebSocket, +/// HTTP/2, etc.). L7 policy enforcement does not apply after the upgrade — +/// the initial HTTP request was already evaluated. +async fn handle_upgrade( + client: &mut C, + upstream: &mut U, + overflow: Vec, + host: &str, + port: u16, +) -> Result<()> +where + C: AsyncRead + AsyncWrite + Unpin + Send, + U: AsyncRead + AsyncWrite + Unpin + Send, +{ + info!( + host = %host, + port = port, + overflow_bytes = overflow.len(), + "101 Switching Protocols — switching to raw bidirectional relay \ + (L7 enforcement no longer active)" + ); + if !overflow.is_empty() { + client.write_all(&overflow).await.into_diagnostic()?; + client.flush().await.into_diagnostic()?; + } + tokio::io::copy_bidirectional(client, upstream) + .await + .into_diagnostic()?; + Ok(()) +} + /// REST relay loop: parse request -> evaluate -> allow/deny -> relay response -> repeat. async fn relay_rest( config: &L7EndpointConfig, @@ -137,10 +171,24 @@ where // Evaluate L7 policy via Rego (using redacted target) let (allowed, reason) = evaluate_l7_request(engine, ctx, &request_info)?; - let decision_str = match (allowed, config.enforcement) { - (true, _) => "allow", - (false, EnforcementMode::Audit) => "audit", - (false, EnforcementMode::Enforce) => "deny", + // Check if this is an upgrade request for logging purposes. + let header_end = req + .raw_header + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(req.raw_header.len(), |p| p + 4); + let is_upgrade_request = { + let h = String::from_utf8_lossy(&req.raw_header[..header_end]); + h.lines() + .skip(1) + .any(|l| l.to_ascii_lowercase().starts_with("upgrade:")) + }; + + let decision_str = match (allowed, config.enforcement, is_upgrade_request) { + (true, _, true) => "allow_upgrade", + (true, _, false) => "allow", + (false, EnforcementMode::Audit, _) => "audit", + (false, EnforcementMode::Enforce, _) => "deny", }; // Log every L7 decision (using redacted target — never log real secrets) @@ -162,20 +210,26 @@ where if allowed || config.enforcement == EnforcementMode::Audit { // Forward request to upstream and relay response - let reusable = crate::l7::rest::relay_http_request_with_resolver( + let outcome = crate::l7::rest::relay_http_request_with_resolver( &req, client, upstream, ctx.secret_resolver.as_deref(), ) .await?; - if !reusable { - debug!( - host = %ctx.host, - port = ctx.port, - "Upstream connection not reusable, closing L7 relay" - ); - return Ok(()); + match outcome { + RelayOutcome::Reusable => {} // continue loop + RelayOutcome::Consumed => { + debug!( + host = %ctx.host, + port = ctx.port, + "Upstream connection not reusable, closing L7 relay" + ); + return Ok(()); + } + RelayOutcome::Upgraded { overflow } => { + return handle_upgrade(client, upstream, overflow, &ctx.host, ctx.port).await; + } } } else { // Enforce mode: deny with 403 and close connection (use redacted target) @@ -334,12 +388,16 @@ where // Forward request with credential rewriting and relay the response. // relay_http_request_with_resolver handles both directions: it sends // the request upstream and reads the response back to the client. - let reusable = + let outcome = crate::l7::rest::relay_http_request_with_resolver(&req, client, upstream, resolver) .await?; - if !reusable { - break; + match outcome { + RelayOutcome::Reusable => {} // continue loop + RelayOutcome::Consumed => break, + RelayOutcome::Upgraded { overflow } => { + return handle_upgrade(client, upstream, overflow, &ctx.host, ctx.port).await; + } } } diff --git a/crates/openshell-sandbox/src/l7/rest.rs b/crates/openshell-sandbox/src/l7/rest.rs index ec5494c9..0c136be7 100644 --- a/crates/openshell-sandbox/src/l7/rest.rs +++ b/crates/openshell-sandbox/src/l7/rest.rs @@ -7,12 +7,12 @@ //! policy, and relays allowed requests to upstream. Handles Content-Length //! and chunked transfer encoding for body framing. -use crate::l7::provider::{BodyLength, L7Provider, L7Request}; +use crate::l7::provider::{BodyLength, L7Provider, L7Request, RelayOutcome}; use crate::secrets::rewrite_http_header_block; use miette::{IntoDiagnostic, Result, miette}; use std::collections::HashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tracing::debug; +use tracing::{debug, warn}; const MAX_HEADER_BYTES: usize = 16384; // 16 KiB for HTTP headers const RELAY_BUF_SIZE: usize = 8192; @@ -32,7 +32,12 @@ impl L7Provider for RestProvider { parse_http_request(client).await } - async fn relay(&self, req: &L7Request, client: &mut C, upstream: &mut U) -> Result + async fn relay( + &self, + req: &L7Request, + client: &mut C, + upstream: &mut U, + ) -> Result where C: AsyncRead + AsyncWrite + Unpin + Send, U: AsyncRead + AsyncWrite + Unpin + Send, @@ -236,8 +241,13 @@ fn decode_hex_nibble(byte: u8) -> Option { /// Forward an allowed HTTP request to upstream and relay the response back. /// -/// Returns `true` if the upstream connection is reusable, `false` if consumed. -async fn relay_http_request(req: &L7Request, client: &mut C, upstream: &mut U) -> Result +/// Returns the relay outcome indicating whether the connection is reusable, +/// consumed, or has been upgraded (e.g. WebSocket via 101 Switching Protocols). +async fn relay_http_request( + req: &L7Request, + client: &mut C, + upstream: &mut U, +) -> Result where C: AsyncRead + AsyncWrite + Unpin, U: AsyncRead + AsyncWrite + Unpin, @@ -250,7 +260,7 @@ pub(crate) async fn relay_http_request_with_resolver( client: &mut C, upstream: &mut U, resolver: Option<&crate::secrets::SecretResolver>, -) -> Result +) -> Result where C: AsyncRead + AsyncWrite + Unpin, U: AsyncRead + AsyncWrite + Unpin, @@ -288,8 +298,27 @@ where BodyLength::None => {} } upstream.flush().await.into_diagnostic()?; - let (reusable, _) = relay_response(&req.action, upstream, client).await?; - Ok(reusable) + + let outcome = relay_response(&req.action, upstream, client).await?; + + // Validate that the client actually requested an upgrade before accepting + // a 101 from upstream. Per RFC 9110 Section 7.8, the server MUST NOT send + // 101 unless the client sent Upgrade + Connection: Upgrade headers. A + // non-compliant or malicious upstream could send an unsolicited 101 to + // bypass L7 inspection. + if matches!(outcome, RelayOutcome::Upgraded { .. }) { + let header_str = String::from_utf8_lossy(&req.raw_header[..header_end]); + if !client_requested_upgrade(&header_str) { + warn!( + method = %req.action, + target = %req.target, + "upstream sent unsolicited 101 without client Upgrade request — closing connection" + ); + return Ok(RelayOutcome::Consumed); + } + } + + Ok(outcome) } /// Send a 403 Forbidden JSON deny response. @@ -525,29 +554,28 @@ fn find_crlf(buf: &[u8], start: usize) -> Option { /// Read and relay a full HTTP response (headers + body) from upstream to client. /// -/// Returns `true` if the upstream connection is reusable (keep-alive), -/// `false` if it was consumed (read-until-EOF or `Connection: close`). -/// Relay an HTTP response from upstream back to the client. +/// Returns a [`RelayOutcome`] indicating whether the connection is reusable, +/// consumed, or has been upgraded (101 Switching Protocols). /// -/// Returns `true` if the connection should stay alive for further requests. +/// Note: callers that receive `Upgraded` are responsible for switching to +/// raw bidirectional relay and forwarding the overflow bytes. pub(crate) async fn relay_response_to_client( upstream: &mut U, client: &mut C, request_method: &str, -) -> Result +) -> Result where U: AsyncRead + Unpin, C: AsyncWrite + Unpin, { - let (reusable, _status) = relay_response(request_method, upstream, client).await?; - Ok(reusable) + relay_response(request_method, upstream, client).await } async fn relay_response( request_method: &str, upstream: &mut U, client: &mut C, -) -> Result<(bool, u16)> +) -> Result where U: AsyncRead + Unpin, C: AsyncWrite + Unpin, @@ -568,7 +596,7 @@ where if !buf.is_empty() { client.write_all(&buf).await.into_diagnostic()?; } - return Ok((false, 0)); + return Ok(RelayOutcome::Consumed); } buf.extend_from_slice(&tmp[..n]); @@ -594,6 +622,26 @@ where "relay_response framing" ); + // 101 Switching Protocols: the connection has been upgraded (e.g. to + // WebSocket). Forward the 101 headers to the client and signal the + // caller to switch to raw bidirectional TCP relay. Any bytes read + // from upstream beyond the headers are overflow that belong to the + // upgraded protocol and must be forwarded before switching. + if status_code == 101 { + client + .write_all(&buf[..header_end]) + .await + .into_diagnostic()?; + client.flush().await.into_diagnostic()?; + let overflow = buf[header_end..].to_vec(); + debug!( + request_method, + overflow_bytes = overflow.len(), + "101 Switching Protocols — signaling protocol upgrade" + ); + return Ok(RelayOutcome::Upgraded { overflow }); + } + // Bodiless responses (HEAD, 1xx, 204, 304): forward headers only, skip body if is_bodiless_response(request_method, status_code) { client @@ -601,7 +649,11 @@ where .await .into_diagnostic()?; client.flush().await.into_diagnostic()?; - return Ok((!server_wants_close, status_code)); + return if server_wants_close { + Ok(RelayOutcome::Consumed) + } else { + Ok(RelayOutcome::Reusable) + }; } // No explicit framing (no Content-Length, no Transfer-Encoding). @@ -621,7 +673,7 @@ where } relay_until_eof(upstream, client).await?; client.flush().await.into_diagnostic()?; - return Ok((false, status_code)); + return Ok(RelayOutcome::Consumed); } // No Connection: close — an HTTP/1.1 keep-alive server that omits // framing headers has an empty body. Forward headers and continue @@ -632,7 +684,7 @@ where .await .into_diagnostic()?; client.flush().await.into_diagnostic()?; - return Ok((true, status_code)); + return Ok(RelayOutcome::Reusable); } // Forward response headers + any overflow body bytes @@ -665,7 +717,7 @@ where // loop will exit via the normal error path. Exiting early here would // tear down the CONNECT tunnel before the client can detect the close, // causing ~30 s retry delays in clients like `gh`. - Ok((true, status_code)) + Ok(RelayOutcome::Reusable) } /// Parse the HTTP status code from a response status line. @@ -689,6 +741,33 @@ fn parse_connection_close(headers: &str) -> bool { false } +/// Check if the client request headers contain both `Upgrade` and +/// `Connection: Upgrade` headers, indicating the client requested a +/// protocol upgrade (e.g. WebSocket). +/// +/// Per RFC 9110 Section 7.8, a server MUST NOT send 101 Switching Protocols +/// unless the client sent these headers. +fn client_requested_upgrade(headers: &str) -> bool { + let mut has_upgrade_header = false; + let mut connection_contains_upgrade = false; + + for line in headers.lines().skip(1) { + let lower = line.to_ascii_lowercase(); + if lower.starts_with("upgrade:") { + has_upgrade_header = true; + } + if lower.starts_with("connection:") { + let val = lower.split_once(':').map_or("", |(_, v)| v.trim()); + // Connection header can have comma-separated values + if val.split(',').any(|tok| tok.trim() == "upgrade") { + connection_contains_upgrade = true; + } + } + } + + has_upgrade_header && connection_contains_upgrade +} + /// Returns true for responses that MUST NOT contain a message body per RFC 7230 §3.3.3: /// HEAD responses, 1xx informational, 204 No Content, 304 Not Modified. fn is_bodiless_response(request_method: &str, status_code: u16) -> bool { @@ -1136,8 +1215,11 @@ mod tests { .await .expect("relay_response should not deadlock"); - let (reusable, _status) = result.expect("relay_response should succeed"); - assert!(!reusable, "connection consumed by read-until-EOF"); + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Consumed), + "connection consumed by read-until-EOF" + ); client_write.shutdown().await.unwrap(); let mut received = Vec::new(); @@ -1174,8 +1256,11 @@ mod tests { .await .expect("must not block when no Connection: close"); - let (reusable, _status) = result.expect("relay_response should succeed"); - assert!(reusable, "keep-alive implied, connection reusable"); + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Reusable), + "keep-alive implied, connection reusable" + ); client_write.shutdown().await.unwrap(); let mut received = Vec::new(); @@ -1207,8 +1292,11 @@ mod tests { .await .expect("HEAD relay must not deadlock waiting for body"); - let (reusable, _status) = result.expect("relay_response should succeed"); - assert!(reusable, "HEAD response should be reusable"); + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Reusable), + "HEAD response should be reusable" + ); client_write.shutdown().await.unwrap(); let mut received = Vec::new(); @@ -1237,8 +1325,11 @@ mod tests { .await .expect("204 relay must not deadlock"); - let (reusable, _status) = result.expect("relay_response should succeed"); - assert!(reusable, "204 response should be reusable"); + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Reusable), + "204 response should be reusable" + ); client_write.shutdown().await.unwrap(); let mut received = Vec::new(); @@ -1269,8 +1360,11 @@ mod tests { .await .expect("must not block when chunked body is complete in overflow"); - let (reusable, _status) = result.expect("relay_response should succeed"); - assert!(reusable, "connection should be reusable"); + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Reusable), + "connection should be reusable" + ); client_write.shutdown().await.unwrap(); let mut received = Vec::new(); @@ -1305,8 +1399,11 @@ mod tests { .await .expect("must not block when chunked response has trailers"); - let (reusable, _status) = result.expect("relay_response should succeed"); - assert!(reusable, "chunked response should be reusable"); + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Reusable), + "chunked response should be reusable" + ); client_write.shutdown().await.unwrap(); let mut received = Vec::new(); @@ -1340,8 +1437,11 @@ mod tests { .await .expect("normal relay must not deadlock"); - let (reusable, _status) = result.expect("relay_response should succeed"); - assert!(reusable, "Content-Length response should be reusable"); + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Reusable), + "Content-Length response should be reusable" + ); client_write.shutdown().await.unwrap(); let mut received = Vec::new(); @@ -1368,12 +1468,12 @@ mod tests { .await .expect("relay must not deadlock"); - let (reusable, _status) = result.expect("relay_response should succeed"); + let outcome = result.expect("relay_response should succeed"); // With explicit framing, Connection: close is still reported as reusable // so the relay loop continues. The *next* upstream write will fail and // exit the loop via the normal error path. assert!( - reusable, + matches!(outcome, RelayOutcome::Reusable), "explicit framing keeps loop alive despite Connection: close" ); @@ -1383,6 +1483,224 @@ mod tests { assert!(String::from_utf8_lossy(&received).contains("hello")); } + #[tokio::test] + async fn relay_response_101_switching_protocols_returns_upgraded_with_overflow() { + // Build a 101 response followed by WebSocket frame data (overflow). + let mut response = Vec::new(); + response.extend_from_slice(b"HTTP/1.1 101 Switching Protocols\r\n"); + response.extend_from_slice(b"Upgrade: websocket\r\n"); + response.extend_from_slice(b"Connection: Upgrade\r\n"); + response.extend_from_slice(b"\r\n"); + response.extend_from_slice(b"\x81\x05hello"); // WebSocket frame + + let (upstream_read, mut upstream_write) = tokio::io::duplex(4096); + let (mut client_read, client_write) = tokio::io::duplex(4096); + + upstream_write.write_all(&response).await.unwrap(); + drop(upstream_write); + + let mut upstream_read = upstream_read; + let mut client_write = client_write; + + let result = tokio::time::timeout( + std::time::Duration::from_secs(2), + relay_response("GET", &mut upstream_read, &mut client_write), + ) + .await + .expect("relay_response should not deadlock"); + + let outcome = result.expect("relay_response should succeed"); + match outcome { + RelayOutcome::Upgraded { overflow } => { + assert_eq!( + &overflow, b"\x81\x05hello", + "overflow should contain WebSocket frame data" + ); + } + other => panic!("Expected Upgraded, got {other:?}"), + } + + client_write.shutdown().await.unwrap(); + let mut received = Vec::new(); + client_read.read_to_end(&mut received).await.unwrap(); + let received_str = String::from_utf8_lossy(&received); + assert!( + received_str.contains("101 Switching Protocols"), + "client should receive the 101 response headers" + ); + } + + #[tokio::test] + async fn relay_response_101_no_overflow() { + // 101 response with no trailing bytes — overflow should be empty. + let response = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n"; + + let (upstream_read, mut upstream_write) = tokio::io::duplex(4096); + let (_client_read, client_write) = tokio::io::duplex(4096); + + upstream_write.write_all(response).await.unwrap(); + drop(upstream_write); + + let mut upstream_read = upstream_read; + let mut client_write = client_write; + + let result = tokio::time::timeout( + std::time::Duration::from_secs(2), + relay_response("GET", &mut upstream_read, &mut client_write), + ) + .await + .expect("relay_response should not deadlock"); + + match result.expect("should succeed") { + RelayOutcome::Upgraded { overflow } => { + assert!(overflow.is_empty(), "no overflow expected"); + } + other => panic!("Expected Upgraded, got {other:?}"), + } + } + + #[tokio::test] + async fn relay_rejects_unsolicited_101_without_client_upgrade_header() { + // Client sends a normal GET without Upgrade headers. + // Upstream responds with 101 (non-compliant). The relay should + // reject the upgrade and return Consumed instead. + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let req = L7Request { + action: "GET".to_string(), + target: "/api".to_string(), + query_params: HashMap::new(), + raw_header: b"GET /api HTTP/1.1\r\nHost: example.com\r\n\r\n".to_vec(), + body_length: BodyLength::None, + }; + + let upstream_task = tokio::spawn(async move { + // Read the request + let mut buf = vec![0u8; 4096]; + let mut total = 0; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + // Send unsolicited 101 + upstream_side + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n", + ) + .await + .unwrap(); + upstream_side.flush().await.unwrap(); + }); + + let result = tokio::time::timeout( + std::time::Duration::from_secs(5), + relay_http_request_with_resolver( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + None, + ), + ) + .await + .expect("relay must not deadlock"); + + let outcome = result.expect("relay should succeed"); + assert!( + matches!(outcome, RelayOutcome::Consumed), + "unsolicited 101 should be rejected as Consumed, got {outcome:?}" + ); + + upstream_task.await.expect("upstream task should complete"); + } + + #[tokio::test] + async fn relay_accepts_101_with_client_upgrade_header() { + // Client sends a proper upgrade request with Upgrade + Connection headers. + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let req = L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: b"GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n".to_vec(), + body_length: BodyLength::None, + }; + + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 4096]; + let mut total = 0; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + upstream_side + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n", + ) + .await + .unwrap(); + upstream_side.flush().await.unwrap(); + }); + + let result = tokio::time::timeout( + std::time::Duration::from_secs(5), + relay_http_request_with_resolver( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + None, + ), + ) + .await + .expect("relay must not deadlock"); + + let outcome = result.expect("relay should succeed"); + assert!( + matches!(outcome, RelayOutcome::Upgraded { .. }), + "proper upgrade request should be accepted, got {outcome:?}" + ); + + upstream_task.await.expect("upstream task should complete"); + } + + #[test] + fn client_requested_upgrade_detects_websocket_headers() { + let headers = "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n"; + assert!(client_requested_upgrade(headers)); + } + + #[test] + fn client_requested_upgrade_rejects_missing_upgrade_header() { + let headers = "GET /api HTTP/1.1\r\nHost: example.com\r\n\r\n"; + assert!(!client_requested_upgrade(headers)); + } + + #[test] + fn client_requested_upgrade_rejects_upgrade_without_connection() { + let headers = "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\n\r\n"; + assert!(!client_requested_upgrade(headers)); + } + + #[test] + fn client_requested_upgrade_handles_comma_separated_connection() { + let headers = "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: keep-alive, Upgrade\r\n\r\n"; + assert!(client_requested_upgrade(headers)); + } + #[test] fn rewrite_header_block_resolves_placeholder_auth_headers() { let (_, resolver) = SecretResolver::from_provider_env( diff --git a/crates/openshell-sandbox/tests/websocket_upgrade.rs b/crates/openshell-sandbox/tests/websocket_upgrade.rs new file mode 100644 index 00000000..ec226c9c --- /dev/null +++ b/crates/openshell-sandbox/tests/websocket_upgrade.rs @@ -0,0 +1,259 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Integration test: WebSocket upgrade through the L7 relay. +//! +//! Spins up a dummy WebSocket echo server, connects a client through the +//! `L7Provider::relay` pipeline, validates the 101 upgrade succeeds, and +//! exchanges a WebSocket text frame bidirectionally. +//! +//! This test exercises the full upgrade path described in issue #652: +//! 1. Client sends HTTP GET with `Upgrade: websocket` headers +//! 2. Relay forwards to upstream, upstream responds with 101 +//! 3. Relay detects 101, validates client Upgrade headers, returns `Upgraded` +//! 4. Caller forwards overflow + switches to `copy_bidirectional` +//! 5. Client and server exchange a WebSocket text message +//! +//! Reproduction scenario from #652: raw socket test sends upgrade request +//! through the proxy, receives 101, then verifies WebSocket frames flow. + +use futures::SinkExt; +use futures::stream::StreamExt; +use openshell_sandbox::l7::provider::{BodyLength, L7Provider, L7Request, RelayOutcome}; +use openshell_sandbox::l7::rest::RestProvider; +use std::collections::HashMap; +use std::net::SocketAddr; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_tungstenite::accept_async; +use tokio_tungstenite::tungstenite::Message; + +/// Start a minimal WebSocket echo server on an ephemeral port. +async fn start_ws_echo_server() -> SocketAddr { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let ws_stream = accept_async(stream).await.unwrap(); + let (mut write, mut read) = ws_stream.split(); + + while let Some(msg) = read.next().await { + match msg { + Ok(Message::Text(text)) => { + write + .send(Message::Text(format!("echo: {text}").into())) + .await + .unwrap(); + } + Ok(Message::Close(_)) => break, + Ok(_) => {} + Err(_) => break, + } + } + }); + + addr +} + +/// Build raw HTTP upgrade request bytes (mimics the reproduction script from #652). +fn build_ws_upgrade_request(host: &str) -> Vec { + format!( + "GET / HTTP/1.1\r\n\ + Host: {host}\r\n\ + Upgrade: websocket\r\n\ + Connection: Upgrade\r\n\ + Sec-WebSocket-Key: RylUQAh3p5cysfOlexgubw==\r\n\ + Sec-WebSocket-Version: 13\r\n\ + \r\n" + ) + .into_bytes() +} + +/// Build a masked WebSocket text frame (client -> server must be masked per RFC 6455). +fn build_ws_text_frame(payload: &[u8]) -> Vec { + let mask_key: [u8; 4] = [0x37, 0xfa, 0x21, 0x3d]; + let mut frame = Vec::new(); + frame.push(0x81); // FIN + text opcode + frame.push(0x80 | payload.len() as u8); // masked + length + frame.extend_from_slice(&mask_key); + for (i, b) in payload.iter().enumerate() { + frame.push(b ^ mask_key[i % 4]); + } + frame +} + +/// Core test: WebSocket upgrade through `L7Provider::relay`, then exchange a message. +/// +/// This mirrors the reproduction steps from issue #652: +/// - Send WebSocket upgrade → receive 101 → verify frames flow bidirectionally +/// - Previously, 101 was treated as a generic 1xx and frames were dropped +#[tokio::test] +async fn websocket_upgrade_through_l7_relay_exchanges_message() { + let ws_addr = start_ws_echo_server().await; + + // Open a real TCP connection to the WebSocket server (simulates upstream) + let mut upstream = TcpStream::connect(ws_addr).await.unwrap(); + + // In-memory duplex for the client side of the relay + let (mut client_app, mut client_proxy) = tokio::io::duplex(8192); + + let host = format!("127.0.0.1:{}", ws_addr.port()); + let raw_header = build_ws_upgrade_request(&host); + + let req = L7Request { + action: "GET".to_string(), + target: "/".to_string(), + query_params: HashMap::new(), + raw_header, + body_length: BodyLength::None, + }; + + // Run the relay in a background task (simulates what relay_rest does) + let relay_handle = tokio::spawn(async move { + let outcome = RestProvider + .relay(&req, &mut client_proxy, &mut upstream) + .await + .expect("relay should succeed"); + + match outcome { + RelayOutcome::Upgraded { overflow } => { + // This is what handle_upgrade() does in relay.rs + if !overflow.is_empty() { + client_proxy.write_all(&overflow).await.unwrap(); + client_proxy.flush().await.unwrap(); + } + let _ = tokio::io::copy_bidirectional(&mut client_proxy, &mut upstream).await; + } + other => panic!("Expected Upgraded, got {other:?}"), + } + }); + + // Client side: read the 101 response headers byte-by-byte + // (mirrors the reproduction script's recv() after sending the upgrade) + let mut response_buf = Vec::new(); + let mut tmp = [0u8; 1]; + tokio::time::timeout(std::time::Duration::from_secs(5), async { + loop { + client_app.read_exact(&mut tmp).await.unwrap(); + response_buf.push(tmp[0]); + if response_buf.ends_with(b"\r\n\r\n") { + break; + } + } + }) + .await + .expect("should receive 101 headers within 5 seconds"); + + let response_str = String::from_utf8_lossy(&response_buf); + assert!( + response_str.contains("101 Switching Protocols"), + "should receive 101, got: {response_str}" + ); + + // ---- This is the part that was broken before the fix (issue #652) ---- + // Previously, after 101, the relay re-entered the HTTP parsing loop and + // all WebSocket frames were silently dropped. The reproduction script + // would see RECV2: TIMEOUT here. + + // Send a WebSocket text frame + let frame = build_ws_text_frame(b"hello"); + client_app.write_all(&frame).await.unwrap(); + client_app.flush().await.unwrap(); + + // Read the echo response (unmasked server -> client frame) + tokio::time::timeout(std::time::Duration::from_secs(5), async { + let mut header = [0u8; 2]; + client_app.read_exact(&mut header).await.unwrap(); + + let fin_opcode = header[0]; + assert_eq!(fin_opcode & 0x0F, 1, "should be text frame"); + assert!(fin_opcode & 0x80 != 0, "FIN bit should be set"); + + let len = (header[1] & 0x7F) as usize; + let mut payload_buf = vec![0u8; len]; + client_app.read_exact(&mut payload_buf).await.unwrap(); + let text = String::from_utf8(payload_buf).unwrap(); + assert_eq!( + text, "echo: hello", + "server should echo our message back through the relay" + ); + }) + .await + .expect("should receive WebSocket echo within 5 seconds (previously timed out per #652)"); + + // Clean shutdown + let close_frame = [0x88, 0x82, 0x00, 0x00, 0x00, 0x00, 0x03, 0xe8]; + let _ = client_app.write_all(&close_frame).await; + drop(client_app); + + let _ = tokio::time::timeout(std::time::Duration::from_secs(2), relay_handle).await; +} + +/// Test that a normal (non-upgrade) HTTP request still works correctly +/// after the relay_response changes. Ensures the 101 detection doesn't +/// break regular HTTP traffic. +#[tokio::test] +async fn normal_http_request_still_works_after_relay_changes() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + // Simple HTTP echo server + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut buf = vec![0u8; 4096]; + let mut total = 0; + loop { + let n = stream.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + stream + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok") + .await + .unwrap(); + stream.flush().await.unwrap(); + }); + + let mut upstream = TcpStream::connect(addr).await.unwrap(); + let (mut client_read, mut client_proxy) = tokio::io::duplex(8192); + + let raw_header = format!( + "GET /api HTTP/1.1\r\nHost: 127.0.0.1:{}\r\n\r\n", + addr.port() + ) + .into_bytes(); + + let req = L7Request { + action: "GET".to_string(), + target: "/api".to_string(), + query_params: HashMap::new(), + raw_header, + body_length: BodyLength::None, + }; + + let outcome = tokio::time::timeout( + std::time::Duration::from_secs(5), + RestProvider.relay(&req, &mut client_proxy, &mut upstream), + ) + .await + .expect("should not deadlock") + .expect("relay should succeed"); + + assert!( + matches!(outcome, RelayOutcome::Reusable), + "normal 200 response should be Reusable, got {outcome:?}" + ); + + client_proxy.shutdown().await.unwrap(); + let mut received = Vec::new(); + client_read.read_to_end(&mut received).await.unwrap(); + let body = String::from_utf8_lossy(&received); + assert!(body.contains("200 OK"), "should forward 200 response"); + assert!(body.contains("ok"), "should forward response body"); +}