server.rs (26046B)
1 /* This Source Code Form is subject to the terms of the Mozilla Public 2 * License, v. 2.0. If a copy of the MPL was not distributed with this 3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ 4 5 use crate::command::{WebDriverCommand, WebDriverMessage}; 6 use crate::error::{ErrorStatus, WebDriverError, WebDriverResult}; 7 use crate::httpapi::{ 8 standard_routes, Route, VoidWebDriverExtensionRoute, WebDriverExtensionRoute, 9 }; 10 use crate::response::{CloseWindowResponse, WebDriverResponse}; 11 use crate::Parameters; 12 use bytes::Bytes; 13 use http::{Method, StatusCode}; 14 use std::marker::PhantomData; 15 use std::net::{SocketAddr, TcpListener as StdTcpListener}; 16 use std::sync::mpsc::{channel, Receiver, Sender}; 17 use std::sync::{Arc, Mutex}; 18 use std::thread; 19 use tokio::net::TcpListener; 20 use tokio_stream::wrappers::TcpListenerStream; 21 use url::{Host, Url}; 22 use warp::{Buf, Filter, Rejection}; 23 24 // Silence warning about Quit being unused for now. 25 #[allow(dead_code)] 26 enum DispatchMessage<U: WebDriverExtensionRoute> { 27 HandleWebDriver( 28 WebDriverMessage<U>, 29 Sender<WebDriverResult<WebDriverResponse>>, 30 ), 31 Quit, 32 } 33 34 #[derive(Clone, Debug, PartialEq)] 35 /// Representation of whether we managed to successfully send a DeleteSession message 36 /// and read the response during session teardown. 37 pub enum SessionTeardownKind { 38 /// A DeleteSession message has been sent and the response handled. 39 Deleted, 40 /// No DeleteSession message has been sent, or the response was not received. 41 NotDeleted, 42 } 43 44 #[derive(Clone, Debug, PartialEq)] 45 pub struct Session { 46 pub id: String, 47 } 48 49 impl Session { 50 fn new(id: String) -> Session { 51 Session { id } 52 } 53 } 54 55 pub trait WebDriverHandler<U: WebDriverExtensionRoute = VoidWebDriverExtensionRoute>: Send { 56 fn handle_command( 57 &mut self, 58 session: &Option<Session>, 59 msg: WebDriverMessage<U>, 60 ) -> WebDriverResult<WebDriverResponse>; 61 fn teardown_session(&mut self, kind: SessionTeardownKind); 62 } 63 64 #[derive(Debug)] 65 struct Dispatcher<T: WebDriverHandler<U>, U: WebDriverExtensionRoute> { 66 handler: T, 67 session: Option<Session>, 68 extension_type: PhantomData<U>, 69 } 70 71 impl<T: WebDriverHandler<U>, U: WebDriverExtensionRoute> Dispatcher<T, U> { 72 fn new(handler: T) -> Dispatcher<T, U> { 73 Dispatcher { 74 handler, 75 session: None, 76 extension_type: PhantomData, 77 } 78 } 79 80 fn run(&mut self, msg_chan: &Receiver<DispatchMessage<U>>) { 81 loop { 82 match msg_chan.recv() { 83 Ok(DispatchMessage::HandleWebDriver(msg, resp_chan)) => { 84 let resp = match self.check_session(&msg) { 85 Ok(_) => self.handler.handle_command(&self.session, msg), 86 Err(e) => Err(e), 87 }; 88 89 match resp { 90 Ok(WebDriverResponse::NewSession(ref new_session)) => { 91 self.session = Some(Session::new(new_session.session_id.clone())); 92 } 93 Ok(WebDriverResponse::CloseWindow(CloseWindowResponse(ref handles))) => { 94 if handles.is_empty() { 95 debug!("Last window was closed, deleting session"); 96 // The teardown_session implementation is responsible for actually 97 // sending the DeleteSession message in this case 98 self.teardown_session(SessionTeardownKind::NotDeleted); 99 } 100 } 101 Ok(WebDriverResponse::DeleteSession) => { 102 self.teardown_session(SessionTeardownKind::Deleted); 103 } 104 Err(ref x) if x.delete_session => { 105 // This includes the case where we failed during session creation 106 self.teardown_session(SessionTeardownKind::NotDeleted) 107 } 108 _ => {} 109 } 110 111 if resp_chan.send(resp).is_err() { 112 error!("Sending response to the main thread failed"); 113 }; 114 } 115 Ok(DispatchMessage::Quit) => break, 116 Err(e) => panic!("Error receiving message in handler: {:?}", e), 117 } 118 } 119 } 120 121 fn teardown_session(&mut self, kind: SessionTeardownKind) { 122 debug!("Teardown session"); 123 let final_kind = match kind { 124 SessionTeardownKind::NotDeleted if self.session.is_some() => { 125 let delete_session = WebDriverMessage { 126 session_id: Some( 127 self.session 128 .as_ref() 129 .expect("Failed to get session") 130 .id 131 .clone(), 132 ), 133 command: WebDriverCommand::DeleteSession, 134 }; 135 match self.handler.handle_command(&self.session, delete_session) { 136 Ok(_) => SessionTeardownKind::Deleted, 137 Err(_) => SessionTeardownKind::NotDeleted, 138 } 139 } 140 _ => kind, 141 }; 142 self.handler.teardown_session(final_kind); 143 self.session = None; 144 } 145 146 fn check_session(&self, msg: &WebDriverMessage<U>) -> WebDriverResult<()> { 147 match msg.session_id { 148 Some(ref msg_session_id) => match self.session { 149 Some(ref existing_session) => { 150 if existing_session.id != *msg_session_id { 151 Err(WebDriverError::new( 152 ErrorStatus::InvalidSessionId, 153 format!("Got unexpected session id {}", msg_session_id), 154 )) 155 } else { 156 Ok(()) 157 } 158 } 159 None => Ok(()), 160 }, 161 None => { 162 match self.session { 163 Some(_) => { 164 match msg.command { 165 WebDriverCommand::Status => Ok(()), 166 WebDriverCommand::NewSession(_) => Err(WebDriverError::new( 167 ErrorStatus::SessionNotCreated, 168 "Session is already started", 169 )), 170 _ => { 171 //This should be impossible 172 error!("Got a message with no session id"); 173 Err(WebDriverError::new( 174 ErrorStatus::UnknownError, 175 "Got a command with no session?!", 176 )) 177 } 178 } 179 } 180 None => match msg.command { 181 WebDriverCommand::NewSession(_) => Ok(()), 182 WebDriverCommand::Status => Ok(()), 183 _ => Err(WebDriverError::new( 184 ErrorStatus::InvalidSessionId, 185 "Tried to run a command before creating a session", 186 )), 187 }, 188 } 189 } 190 } 191 } 192 } 193 194 pub struct Listener { 195 guard: Option<thread::JoinHandle<()>>, 196 pub socket: SocketAddr, 197 } 198 199 impl Drop for Listener { 200 fn drop(&mut self) { 201 let _ = self.guard.take().map(|j| j.join()); 202 } 203 } 204 205 pub fn start<T, U>( 206 mut address: SocketAddr, 207 allow_hosts: Vec<Host>, 208 allow_origins: Vec<Url>, 209 handler: T, 210 extension_routes: Vec<(Method, &'static str, U)>, 211 ) -> ::std::io::Result<Listener> 212 where 213 T: 'static + WebDriverHandler<U>, 214 U: 'static + WebDriverExtensionRoute + Send + Sync, 215 { 216 let listener = StdTcpListener::bind(address)?; 217 listener.set_nonblocking(true)?; 218 let addr = listener.local_addr()?; 219 if address.port() == 0 { 220 // If we passed in 0 as the port number the OS will assign an unused port; 221 // we want to update the address to the actual used port 222 address.set_port(addr.port()) 223 } 224 let (msg_send, msg_recv) = channel(); 225 226 let builder = thread::Builder::new().name("webdriver server".to_string()); 227 let handle = builder.spawn(move || { 228 let rt = tokio::runtime::Builder::new_current_thread() 229 .enable_io() 230 .build() 231 .unwrap(); 232 let listener = rt.block_on(async { TcpListener::from_std(listener).unwrap() }); 233 let wroutes = build_warp_routes( 234 address, 235 allow_hosts, 236 allow_origins, 237 &extension_routes, 238 msg_send.clone(), 239 ); 240 let fut = warp::serve(wroutes).run_incoming(TcpListenerStream::new(listener)); 241 rt.block_on(fut); 242 })?; 243 244 let builder = thread::Builder::new().name("webdriver dispatcher".to_string()); 245 builder.spawn(move || { 246 let mut dispatcher = Dispatcher::new(handler); 247 dispatcher.run(&msg_recv); 248 })?; 249 250 Ok(Listener { 251 guard: Some(handle), 252 socket: addr, 253 }) 254 } 255 256 fn build_warp_routes<U: 'static + WebDriverExtensionRoute + Send + Sync>( 257 address: SocketAddr, 258 allow_hosts: Vec<Host>, 259 allow_origins: Vec<Url>, 260 ext_routes: &[(Method, &'static str, U)], 261 chan: Sender<DispatchMessage<U>>, 262 ) -> impl Filter<Extract = (impl warp::Reply,), Error = Rejection> + Clone { 263 let chan = Arc::new(Mutex::new(chan)); 264 let mut std_routes = standard_routes::<U>(); 265 266 let (method, path, res) = std_routes.pop().unwrap(); 267 trace!("Build standard route for {path}"); 268 let mut wroutes = build_route( 269 address, 270 allow_hosts.clone(), 271 allow_origins.clone(), 272 method, 273 path, 274 res, 275 chan.clone(), 276 ); 277 278 for (method, path, res) in std_routes { 279 trace!("Build standard route for {path}"); 280 wroutes = wroutes 281 .or(build_route( 282 address, 283 allow_hosts.clone(), 284 allow_origins.clone(), 285 method, 286 path, 287 res.clone(), 288 chan.clone(), 289 )) 290 .unify() 291 .boxed() 292 } 293 294 for (method, path, res) in ext_routes { 295 trace!("Build vendor route for {path}"); 296 wroutes = wroutes 297 .or(build_route( 298 address, 299 allow_hosts.clone(), 300 allow_origins.clone(), 301 method.clone(), 302 path, 303 Route::Extension(res.clone()), 304 chan.clone(), 305 )) 306 .unify() 307 .boxed() 308 } 309 310 wroutes 311 } 312 313 fn is_host_allowed(server_address: &SocketAddr, allow_hosts: &[Host], host_header: &str) -> bool { 314 // Validate that the Host header value has a hostname in allow_hosts and 315 // the port matches the server configuration 316 let header_host_url = match Url::parse(&format!("http://{}", &host_header)) { 317 Ok(x) => x, 318 Err(_) => { 319 return false; 320 } 321 }; 322 323 let host = match header_host_url.host() { 324 Some(host) => host.to_owned(), 325 None => { 326 // This shouldn't be possible since http URL always have a 327 // host, but conservatively return false here, which will cause 328 // an error response 329 return false; 330 } 331 }; 332 let port = match header_host_url.port_or_known_default() { 333 Some(port) => port, 334 None => { 335 // This shouldn't be possible since http URL always have a 336 // default port, but conservatively return false here, which will cause 337 // an error response 338 return false; 339 } 340 }; 341 342 let host_matches = match host { 343 Host::Domain(_) => allow_hosts.contains(&host), 344 Host::Ipv4(_) | Host::Ipv6(_) => true, 345 }; 346 let port_matches = server_address.port() == port; 347 host_matches && port_matches 348 } 349 350 fn is_origin_allowed(allow_origins: &[Url], origin_url: Url) -> bool { 351 // Validate that the Origin header value is in allow_origins 352 allow_origins.contains(&origin_url) 353 } 354 355 fn build_route<U: 'static + WebDriverExtensionRoute + Send + Sync>( 356 server_address: SocketAddr, 357 allow_hosts: Vec<Host>, 358 allow_origins: Vec<Url>, 359 method: Method, 360 path: &'static str, 361 route: Route<U>, 362 chan: Arc<Mutex<Sender<DispatchMessage<U>>>>, 363 ) -> warp::filters::BoxedFilter<(impl warp::Reply,)> { 364 // Create an empty filter based on the provided method and append an empty hashmap to it. The 365 // hashmap will be used to store path parameters. 366 let mut subroute = match method { 367 Method::GET => warp::get().boxed(), 368 Method::POST => warp::post().boxed(), 369 Method::DELETE => warp::delete().boxed(), 370 Method::OPTIONS => warp::options().boxed(), 371 Method::PUT => warp::put().boxed(), 372 _ => panic!("Unsupported method"), 373 } 374 .or(warp::head()) 375 .unify() 376 .map(Parameters::new) 377 .boxed(); 378 379 // For each part of the path, if it's a normal part, just append it to the current filter, 380 // otherwise if it's a parameter (a named enclosed in { }), we take that parameter and insert 381 // it into the hashmap created earlier. 382 for part in path.split('/') { 383 if part.is_empty() { 384 continue; 385 } else if part.starts_with('{') { 386 assert!(part.ends_with('}')); 387 388 subroute = subroute 389 .and(warp::path::param()) 390 .map(move |mut params: Parameters, param: String| { 391 let name = &part[1..part.len() - 1]; 392 params.insert(name.to_string(), param); 393 params 394 }) 395 .boxed(); 396 } else { 397 subroute = subroute.and(warp::path(part)).boxed(); 398 } 399 } 400 401 // Finally, tell warp that the path is complete 402 subroute 403 .and(warp::path::end()) 404 .and(warp::path::full()) 405 .and(warp::method()) 406 .and(warp::header::optional::<String>("origin")) 407 .and(warp::header::optional::<String>("host")) 408 .and(warp::header::optional::<String>("content-type")) 409 .and(warp::body::bytes()) 410 .map( 411 move |params, 412 full_path: warp::path::FullPath, 413 method, 414 origin_header: Option<String>, 415 host_header: Option<String>, 416 content_type_header: Option<String>, 417 body: Bytes| { 418 if method == Method::HEAD { 419 return warp::reply::with_status("".into(), StatusCode::OK); 420 } 421 if let Some(host) = host_header { 422 if !is_host_allowed(&server_address, &allow_hosts, &host) { 423 warn!( 424 "Rejected request with Host header {}, allowed values are [{}]", 425 host, 426 allow_hosts 427 .iter() 428 .map(|x| format!("{}:{}", x, server_address.port())) 429 .collect::<Vec<_>>() 430 .join(",") 431 ); 432 let err = WebDriverError::new( 433 ErrorStatus::UnknownError, 434 format!("Invalid Host header {}", host), 435 ); 436 return warp::reply::with_status( 437 serde_json::to_string(&err).unwrap(), 438 StatusCode::INTERNAL_SERVER_ERROR, 439 ); 440 }; 441 } else { 442 warn!("Rejected request with missing Host header"); 443 let err = WebDriverError::new( 444 ErrorStatus::UnknownError, 445 "Missing Host header".to_string(), 446 ); 447 return warp::reply::with_status( 448 serde_json::to_string(&err).unwrap(), 449 StatusCode::INTERNAL_SERVER_ERROR, 450 ); 451 } 452 if let Some(origin) = origin_header { 453 let make_err = || { 454 warn!( 455 "Rejected request with Origin header {}, allowed values are [{}]", 456 origin, 457 allow_origins 458 .iter() 459 .map(|x| x.to_string()) 460 .collect::<Vec<_>>() 461 .join(",") 462 ); 463 WebDriverError::new( 464 ErrorStatus::UnknownError, 465 format!("Invalid Origin header {}", origin), 466 ) 467 }; 468 let origin_url = match Url::parse(&origin) { 469 Ok(url) => url, 470 Err(_) => { 471 return warp::reply::with_status( 472 serde_json::to_string(&make_err()).unwrap(), 473 StatusCode::INTERNAL_SERVER_ERROR, 474 ); 475 } 476 }; 477 if !is_origin_allowed(&allow_origins, origin_url) { 478 return warp::reply::with_status( 479 serde_json::to_string(&make_err()).unwrap(), 480 StatusCode::INTERNAL_SERVER_ERROR, 481 ); 482 } 483 } 484 if method == Method::POST { 485 // Disallow CORS-safelisted request headers 486 // c.f. https://fetch.spec.whatwg.org/#cors-safelisted-request-header 487 let content_type = content_type_header 488 .as_ref() 489 .map(|x| x.find(';').and_then(|idx| x.get(0..idx)).unwrap_or(x)) 490 .map(|x| x.trim()) 491 .map(|x| x.to_lowercase()); 492 match content_type.as_ref().map(|x| x.as_ref()) { 493 Some("application/x-www-form-urlencoded") 494 | Some("multipart/form-data") 495 | Some("text/plain") => { 496 warn!( 497 "Rejected POST request with disallowed content type {}", 498 content_type.unwrap_or_else(|| "".into()) 499 ); 500 let err = WebDriverError::new( 501 ErrorStatus::UnknownError, 502 "Invalid Content-Type", 503 ); 504 return warp::reply::with_status( 505 serde_json::to_string(&err).unwrap(), 506 StatusCode::INTERNAL_SERVER_ERROR, 507 ); 508 } 509 Some(_) | None => {} 510 } 511 } 512 let body = String::from_utf8(body.chunk().to_vec()); 513 if body.is_err() { 514 let err = WebDriverError::new( 515 ErrorStatus::UnknownError, 516 "Request body wasn't valid UTF-8", 517 ); 518 return warp::reply::with_status( 519 serde_json::to_string(&err).unwrap(), 520 StatusCode::INTERNAL_SERVER_ERROR, 521 ); 522 } 523 let body = body.unwrap(); 524 525 debug!("-> {} {} {}", method, full_path.as_str(), body); 526 let msg_result = WebDriverMessage::from_http( 527 route.clone(), 528 ¶ms, 529 &body, 530 method == Method::POST, 531 ); 532 533 let (status, resp_body) = match msg_result { 534 Ok(message) => { 535 let (send_res, recv_res) = channel(); 536 match chan.lock() { 537 Ok(ref c) => { 538 let res = 539 c.send(DispatchMessage::HandleWebDriver(message, send_res)); 540 match res { 541 Ok(x) => x, 542 Err(e) => panic!("Error: {:?}", e), 543 } 544 } 545 Err(e) => panic!("Error reading response: {:?}", e), 546 } 547 548 match recv_res.recv() { 549 Ok(data) => match data { 550 Ok(response) => { 551 (StatusCode::OK, serde_json::to_string(&response).unwrap()) 552 } 553 Err(e) => (e.http_status(), serde_json::to_string(&e).unwrap()), 554 }, 555 Err(e) => panic!("Error reading response: {:?}", e), 556 } 557 } 558 Err(e) => (e.http_status(), serde_json::to_string(&e).unwrap()), 559 }; 560 561 debug!("<- {} {}", status, resp_body); 562 warp::reply::with_status(resp_body, status) 563 }, 564 ) 565 .with(warp::reply::with::header( 566 http::header::CONTENT_TYPE, 567 "application/json; charset=utf-8", 568 )) 569 .with(warp::reply::with::header( 570 http::header::CACHE_CONTROL, 571 "no-cache", 572 )) 573 .boxed() 574 } 575 576 #[cfg(test)] 577 mod tests { 578 use super::*; 579 use std::net::IpAddr; 580 use std::str::FromStr; 581 582 #[test] 583 fn test_host_allowed() { 584 let addr_80 = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80); 585 let addr_8000 = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 8000); 586 let addr_v6_80 = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 80); 587 let addr_v6_8000 = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 8000); 588 589 // We match the host ip address to the server, so we can only use hosts that actually resolve 590 let localhost_host = Host::Domain("localhost".to_string()); 591 let test_host = Host::Domain("example.test".to_string()); 592 let subdomain_localhost_host = Host::Domain("subdomain.localhost".to_string()); 593 594 assert!(is_host_allowed( 595 &addr_80, 596 &[localhost_host.clone()], 597 "localhost:80" 598 )); 599 assert!(is_host_allowed( 600 &addr_80, 601 &[test_host.clone()], 602 "example.test:80" 603 )); 604 assert!(is_host_allowed( 605 &addr_80, 606 &[test_host.clone(), localhost_host.clone()], 607 "example.test" 608 )); 609 assert!(is_host_allowed( 610 &addr_80, 611 &[subdomain_localhost_host.clone()], 612 "subdomain.localhost" 613 )); 614 615 // ip address cases 616 assert!(is_host_allowed(&addr_80, &[], "127.0.0.1:80")); 617 assert!(is_host_allowed(&addr_v6_80, &[], "127.0.0.1")); 618 assert!(is_host_allowed(&addr_80, &[], "[::1]")); 619 assert!(is_host_allowed(&addr_8000, &[], "127.0.0.1:8000")); 620 assert!(is_host_allowed( 621 &addr_80, 622 &[subdomain_localhost_host.clone()], 623 "[::1]" 624 )); 625 assert!(is_host_allowed( 626 &addr_v6_8000, 627 &[subdomain_localhost_host.clone()], 628 "[::1]:8000" 629 )); 630 631 // Mismatch cases 632 633 assert!(!is_host_allowed(&addr_80, &[test_host], "localhost")); 634 635 assert!(!is_host_allowed(&addr_80, &[], "localhost:80")); 636 637 // Port mismatch cases 638 639 assert!(!is_host_allowed( 640 &addr_80, 641 &[localhost_host.clone()], 642 "localhost:8000" 643 )); 644 assert!(!is_host_allowed( 645 &addr_8000, 646 &[localhost_host.clone()], 647 "localhost" 648 )); 649 assert!(!is_host_allowed( 650 &addr_v6_8000, 651 &[localhost_host.clone()], 652 "[::1]" 653 )); 654 } 655 656 #[test] 657 fn test_origin_allowed() { 658 assert!(is_origin_allowed( 659 &[Url::parse("http://localhost").unwrap()], 660 Url::parse("http://localhost").unwrap() 661 )); 662 assert!(is_origin_allowed( 663 &[Url::parse("http://localhost").unwrap()], 664 Url::parse("http://localhost:80").unwrap() 665 )); 666 assert!(is_origin_allowed( 667 &[ 668 Url::parse("https://test.example").unwrap(), 669 Url::parse("http://localhost").unwrap() 670 ], 671 Url::parse("http://localhost").unwrap() 672 )); 673 assert!(is_origin_allowed( 674 &[ 675 Url::parse("https://test.example").unwrap(), 676 Url::parse("http://localhost").unwrap() 677 ], 678 Url::parse("https://test.example:443").unwrap() 679 )); 680 // Mismatch cases 681 assert!(!is_origin_allowed( 682 &[], 683 Url::parse("http://localhost").unwrap() 684 )); 685 assert!(!is_origin_allowed( 686 &[Url::parse("http://localhost").unwrap()], 687 Url::parse("http://localhost:8000").unwrap() 688 )); 689 assert!(!is_origin_allowed( 690 &[Url::parse("https://localhost").unwrap()], 691 Url::parse("http://localhost").unwrap() 692 )); 693 assert!(!is_origin_allowed( 694 &[Url::parse("https://example.test").unwrap()], 695 Url::parse("http://subdomain.example.test").unwrap() 696 )); 697 } 698 }