main.rs (73113B)
1 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or 2 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license 3 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your 4 // option. This file may not be copied, modified, or distributed 5 // except according to those terms. 6 7 use base64::prelude::*; 8 use neqo_bin::server::{HttpServer, Runner}; 9 use neqo_common::Bytes; 10 use neqo_common::{event::Provider, qdebug, qerror, qinfo, qtrace, Datagram, Header}; 11 use neqo_crypto::{generate_ech_keys, init_db, AllowZeroRtt, AntiReplay}; 12 use neqo_http3::{ 13 ConnectUdpRequest, ConnectUdpServerEvent, Error, Http3OrWebTransportStream, Http3Parameters, 14 Http3Server, Http3ServerEvent, SessionAcceptAction, StreamId, WebTransportRequest, 15 WebTransportServerEvent, 16 }; 17 use neqo_transport::server::ConnectionRef; 18 use neqo_transport::{ 19 ConnectionEvent, ConnectionParameters, OutputBatch, RandomConnectionIdGenerator, StreamType, 20 }; 21 use std::env; 22 use std::pin::Pin; 23 use std::task::{Context, Poll}; 24 use tokio::io::AsyncWriteExt; 25 use tokio::io::ReadBuf; 26 use tokio::task::LocalSet; 27 28 use std::cell::RefCell; 29 use std::io; 30 use std::num::NonZeroUsize; 31 use std::path::PathBuf; 32 use std::process::exit; 33 use std::rc::Rc; 34 use std::thread; 35 use std::time::{Duration, Instant}; 36 37 use cfg_if::cfg_if; 38 39 cfg_if! { 40 if #[cfg(not(target_os = "android"))] { 41 use std::sync::mpsc::{channel, Receiver, TryRecvError}; 42 use hyper::body::HttpBody; 43 use hyper::header::{HeaderName, HeaderValue}; 44 use hyper::{Body, Client, Method, Request}; 45 } 46 } 47 48 use std::cmp::min; 49 use std::collections::hash_map::DefaultHasher; 50 use std::collections::HashSet; 51 use std::collections::{HashMap, VecDeque}; 52 use std::hash::{Hash, Hasher}; 53 use std::net::SocketAddr; 54 55 const MAX_TABLE_SIZE: u64 = 65536; 56 const MAX_BLOCKED_STREAMS: u16 = 10; 57 const PROTOCOLS: &[&str] = &["h3"]; 58 const ECH_CONFIG_ID: u8 = 7; 59 const ECH_PUBLIC_NAME: &str = "public.example"; 60 61 const HTTP_RESPONSE_WITH_WRONG_FRAME: &[u8] = &[ 62 0x01, 0x06, 0x00, 0x00, 0xd9, 0x54, 0x01, 0x37, // headers 63 0x0, 0x3, 0x61, 0x62, 0x63, // the first data frame 64 0x3, 0x1, 0x5, // a cancel push frame that is not allowed 65 ]; 66 struct Http3TestServer { 67 server: Http3Server, 68 // This a map from a post request to amount of data ithas been received on the request. 69 // The respons will carry the amount of data received. 70 posts: HashMap<Http3OrWebTransportStream, usize>, 71 responses: HashMap<Http3OrWebTransportStream, Vec<u8>>, 72 connections_to_close: HashMap<Instant, Vec<ConnectionRef>>, 73 sessions_to_close: HashMap<Instant, Vec<WebTransportRequest>>, 74 sessions_to_create_stream: Vec<(WebTransportRequest, StreamType, Option<Vec<u8>>)>, 75 webtransport_bidi_stream: HashSet<Http3OrWebTransportStream>, 76 wt_unidi_conn_to_stream: HashMap<ConnectionRef, Http3OrWebTransportStream>, 77 wt_unidi_echo_back: HashMap<Http3OrWebTransportStream, Http3OrWebTransportStream>, 78 received_datagram: Option<Bytes>, 79 } 80 81 impl ::std::fmt::Display for Http3TestServer { 82 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { 83 write!(f, "{}", self.server) 84 } 85 } 86 impl Http3TestServer { 87 pub fn new(server: Http3Server) -> Self { 88 Self { 89 server, 90 posts: HashMap::new(), 91 responses: HashMap::new(), 92 connections_to_close: HashMap::new(), 93 sessions_to_close: HashMap::new(), 94 sessions_to_create_stream: Vec::new(), 95 webtransport_bidi_stream: HashSet::new(), 96 wt_unidi_conn_to_stream: HashMap::new(), 97 wt_unidi_echo_back: HashMap::new(), 98 received_datagram: None, 99 } 100 } 101 102 fn new_response(&mut self, stream: Http3OrWebTransportStream, mut data: Vec<u8>, now: Instant) { 103 if data.len() == 0 { 104 let _ = stream.stream_close_send(now); 105 return; 106 } 107 match stream.send_data(&data, now) { 108 Ok(sent) => { 109 if sent < data.len() { 110 self.responses.insert(stream, data.split_off(sent)); 111 } else { 112 let _ = stream.stream_close_send(now); 113 } 114 } 115 Err(e) => { 116 eprintln!("error is {:?}", e); 117 } 118 } 119 } 120 121 fn handle_stream_writable(&mut self, stream: Http3OrWebTransportStream, now: Instant) { 122 if let Some(data) = self.responses.get_mut(&stream) { 123 match stream.send_data(&data, now) { 124 Ok(sent) => { 125 if sent < data.len() { 126 let new_d = (*data).split_off(sent); 127 *data = new_d; 128 } else { 129 stream.stream_close_send(now).unwrap(); 130 self.responses.remove(&stream); 131 } 132 } 133 Err(_) => { 134 eprintln!("Unexpected error"); 135 } 136 } 137 } 138 } 139 140 fn maybe_close_session(&mut self, now: Instant) { 141 for (expires, sessions) in self.sessions_to_close.iter_mut() { 142 if *expires <= now { 143 for s in sessions.iter_mut() { 144 drop(s.close_session(0, "", now)); 145 } 146 } 147 } 148 self.sessions_to_close.retain(|expires, _| *expires >= now); 149 } 150 151 fn maybe_close_connection(&mut self) { 152 let now = Instant::now(); 153 for (expires, connections) in self.connections_to_close.iter_mut() { 154 if *expires <= now { 155 for c in connections.iter_mut() { 156 c.borrow_mut().close(now, 0x0100, ""); 157 } 158 } 159 } 160 self.connections_to_close 161 .retain(|expires, _| *expires >= now); 162 } 163 164 fn maybe_create_wt_stream(&mut self, now: Instant) { 165 if self.sessions_to_create_stream.is_empty() { 166 return; 167 } 168 let tuple = self.sessions_to_create_stream.pop().unwrap(); 169 let session = tuple.0; 170 let wt_server_stream = session.create_stream(tuple.1).unwrap(); 171 if tuple.1 == StreamType::UniDi { 172 if let Some(data) = tuple.2 { 173 self.new_response(wt_server_stream, data, now); 174 } else { 175 // relaying Http3ServerEvent::Data to uni streams 176 // slows down netwerk/test/unit/test_webtransport_simple.js 177 // to the point of failure. Only do so when necessary. 178 self.wt_unidi_conn_to_stream 179 .insert(wt_server_stream.conn.clone(), wt_server_stream); 180 } 181 } else { 182 if let Some(data) = tuple.2 { 183 self.new_response(wt_server_stream, data, now); 184 } else { 185 self.webtransport_bidi_stream.insert(wt_server_stream); 186 } 187 } 188 } 189 } 190 191 impl HttpServer for Http3TestServer { 192 fn process_multiple<'a>( 193 &mut self, 194 dgrams: impl IntoIterator<Item = Datagram<&'a mut [u8]>>, 195 now: Instant, 196 max_datagrams: NonZeroUsize, 197 ) -> OutputBatch { 198 let output = self.server.process_multiple(dgrams, now, max_datagrams); 199 200 let output = if self.sessions_to_close.is_empty() && self.connections_to_close.is_empty() { 201 output 202 } else { 203 // In case there are pending sessions to close, use a shorter 204 // timeout to make process_events() to be called earlier. 205 const MIN_INTERVAL: Duration = Duration::from_millis(100); 206 207 match output { 208 OutputBatch::None => OutputBatch::Callback(MIN_INTERVAL), 209 o @ OutputBatch::DatagramBatch(_) => o, 210 OutputBatch::Callback(d) => OutputBatch::Callback(min(d, MIN_INTERVAL)), 211 } 212 }; 213 214 output 215 } 216 217 fn process_events(&mut self, now: Instant) { 218 self.maybe_close_connection(); 219 self.maybe_close_session(now); 220 self.maybe_create_wt_stream(now); 221 222 while let Some(event) = self.server.next_event() { 223 qtrace!("Event: {:?}", event); 224 match event { 225 Http3ServerEvent::Headers { 226 stream, 227 headers, 228 fin, 229 } => { 230 qtrace!("Headers (request={} fin={}): {:?}", stream, fin, headers); 231 232 let connection_hash = { 233 let mut hasher = DefaultHasher::new(); 234 stream.conn.hash(&mut hasher); 235 hasher.finish() 236 }; 237 238 // Some responses do not have content-type. This is on purpose to exercise 239 // UnknownDecoder code. 240 let default_ret = b"Hello World".to_vec(); 241 let default_headers = vec![ 242 Header::new(":status", "200"), 243 Header::new("cache-control", "no-cache"), 244 Header::new("content-length", default_ret.len().to_string()), 245 Header::new("x-http3-conn-hash", connection_hash.to_string()), 246 ]; 247 248 let path_hdr = headers.iter().find(|&h| h.name() == ":path"); 249 match path_hdr { 250 Some(ph) if !ph.value().is_empty() => { 251 let path = ph.value(); 252 qtrace!( 253 "Serve request {:?}", 254 ph.value_utf8().unwrap_or("<invalid utf8>") 255 ); 256 if path == b"/Response421" { 257 let response_body = b"0123456789".to_vec(); 258 stream 259 .send_headers(&[ 260 Header::new(":status", "421"), 261 Header::new("cache-control", "no-cache"), 262 Header::new("content-type", "text/plain"), 263 Header::new( 264 "content-length", 265 response_body.len().to_string(), 266 ), 267 ]) 268 .unwrap(); 269 self.new_response(stream, response_body, now); 270 } else if path == b"/RequestCancelled" { 271 stream 272 .stream_stop_sending(Error::HttpRequestCancelled.code()) 273 .unwrap(); 274 stream 275 .stream_reset_send(Error::HttpRequestCancelled.code()) 276 .unwrap(); 277 } else if path == b"/VersionFallback" { 278 stream 279 .stream_stop_sending(Error::HttpVersionFallback.code()) 280 .unwrap(); 281 stream 282 .stream_reset_send(Error::HttpVersionFallback.code()) 283 .unwrap(); 284 } else if path == b"/EarlyResponse" { 285 stream.stream_stop_sending(Error::HttpNone.code()).unwrap(); 286 } else if path == b"/RequestRejected" { 287 stream 288 .stream_stop_sending(Error::HttpRequestRejected.code()) 289 .unwrap(); 290 stream 291 .stream_reset_send(Error::HttpRequestRejected.code()) 292 .unwrap(); 293 } else if path == b"/closeafter1000ms" { 294 let response_body = b"0123456789".to_vec(); 295 stream 296 .send_headers(&[ 297 Header::new(":status", "200"), 298 Header::new("cache-control", "no-cache"), 299 Header::new("content-type", "text/plain"), 300 Header::new( 301 "content-length", 302 response_body.len().to_string(), 303 ), 304 ]) 305 .unwrap(); 306 let expires = Instant::now() + Duration::from_millis(1000); 307 if !self.connections_to_close.contains_key(&expires) { 308 self.connections_to_close.insert(expires, Vec::new()); 309 } 310 self.connections_to_close 311 .get_mut(&expires) 312 .unwrap() 313 .push(stream.conn.clone()); 314 315 self.new_response(stream, response_body, now); 316 } else if path == b"/.well-known/http-opportunistic" { 317 let host_hdr = headers.iter().find(|&h| h.name() == ":authority"); 318 match host_hdr { 319 Some(host) if !host.value().is_empty() => { 320 let mut content = b"[\"http://".to_vec(); 321 content.extend(host.value()); 322 content.extend(b"\"]"); 323 stream 324 .send_headers(&[ 325 Header::new(":status", "200"), 326 Header::new("cache-control", "no-cache"), 327 Header::new("content-type", "application/json"), 328 Header::new( 329 "content-length", 330 content.len().to_string(), 331 ), 332 ]) 333 .unwrap(); 334 self.new_response(stream, content, now); 335 } 336 _ => { 337 stream.send_headers(&default_headers).unwrap(); 338 self.new_response(stream, default_ret, now); 339 } 340 } 341 } else if path == b"/no_body" { 342 qdebug!("Request for no_body"); 343 stream 344 .send_headers(&[ 345 Header::new(":status", "200"), 346 Header::new("cache-control", "no-cache"), 347 ]) 348 .unwrap(); 349 stream.stream_close_send(now).unwrap(); 350 } else if path == b"/no_content_length" { 351 stream 352 .send_headers(&[ 353 Header::new(":status", "200"), 354 Header::new("cache-control", "no-cache"), 355 ]) 356 .unwrap(); 357 self.new_response(stream, vec![b'a'; 4000], now); 358 } else if path == b"/content_length_smaller" { 359 stream 360 .send_headers(&[ 361 Header::new(":status", "200"), 362 Header::new("cache-control", "no-cache"), 363 Header::new("content-type", "text/plain"), 364 Header::new("content-length", 4000.to_string()), 365 ]) 366 .unwrap(); 367 self.new_response(stream, vec![b'a'; 8000], now); 368 } else if path == b"/post" { 369 // Read all data before responding. 370 self.posts.insert(stream, 0); 371 } else if path == b"/priority_mirror" { 372 if let Some(priority) = 373 headers.iter().find(|h| h.name() == "priority") 374 { 375 stream 376 .send_headers(&[ 377 Header::new(":status", "200"), 378 Header::new("cache-control", "no-cache"), 379 Header::new("content-type", "text/plain"), 380 Header::new( 381 "priority-mirror", 382 priority.value_utf8().unwrap(), 383 ), 384 Header::new( 385 "content-length", 386 priority.value().len().to_string(), 387 ), 388 ]) 389 .unwrap(); 390 self.new_response(stream, priority.value().to_vec(), now); 391 } else { 392 stream 393 .send_headers(&[ 394 Header::new(":status", "200"), 395 Header::new("cache-control", "no-cache"), 396 ]) 397 .unwrap(); 398 stream.stream_close_send(now).unwrap(); 399 } 400 } else if path == b"/103_response" { 401 if let Some(early_hint) = 402 headers.iter().find(|h| h.name() == "link-to-set") 403 { 404 for l in early_hint.value_utf8().unwrap().split(',') { 405 stream 406 .send_headers(&[ 407 Header::new(":status", "103"), 408 Header::new("link", l), 409 ]) 410 .unwrap(); 411 } 412 } 413 stream 414 .send_headers(&[ 415 Header::new(":status", "200"), 416 Header::new("cache-control", "no-cache"), 417 Header::new("content-length", "0"), 418 ]) 419 .unwrap(); 420 stream.stream_close_send(now).unwrap(); 421 } else if path == b"/get_webtransport_datagram" { 422 if let Some(dgram) = self.received_datagram.take() { 423 stream 424 .send_headers(&[ 425 Header::new(":status", "200"), 426 Header::new("content-length", dgram.len().to_string()), 427 ]) 428 .unwrap(); 429 self.new_response(stream, dgram.as_ref().to_vec(), now); 430 } else { 431 stream 432 .send_headers(&[ 433 Header::new(":status", "404"), 434 Header::new("cache-control", "no-cache"), 435 ]) 436 .unwrap(); 437 stream.stream_close_send(now).unwrap(); 438 } 439 } else if path == b"/alt_svc_header" { 440 if let Some(alt_svc) = 441 headers.iter().find(|h| h.name() == "x-altsvc") 442 { 443 stream 444 .send_headers(&[ 445 Header::new(":status", "200"), 446 Header::new("cache-control", "no-cache"), 447 Header::new("content-type", "text/plain"), 448 Header::new("content-length", 100.to_string()), 449 Header::new( 450 "alt-svc", 451 format!("h3={}", alt_svc.value_utf8().unwrap()), 452 ), 453 ]) 454 .unwrap(); 455 self.new_response(stream, vec![b'a'; 100], now); 456 } else { 457 stream 458 .send_headers(&[ 459 Header::new(":status", "200"), 460 Header::new("cache-control", "no-cache"), 461 ]) 462 .unwrap(); 463 self.new_response(stream, vec![b'a'; 100], now); 464 } 465 } else { 466 match ph.value_utf8().ok().and_then(|s| { 467 s.trim_matches(|p| p == '/').parse::<usize>().ok() 468 }) { 469 Some(v) => { 470 stream 471 .send_headers(&[ 472 Header::new(":status", "200"), 473 Header::new("cache-control", "no-cache"), 474 Header::new("content-type", "text/plain"), 475 Header::new("content-length", v.to_string()), 476 ]) 477 .unwrap(); 478 self.new_response(stream, vec![b'a'; v], now); 479 } 480 None => { 481 stream.send_headers(&default_headers).unwrap(); 482 self.new_response(stream, default_ret, now); 483 } 484 } 485 } 486 } 487 _ => { 488 stream.send_headers(&default_headers).unwrap(); 489 self.new_response(stream, default_ret, now); 490 } 491 } 492 } 493 Http3ServerEvent::Data { stream, data, fin } => { 494 // echo bidirectional input back to client 495 if self.webtransport_bidi_stream.contains(&stream) { 496 if stream.handler.borrow().state().active() { 497 self.new_response(stream, data, now); 498 } 499 break; 500 } 501 502 // echo unidirectional input to back to client 503 // need to close or we hang 504 if self.wt_unidi_echo_back.contains_key(&stream) { 505 let echo_back = self.wt_unidi_echo_back.remove(&stream).unwrap(); 506 echo_back.send_data(&data, now).unwrap(); 507 echo_back.stream_close_send(now).unwrap(); 508 break; 509 } 510 511 if let Some(r) = self.posts.get_mut(&stream) { 512 *r += data.len(); 513 } 514 if fin { 515 if let Some(r) = self.posts.remove(&stream) { 516 let default_ret = b"Hello World".to_vec(); 517 stream 518 .send_headers(&[ 519 Header::new(":status", "200"), 520 Header::new("cache-control", "no-cache"), 521 Header::new("x-data-received-length", r.to_string()), 522 Header::new("content-length", default_ret.len().to_string()), 523 ]) 524 .unwrap(); 525 self.new_response(stream, default_ret, now); 526 } 527 } 528 } 529 Http3ServerEvent::DataWritable { stream } => { 530 self.handle_stream_writable(stream, now) 531 } 532 Http3ServerEvent::StateChange { .. } => {} 533 Http3ServerEvent::PriorityUpdate { .. } => {} 534 Http3ServerEvent::StreamReset { stream, error } => { 535 qtrace!("Http3ServerEvent::StreamReset {:?} {:?}", stream, error); 536 } 537 Http3ServerEvent::StreamStopSending { stream, error } => { 538 qtrace!( 539 "Http3ServerEvent::StreamStopSending {:?} {:?}", 540 stream, 541 error 542 ); 543 } 544 Http3ServerEvent::WebTransport(WebTransportServerEvent::NewSession { 545 session, 546 headers, 547 }) => { 548 qdebug!( 549 "WebTransportServerEvent::NewSession {:?} {:?}", 550 session, 551 headers 552 ); 553 let path_hdr = headers.iter().find(|&h| h.name() == ":path"); 554 match path_hdr { 555 Some(ph) if !ph.value().is_empty() => { 556 let path = ph.value(); 557 qtrace!( 558 "Serve request {:?}", 559 ph.value_utf8().unwrap_or("<invalid utf8>") 560 ); 561 if path == b"/success" { 562 session.response(&SessionAcceptAction::Accept, now).unwrap(); 563 } else if path == b"/redirect" { 564 session 565 .response( 566 &SessionAcceptAction::Reject( 567 [ 568 Header::new(":status", "302"), 569 Header::new("location", "/"), 570 ] 571 .to_vec(), 572 ), 573 now, 574 ) 575 .unwrap(); 576 } else if path == b"/reject" { 577 session 578 .response( 579 &SessionAcceptAction::Reject( 580 [Header::new(":status", "404")].to_vec(), 581 ), 582 now, 583 ) 584 .unwrap(); 585 } else if path == b"/closeafter0ms" { 586 session.response(&SessionAcceptAction::Accept, now).unwrap(); 587 if !self.sessions_to_close.contains_key(&now) { 588 self.sessions_to_close.insert(now, Vec::new()); 589 } 590 self.sessions_to_close.get_mut(&now).unwrap().push(session); 591 } else if path == b"/closeafter100ms" { 592 session.response(&SessionAcceptAction::Accept, now).unwrap(); 593 let expires = Instant::now() + Duration::from_millis(100); 594 if !self.sessions_to_close.contains_key(&expires) { 595 self.sessions_to_close.insert(expires, Vec::new()); 596 } 597 self.sessions_to_close 598 .get_mut(&expires) 599 .unwrap() 600 .push(session); 601 } else if path == b"/create_unidi_stream" { 602 session.response(&SessionAcceptAction::Accept, now).unwrap(); 603 self.sessions_to_create_stream.push(( 604 session, 605 StreamType::UniDi, 606 None, 607 )); 608 } else if path == b"/create_unidi_stream_and_hello" { 609 session.response(&SessionAcceptAction::Accept, now).unwrap(); 610 self.sessions_to_create_stream.push(( 611 session, 612 StreamType::UniDi, 613 Some(Vec::from("qwerty")), 614 )); 615 } else if path == b"/create_bidi_stream" { 616 session.response(&SessionAcceptAction::Accept, now).unwrap(); 617 self.sessions_to_create_stream.push(( 618 session, 619 StreamType::BiDi, 620 None, 621 )); 622 } else if path == b"/create_bidi_stream_and_hello" { 623 self.webtransport_bidi_stream.clear(); 624 session.response(&SessionAcceptAction::Accept, now).unwrap(); 625 self.sessions_to_create_stream.push(( 626 session, 627 StreamType::BiDi, 628 Some(Vec::from("asdfg")), 629 )); 630 } else if path == b"/create_bidi_stream_and_large_data" { 631 self.webtransport_bidi_stream.clear(); 632 let data: Vec<u8> = vec![1u8; 32 * 1024 * 1024]; 633 session.response(&SessionAcceptAction::Accept, now).unwrap(); 634 self.sessions_to_create_stream.push(( 635 session, 636 StreamType::BiDi, 637 Some(data), 638 )); 639 } else { 640 session.response(&SessionAcceptAction::Accept, now).unwrap(); 641 } 642 } 643 _ => { 644 session 645 .response( 646 &SessionAcceptAction::Reject( 647 [Header::new(":status", "404")].to_vec(), 648 ), 649 now, 650 ) 651 .unwrap(); 652 } 653 } 654 } 655 Http3ServerEvent::WebTransport(WebTransportServerEvent::SessionClosed { 656 session, 657 reason, 658 headers: _, 659 }) => { 660 qdebug!( 661 "WebTransportServerEvent::SessionClosed {:?} {:?}", 662 session, 663 reason 664 ); 665 } 666 Http3ServerEvent::WebTransport(WebTransportServerEvent::NewStream(stream)) => { 667 // new stream could be from client-outgoing unidirectional 668 // or bidirectional 669 if !stream.stream_info.is_http() { 670 if stream.stream_id().is_bidi() { 671 self.webtransport_bidi_stream.insert(stream); 672 } else { 673 // Newly created stream happens on same connection 674 // as the stream creation for client's incoming stream. 675 // Link the streams with map for echo back 676 if self.wt_unidi_conn_to_stream.contains_key(&stream.conn) { 677 let s = self.wt_unidi_conn_to_stream.remove(&stream.conn).unwrap(); 678 self.wt_unidi_echo_back.insert(stream, s); 679 } 680 } 681 } 682 } 683 Http3ServerEvent::WebTransport(WebTransportServerEvent::Datagram { 684 session, 685 datagram, 686 }) => { 687 qdebug!( 688 "WebTransportServerEvent::Datagram {:?} {:?}", 689 session, 690 datagram 691 ); 692 self.received_datagram = Some(datagram); 693 } 694 Http3ServerEvent::ConnectUdp(_) => { 695 unimplemented!() 696 } 697 } 698 } 699 } 700 701 fn has_events(&self) -> bool { 702 self.server.has_events() 703 } 704 } 705 706 struct Server(neqo_transport::server::Server); 707 708 impl ::std::fmt::Display for Server { 709 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { 710 self.0.fmt(f) 711 } 712 } 713 714 impl HttpServer for Server { 715 fn process_multiple<'a>( 716 &mut self, 717 dgrams: impl IntoIterator<Item = Datagram<&'a mut [u8]>>, 718 now: Instant, 719 max_datagrams: NonZeroUsize, 720 ) -> OutputBatch { 721 self.0.process_multiple(dgrams, now, max_datagrams) 722 } 723 724 fn process_events(&mut self, _now: Instant) { 725 let active_conns = self.0.active_connections(); 726 for acr in active_conns { 727 loop { 728 let event = match acr.borrow_mut().next_event() { 729 None => break, 730 Some(e) => e, 731 }; 732 match event { 733 ConnectionEvent::RecvStreamReadable { stream_id } => { 734 if stream_id.is_bidi() && stream_id.is_client_initiated() { 735 // We are only interesting in request streams 736 acr.borrow_mut() 737 .stream_send(stream_id, HTTP_RESPONSE_WITH_WRONG_FRAME) 738 .expect("Read should succeed"); 739 } 740 } 741 _ => {} 742 } 743 } 744 } 745 } 746 747 fn has_events(&self) -> bool { 748 self.0.has_active_connections() 749 } 750 } 751 752 struct Http3ReverseProxyServer { 753 server: Http3Server, 754 responses: HashMap<Http3OrWebTransportStream, Vec<u8>>, 755 server_port: i32, 756 requests: HashMap<Http3OrWebTransportStream, (Vec<Header>, Vec<u8>)>, 757 #[cfg(not(target_os = "android"))] 758 response_to_send: HashMap<Http3OrWebTransportStream, Receiver<(Vec<Header>, Vec<u8>)>>, 759 } 760 761 impl ::std::fmt::Display for Http3ReverseProxyServer { 762 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { 763 write!(f, "{}", self.server) 764 } 765 } 766 767 impl Http3ReverseProxyServer { 768 pub fn new(server: Http3Server, server_port: i32) -> Self { 769 Self { 770 server, 771 responses: HashMap::new(), 772 server_port, 773 requests: HashMap::new(), 774 #[cfg(not(target_os = "android"))] 775 response_to_send: HashMap::new(), 776 } 777 } 778 779 #[cfg(not(target_os = "android"))] 780 fn new_response(&mut self, stream: Http3OrWebTransportStream, mut data: Vec<u8>, now: Instant) { 781 if data.len() == 0 { 782 let _ = stream.stream_close_send(now); 783 return; 784 } 785 match stream.send_data(&data, now) { 786 Ok(sent) => { 787 if sent < data.len() { 788 self.responses.insert(stream, data.split_off(sent)); 789 } else { 790 stream.stream_close_send(now).unwrap(); 791 } 792 } 793 Err(e) => { 794 eprintln!("error is {:?}, stream will be reset", e); 795 let _ = stream.stream_reset_send(Error::HttpRequestCancelled.code()); 796 } 797 } 798 } 799 800 fn handle_stream_writable(&mut self, stream: Http3OrWebTransportStream, now: Instant) { 801 if let Some(data) = self.responses.get_mut(&stream) { 802 match stream.send_data(&data, now) { 803 Ok(sent) => { 804 if sent < data.len() { 805 let new_d = (*data).split_off(sent); 806 *data = new_d; 807 } else { 808 stream.stream_close_send(now).unwrap(); 809 self.responses.remove(&stream); 810 } 811 } 812 Err(_) => { 813 eprintln!("Unexpected error"); 814 } 815 } 816 } 817 } 818 819 #[cfg(not(target_os = "android"))] 820 async fn fetch_url( 821 request: Request<Body>, 822 out_header: &mut Vec<Header>, 823 out_body: &mut Vec<u8>, 824 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { 825 let client = Client::new(); 826 let mut resp = client.request(request).await?; 827 out_header.push(Header::new(":status", resp.status().as_str())); 828 for (key, value) in resp.headers() { 829 out_header.push(Header::new( 830 key.as_str().to_ascii_lowercase(), 831 match value.to_str() { 832 Ok(str) => str, 833 _ => "", 834 }, 835 )); 836 } 837 838 while let Some(chunk) = resp.body_mut().data().await { 839 match chunk { 840 Ok(data) => { 841 out_body.append(&mut data.to_vec()); 842 } 843 _ => {} 844 } 845 } 846 847 Ok(()) 848 } 849 850 #[cfg(not(target_os = "android"))] 851 fn fetch( 852 &mut self, 853 stream: Http3OrWebTransportStream, 854 request_headers: &Vec<Header>, 855 request_body: Vec<u8>, 856 ) { 857 let mut request: Request<Body> = Request::default(); 858 let mut path = String::new(); 859 for hdr in request_headers.iter() { 860 match hdr.name() { 861 ":method" => { 862 *request.method_mut() = Method::from_bytes(hdr.value()).unwrap(); 863 } 864 ":scheme" => {} 865 ":authority" => { 866 request.headers_mut().insert( 867 hyper::header::HOST, 868 HeaderValue::from_bytes(hdr.value()).unwrap(), 869 ); 870 } 871 ":path" => { 872 path = hdr.value_utf8().unwrap_or("/").to_string(); 873 } 874 _ => { 875 if let Ok(hdr_name) = HeaderName::from_lowercase(hdr.name().as_bytes()) { 876 request 877 .headers_mut() 878 .insert(hdr_name, HeaderValue::from_bytes(hdr.value()).unwrap()); 879 } 880 } 881 } 882 } 883 *request.body_mut() = Body::from(request_body); 884 *request.uri_mut() = 885 match format!("http://127.0.0.1:{}{}", self.server_port.to_string(), path).parse() { 886 Ok(uri) => uri, 887 _ => { 888 eprintln!("invalid uri: {}", path); 889 stream 890 .send_headers(&[ 891 Header::new(":status", "400"), 892 Header::new("cache-control", "no-cache"), 893 Header::new("content-length", "0"), 894 ]) 895 .unwrap(); 896 return; 897 } 898 }; 899 qtrace!("request header: {:?}", request); 900 901 let (sender, receiver) = channel(); 902 thread::spawn(move || { 903 let rt = tokio::runtime::Runtime::new().unwrap(); 904 let mut h: Vec<Header> = Vec::new(); 905 let mut data: Vec<u8> = Vec::new(); 906 let _ = rt.block_on(Self::fetch_url(request, &mut h, &mut data)); 907 qtrace!("response headers: {:?}", h); 908 qtrace!("res data: {:02X?}", data); 909 910 match sender.send((h, data)) { 911 Ok(()) => {} 912 _ => { 913 eprintln!("sender.send failed"); 914 } 915 } 916 }); 917 self.response_to_send.insert(stream, receiver); 918 } 919 920 #[cfg(target_os = "android")] 921 fn fetch( 922 &mut self, 923 mut _stream: Http3OrWebTransportStream, 924 _request_headers: &Vec<Header>, 925 _request_body: Vec<u8>, 926 ) { 927 // do nothing 928 } 929 930 #[cfg(not(target_os = "android"))] 931 fn maybe_process_response(&mut self, now: Instant) { 932 let mut data_to_send = HashMap::new(); 933 self.response_to_send 934 .retain(|id, receiver| match receiver.try_recv() { 935 Ok((headers, body)) => { 936 data_to_send.insert(id.clone(), (headers.clone(), body.clone())); 937 false 938 } 939 Err(TryRecvError::Empty) => true, 940 Err(TryRecvError::Disconnected) => false, 941 }); 942 while let Some(stream) = data_to_send.keys().next().cloned() { 943 let (header, data) = data_to_send.remove(&stream).unwrap(); 944 qtrace!("response headers: {:?}", header); 945 match stream.send_headers(&header) { 946 Ok(()) => { 947 self.new_response(stream, data, now); 948 } 949 _ => {} 950 } 951 } 952 } 953 } 954 955 impl HttpServer for Http3ReverseProxyServer { 956 fn process_multiple<'a>( 957 &mut self, 958 dgrams: impl IntoIterator<Item = Datagram<&'a mut [u8]>>, 959 now: Instant, 960 max_datagrams: NonZeroUsize, 961 ) -> OutputBatch { 962 let output = self.server.process_multiple(dgrams, now, max_datagrams); 963 964 #[cfg(not(target_os = "android"))] 965 let output = if self.response_to_send.is_empty() { 966 output 967 } else { 968 // In case there are pending responses to send, make sure a reasonable 969 // callback is returned. 970 const MIN_INTERVAL: Duration = Duration::from_millis(100); 971 972 match output { 973 OutputBatch::None => OutputBatch::Callback(MIN_INTERVAL), 974 o @ OutputBatch::DatagramBatch(_) => o, 975 OutputBatch::Callback(d) => OutputBatch::Callback(min(d, MIN_INTERVAL)), 976 } 977 }; 978 979 output 980 } 981 982 fn process_events(&mut self, now: Instant) { 983 #[cfg(not(target_os = "android"))] 984 self.maybe_process_response(now); 985 while let Some(event) = self.server.next_event() { 986 qtrace!("Event: {:?}", event); 987 match event { 988 Http3ServerEvent::Headers { 989 stream, 990 headers, 991 fin: _, 992 } => { 993 qtrace!("Headers {:?}", headers); 994 if self.server_port != -1 { 995 let method_hdr = headers.iter().find(|&h| h.name() == ":method"); 996 match method_hdr { 997 Some(method) => match method.value() { 998 b"POST" => { 999 let content_length = 1000 headers.iter().find(|&h| h.name() == "content-length"); 1001 if let Some(length_str) = content_length { 1002 if let Ok(len) = 1003 length_str.value_utf8().unwrap_or("0").parse::<u32>() 1004 { 1005 if len > 0 { 1006 self.requests.insert(stream, (headers, Vec::new())); 1007 } else { 1008 self.fetch(stream, &headers, b"".to_vec()); 1009 } 1010 } 1011 } 1012 } 1013 _ => { 1014 self.fetch(stream, &headers, b"".to_vec()); 1015 } 1016 }, 1017 _ => {} 1018 } 1019 } else { 1020 let path_hdr = headers.iter().find(|&h| h.name() == ":path"); 1021 match path_hdr { 1022 Some(ph) if !ph.value().is_empty() => { 1023 if let Some(path_str) = ph.value_utf8().ok() { 1024 if let Some(port_str) = path_str.strip_prefix("/port?") { 1025 let port = port_str.parse::<i32>().ok(); 1026 if let Some(port) = port { 1027 qtrace!("got port {}", port); 1028 self.server_port = port; 1029 } 1030 } 1031 } 1032 } 1033 _ => {} 1034 } 1035 stream 1036 .send_headers(&[ 1037 Header::new(":status", "200"), 1038 Header::new("cache-control", "no-cache"), 1039 Header::new("content-length", "0"), 1040 ]) 1041 .unwrap(); 1042 } 1043 } 1044 Http3ServerEvent::Data { 1045 stream, 1046 mut data, 1047 fin, 1048 } => { 1049 if let Some((_, body)) = self.requests.get_mut(&stream) { 1050 body.append(&mut data); 1051 } 1052 if fin { 1053 if let Some((headers, body)) = self.requests.remove(&stream) { 1054 self.fetch(stream, &headers, body); 1055 } 1056 } 1057 } 1058 Http3ServerEvent::DataWritable { stream } => { 1059 self.handle_stream_writable(stream, now) 1060 } 1061 Http3ServerEvent::StateChange { .. } | Http3ServerEvent::PriorityUpdate { .. } => {} 1062 Http3ServerEvent::StreamReset { stream, error } => { 1063 qtrace!("Http3ServerEvent::StreamReset {:?} {:?}", stream, error); 1064 } 1065 Http3ServerEvent::StreamStopSending { stream, error } => { 1066 qtrace!( 1067 "Http3ServerEvent::StreamStopSending {:?} {:?}", 1068 stream, 1069 error 1070 ); 1071 } 1072 Http3ServerEvent::WebTransport(_) => {} 1073 Http3ServerEvent::ConnectUdp(_) => {} 1074 } 1075 } 1076 } 1077 1078 fn has_events(&self) -> bool { 1079 self.server.has_events() 1080 } 1081 } 1082 1083 struct Http3ConnectProxyServer { 1084 server: Http3Server, 1085 tcp_streams: HashMap<StreamId, TcpStream>, 1086 udp_sockets: HashMap<StreamId, UdpSocket>, 1087 } 1088 1089 impl ::std::fmt::Display for Http3ConnectProxyServer { 1090 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { 1091 write!(f, "{}", self.server) 1092 } 1093 } 1094 1095 impl Http3ConnectProxyServer { 1096 pub fn new(server: Http3Server) -> Self { 1097 Self { 1098 server, 1099 tcp_streams: HashMap::new(), 1100 udp_sockets: HashMap::new(), 1101 } 1102 } 1103 } 1104 1105 impl HttpServer for Http3ConnectProxyServer { 1106 fn process_multiple<'a>( 1107 &mut self, 1108 dgrams: impl IntoIterator<Item = Datagram<&'a mut [u8]>>, 1109 now: Instant, 1110 max_datagrams: NonZeroUsize, 1111 ) -> OutputBatch { 1112 self.server.process_multiple(dgrams, now, max_datagrams) 1113 } 1114 1115 fn process_events(&mut self, now: Instant) { 1116 while let Some(event) = self.server.next_event() { 1117 qtrace!("Event: {:?}", event); 1118 match event { 1119 Http3ServerEvent::Headers { 1120 stream, 1121 headers, 1122 fin: _, 1123 } => { 1124 qtrace!("Headers {:?}", headers); 1125 let method_hdr = headers.iter().find(|&h| h.name() == ":method").unwrap(); 1126 assert_eq!( 1127 method_hdr.value(), 1128 b"CONNECT", 1129 "{:?} not supported", 1130 method_hdr.value_utf8().unwrap_or("<invalid utf8>") 1131 ); 1132 let host_hdr = headers.iter().find(|&h| h.name() == ":authority").unwrap(); 1133 let host_str = host_hdr.value_utf8().unwrap(); 1134 1135 // Check if we should fallback to 127.0.0.1 before attempting connection 1136 let host_without_port = if let Some(colon_pos) = host_str.rfind(':') { 1137 &host_str[..colon_pos] 1138 } else { 1139 host_str 1140 }; 1141 1142 let should_fallback = matches!( 1143 host_without_port, 1144 "foo.example.com" | "alt1.example.com" | "alt2.example.com" 1145 ); 1146 1147 let target = if should_fallback { 1148 if let Some(port_start) = host_str.rfind(':') { 1149 format!("127.0.0.1:{}", &host_str[port_start + 1..]) 1150 } else { 1151 // No port specified, assume default HTTP port 80 1152 "127.0.0.1:80".to_string() 1153 } 1154 } else { 1155 host_str.to_string() 1156 }; 1157 1158 let tcp_stream = match std::net::TcpStream::connect(&target) { 1159 Ok(c) => c, 1160 Err(_) => { 1161 stream 1162 .send_headers(&[ 1163 Header::new(":status", "502"), 1164 Header::new("cache-control", "no-cache"), 1165 ]) 1166 .unwrap(); 1167 stream.stream_close_send(now).unwrap(); 1168 return; 1169 } 1170 }; 1171 1172 tcp_stream.set_nonblocking(true).unwrap(); 1173 qtrace!("tcp_stream to {:?} created", host_hdr); 1174 stream 1175 .send_headers(&[ 1176 Header::new(":status", "200"), 1177 Header::new("cache-control", "no-cache"), 1178 ]) 1179 .unwrap(); 1180 self.tcp_streams.insert( 1181 stream.stream_id(), 1182 TcpStream { 1183 send_buffer: VecDeque::new(), 1184 recv_buffer: VecDeque::new(), 1185 stream: tokio::net::TcpStream::from_std(tcp_stream).unwrap(), 1186 send_fin: false, 1187 received_fin: false, 1188 session: stream, 1189 }, 1190 ); 1191 } 1192 Http3ServerEvent::Data { stream, data, fin } => { 1193 qtrace!("tcp_stream send to server len={}", data.len()); 1194 let tcp_stream = self.tcp_streams.get_mut(&stream.stream_id()).unwrap(); 1195 // TODO: extend() effectively breaks backpressure. 1196 tcp_stream.send_buffer.extend(data); 1197 tcp_stream.send_fin |= fin; 1198 } 1199 Http3ServerEvent::DataWritable { stream } => { 1200 qtrace!( 1201 "Http3ServerEvent::DataWritable streamid={}", 1202 stream.stream_id() 1203 ); 1204 let tcp_stream = self.tcp_streams.get_mut(&stream.stream_id()).unwrap(); 1205 while !tcp_stream.recv_buffer.is_empty() { 1206 let sent = stream 1207 .send_data(&tcp_stream.recv_buffer.make_contiguous(), now) 1208 .unwrap(); 1209 qtrace!("tcp_stream send to client sent={}", sent); 1210 if sent == 0 { 1211 break; 1212 } 1213 tcp_stream.recv_buffer.drain(0..sent); 1214 } 1215 } 1216 Http3ServerEvent::ConnectUdp(ConnectUdpServerEvent::NewSession { 1217 session, 1218 headers, 1219 }) => { 1220 session.response(&SessionAcceptAction::Accept, now).unwrap(); 1221 1222 let host_hdr = headers.iter().find(|&h| h.name() == ":path").unwrap(); 1223 let path_str = host_hdr.value_utf8().unwrap(); 1224 let path_parts: Vec<&str> = path_str.split('/').collect(); 1225 1226 // Format is /.well-known/masque/udp/{target_host}/{target_port}/ 1227 if path_parts.len() < 6 { 1228 panic!("{}", path_str) 1229 } 1230 1231 let target_host = path_parts[4]; 1232 let target_port = match path_parts[5].trim_end_matches('/').parse::<u16>() { 1233 Ok(port) => port, 1234 Err(_) => { 1235 panic!("{}", path_str) 1236 } 1237 }; 1238 1239 // Replace target_host with 127.0.0.1 for specific hosts 1240 let actual_host = match target_host { 1241 "foo.example.com" | "alt1.example.com" | "alt2.example.com" => "127.0.0.1", 1242 _ => target_host, 1243 }; 1244 1245 let host_port = format!("{}:{}", actual_host, target_port); 1246 qdebug!("CONNECT-UDP to {}", host_port); 1247 1248 let socket = { 1249 let s = 1250 socket2::Socket::new(socket2::Domain::IPV4, socket2::Type::DGRAM, None) 1251 .unwrap(); 1252 s.bind(&"0.0.0.0:0".parse::<SocketAddr>().unwrap().into()) 1253 .unwrap(); 1254 let s: std::net::UdpSocket = s.into(); 1255 s.connect((actual_host, target_port)).unwrap(); 1256 s.set_nonblocking(true).unwrap(); 1257 s.into() 1258 }; 1259 1260 self.udp_sockets.insert( 1261 session.stream_id(), 1262 UdpSocket { 1263 session, 1264 send_buffer: VecDeque::new(), 1265 socket: tokio::net::UdpSocket::from_std(socket).unwrap(), 1266 }, 1267 ); 1268 } 1269 Http3ServerEvent::ConnectUdp(ConnectUdpServerEvent::Datagram { 1270 session, 1271 datagram, 1272 }) => { 1273 let udp_socket = self.udp_sockets.get_mut(&session.stream_id()).unwrap(); 1274 // TODO: effectively breaks backpressure. 1275 udp_socket.send_buffer.push_back(datagram); 1276 } 1277 Http3ServerEvent::ConnectUdp(ConnectUdpServerEvent::SessionClosed { 1278 session, 1279 reason, 1280 headers: _, 1281 }) => { 1282 qdebug!( 1283 "ConnectUdp session closed: {:?} reason: {:?}", 1284 session, 1285 reason 1286 ); 1287 self.udp_sockets.remove(&session.stream_id()); 1288 } 1289 Http3ServerEvent::StateChange { .. } | Http3ServerEvent::PriorityUpdate { .. } => {} 1290 Http3ServerEvent::StreamReset { stream, error } => { 1291 qtrace!("Http3ServerEvent::StreamReset {:?} {:?}", stream, error); 1292 } 1293 Http3ServerEvent::StreamStopSending { stream, error } => { 1294 qtrace!( 1295 "Http3ServerEvent::StreamStopSending {:?} {:?}", 1296 stream, 1297 error 1298 ); 1299 } 1300 Http3ServerEvent::WebTransport(_) => {} 1301 } 1302 } 1303 } 1304 1305 fn has_events(&self) -> bool { 1306 self.server.has_events() 1307 } 1308 1309 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { 1310 let mut progressed = false; 1311 let mut failed_udp_sockets: Vec<StreamId> = Vec::new(); 1312 1313 for (_sessionid, stream) in &mut self.tcp_streams { 1314 if let Poll::Ready(Ok(())) = stream.stream.poll_read_ready(cx) { 1315 loop { 1316 let mut buf = vec![0; 1024]; 1317 match stream.stream.try_read(&mut buf) { 1318 Ok(0) => { 1319 qdebug!("TCP: Received 0 bytes -FIN"); 1320 stream.received_fin = true; 1321 // TODO: Reset CONNECT stream. 1322 break; 1323 } 1324 Ok(n) => { 1325 qdebug!("TCP: Received {} bytes from origin", n); 1326 // TODO: extend() effectively breaks backpressure. 1327 stream.recv_buffer.extend(&buf[0..n]); 1328 while !stream.recv_buffer.is_empty() { 1329 let sent = match stream.session.send_data( 1330 &stream.recv_buffer.make_contiguous(), 1331 Instant::now(), 1332 ) { 1333 Ok(n) => n, 1334 Err(e) => { 1335 qdebug!("TCP: send_data failed: {}", e); 1336 break; 1337 } 1338 }; 1339 qdebug!("TCP: stream send to client sent={}", sent); 1340 if sent == 0 { 1341 break; 1342 } 1343 stream.recv_buffer.drain(0..sent); 1344 } 1345 progressed = true; 1346 } 1347 Err(e) => { 1348 qdebug!("TCP read error: {e:?}"); 1349 stream.received_fin = true; 1350 // TODO: Handle the error 1351 break; 1352 } 1353 } 1354 } 1355 } 1356 1357 if let Poll::Ready(Ok(())) = stream.stream.poll_write_ready(cx) { 1358 while !stream.send_buffer.is_empty() { 1359 match stream 1360 .stream 1361 .try_write(&stream.send_buffer.make_contiguous()) 1362 { 1363 Ok(0) => break, 1364 Ok(n) => { 1365 qdebug!("TCP: Sent {} bytes to origin", n); 1366 stream.send_buffer.drain(0..n); 1367 progressed = true; 1368 } 1369 Err(e) => { 1370 qdebug!("TCP write error: {e:?}"); 1371 stream.received_fin = true; 1372 // TODO: Handle the error 1373 break; 1374 } 1375 } 1376 } 1377 } 1378 if stream.send_fin { 1379 let _ = stream.stream.shutdown(); 1380 } 1381 } 1382 1383 for (stream_id, socket) in &mut self.udp_sockets { 1384 loop { 1385 let mut buf = vec![0u8; u16::MAX as usize]; 1386 let mut read_buf = ReadBuf::new(buf.as_mut()); 1387 match socket.socket.poll_recv(cx, &mut read_buf) { 1388 Poll::Ready(Ok(())) => { 1389 let len = read_buf.filled().len(); 1390 qinfo!("Received {} bytes from origin", len); 1391 buf.resize(len, 0); 1392 // TODO: Might overflow our current datagram buffer of 10 1393 // https://github.com/mozilla/neqo/issues/2852 1394 socket.session.send_datagram(buf.as_slice(), None).unwrap(); 1395 progressed = true; 1396 } 1397 Poll::Ready(Err(e)) => { 1398 qerror!("Error receiving UDP datagram: {}, closing socket", e); 1399 failed_udp_sockets.push(*stream_id); 1400 break; 1401 } 1402 Poll::Pending => break, 1403 } 1404 } 1405 1406 while let Some(datagram) = socket.send_buffer.pop_front() { 1407 match socket.socket.poll_send(cx, datagram.as_ref()) { 1408 Poll::Ready(Ok(0)) | Poll::Pending => { 1409 socket.send_buffer.push_front(datagram); 1410 break; 1411 } 1412 Poll::Ready(Ok(n)) => { 1413 assert_eq!(n, datagram.len()); 1414 qinfo!("Sent {}/{} bytes to origin", n, datagram.len()); 1415 progressed = true; 1416 } 1417 Poll::Ready(Err(e)) => { 1418 qerror!( 1419 "Error sending UDP datagram: {} {:?}, closing socket", 1420 e, 1421 socket.socket 1422 ); 1423 failed_udp_sockets.push(*stream_id); 1424 break; 1425 } 1426 } 1427 } 1428 } 1429 1430 // Remove failed UDP sockets from the list 1431 for stream_id in failed_udp_sockets { 1432 if let Some(socket) = self.udp_sockets.remove(&stream_id) { 1433 qdebug!("Removed failed UDP socket for stream {}", stream_id); 1434 // Close the session with an error code 1435 let _ = socket 1436 .session 1437 .close_session(0x0100, "UDP socket error", Instant::now()); 1438 } 1439 } 1440 1441 if progressed { 1442 return Poll::Ready(()); 1443 } 1444 1445 Poll::Pending 1446 } 1447 } 1448 1449 struct TcpStream { 1450 send_buffer: VecDeque<u8>, 1451 recv_buffer: VecDeque<u8>, 1452 stream: tokio::net::TcpStream, 1453 send_fin: bool, 1454 received_fin: bool, 1455 session: Http3OrWebTransportStream, 1456 } 1457 1458 struct UdpSocket { 1459 session: ConnectUdpRequest, 1460 send_buffer: VecDeque<Bytes>, 1461 socket: tokio::net::UdpSocket, 1462 } 1463 1464 #[derive(Default)] 1465 struct NonRespondingServer {} 1466 1467 impl ::std::fmt::Display for NonRespondingServer { 1468 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { 1469 write!(f, "NonRespondingServer") 1470 } 1471 } 1472 1473 impl HttpServer for NonRespondingServer { 1474 fn process_multiple<'a>( 1475 &mut self, 1476 _dgrams: impl IntoIterator<Item = Datagram<&'a mut [u8]>>, 1477 _now: Instant, 1478 _max_datagrams: NonZeroUsize, 1479 ) -> OutputBatch { 1480 OutputBatch::None 1481 } 1482 1483 fn process_events(&mut self, _now: Instant) {} 1484 1485 fn has_events(&self) -> bool { 1486 false 1487 } 1488 } 1489 1490 fn spawn_server<S: HttpServer + Unpin + 'static>( 1491 server: S, 1492 port: u16, 1493 task_set: &LocalSet, 1494 hosts: &mut Vec<SocketAddr>, 1495 ) -> Result<(), io::Error> { 1496 let addr: SocketAddr = if cfg!(target_os = "windows") { 1497 format!("127.0.0.1:{}", port).parse().unwrap() 1498 } else { 1499 format!("[::]:{}", port).parse().unwrap() 1500 }; 1501 1502 let socket = match neqo_bin::udp::Socket::bind(&addr) { 1503 Err(err) => { 1504 eprintln!("Unable to bind UDP socket: {}", err); 1505 exit(1) 1506 } 1507 Ok(s) => s, 1508 }; 1509 1510 let local_addr = match socket.local_addr() { 1511 Err(err) => { 1512 eprintln!("Socket local address not bound: {}", err); 1513 exit(1) 1514 } 1515 Ok(s) => s, 1516 }; 1517 1518 task_set 1519 .spawn_local(Runner::new(server, Box::new(Instant::now), vec![(local_addr, socket)]).run()); 1520 hosts.push(local_addr); 1521 1522 Ok(()) 1523 } 1524 1525 #[tokio::main] 1526 async fn main() -> Result<(), io::Error> { 1527 neqo_common::log::init(None); 1528 1529 let args: Vec<String> = env::args().collect(); 1530 if args.len() < 2 { 1531 eprintln!("Wrong arguments."); 1532 exit(1) 1533 } 1534 1535 // Read data from stdin and terminate the server if EOF is detected, which 1536 // means that runxpcshelltests.py ended without shutting down the server. 1537 thread::spawn(|| loop { 1538 let mut buffer = String::new(); 1539 match io::stdin().read_line(&mut buffer) { 1540 Ok(n) => { 1541 if n == 0 { 1542 exit(0); 1543 } 1544 } 1545 Err(_) => { 1546 exit(0); 1547 } 1548 } 1549 }); 1550 1551 init_db(PathBuf::from(args[1].clone())).unwrap(); 1552 1553 let local = LocalSet::new(); 1554 let mut hosts = vec![]; 1555 1556 let proxy_port = match env::var("MOZ_HTTP3_PROXY_PORT") { 1557 Ok(val) => val.parse::<u16>().unwrap(), 1558 _ => 0, 1559 }; 1560 1561 let anti_replay = || { 1562 AntiReplay::new(Instant::now(), Duration::from_secs(10), 7, 14) 1563 .expect("unable to setup anti-replay") 1564 }; 1565 let cid_mgr = Rc::new(RefCell::new(RandomConnectionIdGenerator::new(10))); 1566 1567 spawn_server( 1568 Http3TestServer::new( 1569 Http3Server::new( 1570 Instant::now(), 1571 &[" HTTP2 Test Cert"], 1572 PROTOCOLS, 1573 anti_replay(), 1574 cid_mgr.clone(), 1575 Http3Parameters::default() 1576 .max_table_size_encoder(MAX_TABLE_SIZE) 1577 .max_table_size_decoder(MAX_TABLE_SIZE) 1578 .max_blocked_streams(MAX_BLOCKED_STREAMS) 1579 .webtransport(true) 1580 .connection_parameters(ConnectionParameters::default().datagram_size(1200)), 1581 None, 1582 ) 1583 .expect("We cannot make a server!"), 1584 ), 1585 0, 1586 &local, 1587 &mut hosts, 1588 )?; 1589 1590 spawn_server( 1591 Server( 1592 neqo_transport::server::Server::new( 1593 Instant::now(), 1594 &[" HTTP2 Test Cert"], 1595 PROTOCOLS, 1596 anti_replay(), 1597 Box::new(AllowZeroRtt {}), 1598 cid_mgr.clone(), 1599 ConnectionParameters::default(), 1600 ) 1601 .expect("We cannot make a server!"), 1602 ), 1603 0, 1604 &local, 1605 &mut hosts, 1606 )?; 1607 1608 let ech_config = { 1609 let mut server = Http3TestServer::new( 1610 Http3Server::new( 1611 Instant::now(), 1612 &[" HTTP2 Test Cert"], 1613 PROTOCOLS, 1614 anti_replay(), 1615 cid_mgr.clone(), 1616 Http3Parameters::default() 1617 .max_table_size_encoder(MAX_TABLE_SIZE) 1618 .max_table_size_decoder(MAX_TABLE_SIZE) 1619 .max_blocked_streams(MAX_BLOCKED_STREAMS), 1620 None, 1621 ) 1622 .expect("We cannot make a server!"), 1623 ); 1624 let (sk, pk) = generate_ech_keys().unwrap(); 1625 server 1626 .server 1627 .enable_ech(ECH_CONFIG_ID, ECH_PUBLIC_NAME, &sk, &pk) 1628 .expect("unable to enable ech"); 1629 let ech_config = server.server.ech_config().to_vec(); 1630 spawn_server(server, 0, &local, &mut hosts)?; 1631 ech_config 1632 }; 1633 1634 spawn_server( 1635 { 1636 let server_config = if env::var("MOZ_HTTP3_MOCHITEST").is_ok() { 1637 ("mochitest-cert", 8888) 1638 } else { 1639 (" HTTP2 Test Cert", -1) 1640 }; 1641 let server = Http3ReverseProxyServer::new( 1642 Http3Server::new( 1643 Instant::now(), 1644 &[server_config.0], 1645 PROTOCOLS, 1646 anti_replay(), 1647 cid_mgr.clone(), 1648 Http3Parameters::default() 1649 .max_table_size_encoder(MAX_TABLE_SIZE) 1650 .max_table_size_decoder(MAX_TABLE_SIZE) 1651 .max_blocked_streams(MAX_BLOCKED_STREAMS) 1652 .webtransport(true) 1653 .connection_parameters(ConnectionParameters::default().datagram_size(1200)), 1654 None, 1655 ) 1656 .expect("We cannot make a server!"), 1657 server_config.1, 1658 ); 1659 server 1660 }, 1661 proxy_port, 1662 &local, 1663 &mut hosts, 1664 )?; 1665 1666 spawn_server(NonRespondingServer::default(), 0, &local, &mut hosts)?; 1667 1668 spawn_server( 1669 Http3ConnectProxyServer::new( 1670 Http3Server::new( 1671 Instant::now(), 1672 &[" HTTP2 Test Cert"], 1673 PROTOCOLS, 1674 anti_replay(), 1675 cid_mgr, 1676 Http3Parameters::default() 1677 .max_table_size_encoder(MAX_TABLE_SIZE) 1678 .connection_parameters( 1679 ConnectionParameters::default() 1680 // TODO: Restrict in size. 1681 .datagram_size(u16::MAX as u64) 1682 .pmtud(true), 1683 ) 1684 .max_table_size_decoder(MAX_TABLE_SIZE) 1685 .max_blocked_streams(MAX_BLOCKED_STREAMS) 1686 .connect(true) 1687 .http3_datagram(true), 1688 None, 1689 ) 1690 .expect("We cannot make a server!"), 1691 ), 1692 0, 1693 &local, 1694 &mut hosts, 1695 )?; 1696 1697 // Note this is parsed by test runner. 1698 // https://searchfox.org/mozilla-central/rev/e69f323af80c357d287fb6314745e75c62eab92a/testing/mozbase/mozserve/mozserve/servers.py#116-121 1699 println!( 1700 "HTTP3 server listening on ports {}, {}, {}, {}, {} and {}. EchConfig is @{}@", 1701 hosts[0].port(), 1702 hosts[1].port(), 1703 hosts[2].port(), 1704 hosts[3].port(), 1705 hosts[4].port(), 1706 hosts[5].port(), 1707 BASE64_STANDARD.encode(ech_config) 1708 ); 1709 1710 local.await; 1711 1712 Ok(()) 1713 } 1714 1715 #[no_mangle] 1716 extern "C" fn __tsan_default_suppressions() -> *const std::os::raw::c_char { 1717 // https://github.com/rust-lang/rust/issues/128769 1718 "race:tokio::runtime::io::registration_set::RegistrationSet::allocate\0".as_ptr() as *const _ 1719 } 1720 1721 // Work around until we can use raw-dylibs. 1722 #[cfg_attr(target_os = "windows", link(name = "runtimeobject"))] 1723 extern "C" {} 1724 #[cfg_attr(target_os = "windows", link(name = "propsys"))] 1725 extern "C" {} 1726 #[cfg_attr(target_os = "windows", link(name = "iphlpapi"))] 1727 extern "C" {}