tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

lib.rs (29486B)


      1 /* -*- Mode: rust; rust-indent-offset: 2 -*- */
      2 /* This Source Code Form is subject to the terms of the Mozilla Public
      3 * License, v. 2.0. If a copy of the MPL was not distributed with this
      4 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
      5 use byteorder::{BigEndian, WriteBytesExt};
      6 use socket2::{Domain, Socket, Type};
      7 use std::collections::HashMap;
      8 use std::collections::LinkedList;
      9 use std::ffi::{c_void, CStr, CString};
     10 use std::io;
     11 use std::net;
     12 use std::os::raw::c_char;
     13 use std::sync::mpsc::channel;
     14 use std::thread;
     15 use std::time;
     16 use uuid::Uuid;
     17 
     18 #[macro_use]
     19 extern crate log;
     20 
     21 struct Callback {
     22    data: *const c_void,
     23    resolved: unsafe extern "C" fn(*const c_void, *const c_char, *const c_char),
     24    timedout: unsafe extern "C" fn(*const c_void, *const c_char),
     25 }
     26 
     27 unsafe impl Send for Callback {}
     28 
     29 fn hostname_resolved(callback: &Callback, hostname: &str, addr: &str) {
     30    if let Ok(hostname) = CString::new(hostname) {
     31        if let Ok(addr) = CString::new(addr) {
     32            unsafe {
     33                (callback.resolved)(callback.data, hostname.as_ptr(), addr.as_ptr());
     34            }
     35        }
     36    }
     37 }
     38 
     39 fn hostname_timedout(callback: &Callback, hostname: &str) {
     40    if let Ok(hostname) = CString::new(hostname) {
     41        unsafe {
     42            (callback.timedout)(callback.data, hostname.as_ptr());
     43        }
     44    }
     45 }
     46 
     47 // This code is derived from code for creating questions in the dns-parser
     48 // crate. It would be nice to upstream this, or something similar.
     49 fn create_answer(id: u16, answers: &[(String, &[u8])]) -> Result<Vec<u8>, io::Error> {
     50    let mut buf = Vec::with_capacity(512);
     51    let head = dns_parser::Header {
     52        id,
     53        query: false,
     54        opcode: dns_parser::Opcode::StandardQuery,
     55        authoritative: true,
     56        truncated: false,
     57        recursion_desired: false,
     58        recursion_available: false,
     59        authenticated_data: false,
     60        checking_disabled: false,
     61        response_code: dns_parser::ResponseCode::NoError,
     62        questions: 0,
     63        answers: answers.len() as u16,
     64        nameservers: 0,
     65        additional: 0,
     66    };
     67 
     68    buf.extend([0u8; 12].iter());
     69    head.write(&mut buf[..12]);
     70 
     71    for (name, addr) in answers {
     72        for part in name.split('.') {
     73            if part.len() > 62 {
     74                return Err(io::Error::new(
     75                    io::ErrorKind::Other,
     76                    "Name part length too long",
     77                ));
     78            }
     79            let ln = part.len() as u8;
     80            buf.push(ln);
     81            buf.extend(part.as_bytes());
     82        }
     83        buf.push(0);
     84 
     85        if addr.len() == 4 {
     86            buf.write_u16::<BigEndian>(dns_parser::Type::A as u16)?;
     87        } else {
     88            buf.write_u16::<BigEndian>(dns_parser::Type::AAAA as u16)?;
     89        }
     90        // set cache flush bit
     91        buf.write_u16::<BigEndian>(dns_parser::Class::IN as u16 | (0x1 << 15))?;
     92        buf.write_u32::<BigEndian>(120)?;
     93        buf.write_u16::<BigEndian>(addr.len() as u16)?;
     94        buf.extend(*addr);
     95    }
     96 
     97    Ok(buf)
     98 }
     99 
    100 fn create_query(id: u16, queries: &[String]) -> Result<Vec<u8>, io::Error> {
    101    let mut buf = Vec::with_capacity(512);
    102    let head = dns_parser::Header {
    103        id,
    104        query: true,
    105        opcode: dns_parser::Opcode::StandardQuery,
    106        authoritative: false,
    107        truncated: false,
    108        recursion_desired: false,
    109        recursion_available: false,
    110        authenticated_data: false,
    111        checking_disabled: false,
    112        response_code: dns_parser::ResponseCode::NoError,
    113        questions: queries.len() as u16,
    114        answers: 0,
    115        nameservers: 0,
    116        additional: 0,
    117    };
    118 
    119    buf.extend([0u8; 12].iter());
    120    head.write(&mut buf[..12]);
    121 
    122    for name in queries {
    123        for part in name.split('.') {
    124            assert!(part.len() < 63);
    125            let ln = part.len() as u8;
    126            buf.push(ln);
    127            buf.extend(part.as_bytes());
    128        }
    129        buf.push(0);
    130 
    131        buf.write_u16::<BigEndian>(dns_parser::QueryType::A as u16)?;
    132        buf.write_u16::<BigEndian>(dns_parser::QueryClass::IN as u16)?;
    133    }
    134 
    135    Ok(buf)
    136 }
    137 
    138 fn handle_queries(
    139    socket: &std::net::UdpSocket,
    140    mdns_addr: &std::net::SocketAddr,
    141    pending_queries: &mut HashMap<String, Query>,
    142    unsent_queries: &mut LinkedList<Query>,
    143 ) {
    144    if pending_queries.len() < 50 {
    145        let mut queries: Vec<Query> = Vec::new();
    146        while queries.len() < 5 && !unsent_queries.is_empty() {
    147            if let Some(query) = unsent_queries.pop_front() {
    148                if !pending_queries.contains_key(&query.hostname) {
    149                    queries.push(query);
    150                }
    151            }
    152        }
    153        if !queries.is_empty() {
    154            let query_hostnames: Vec<String> =
    155                queries.iter().map(|q| q.hostname.to_string()).collect();
    156 
    157            if let Ok(buf) = create_query(0, &query_hostnames) {
    158                match socket.send_to(&buf, &mdns_addr) {
    159                    Ok(_) => {
    160                        for query in queries {
    161                            pending_queries.insert(query.hostname.to_string(), query);
    162                        }
    163                    }
    164                    Err(err) => {
    165                        warn!("Sending mDNS query failed: {}", err);
    166                        if err.kind() != io::ErrorKind::PermissionDenied {
    167                            for query in queries {
    168                                unsent_queries.push_back(query);
    169                            }
    170                        } else {
    171                            for query in queries {
    172                                hostname_timedout(&query.callback, &query.hostname);
    173                            }
    174                        }
    175                    }
    176                }
    177            }
    178        }
    179    }
    180 
    181    let now = time::Instant::now();
    182    let expired: Vec<String> = pending_queries
    183        .iter()
    184        .filter(|(_, query)| now.duration_since(query.timestamp).as_secs() >= 3)
    185        .map(|(hostname, _)| hostname.to_string())
    186        .collect();
    187    for hostname in expired {
    188        if let Some(mut query) = pending_queries.remove(&hostname) {
    189            query.attempts += 1;
    190            if query.attempts < 3 {
    191                query.timestamp = now;
    192                unsent_queries.push_back(query);
    193            } else {
    194                hostname_timedout(&query.callback, &hostname);
    195            }
    196        }
    197    }
    198 }
    199 
    200 fn handle_mdns_socket(
    201    socket: &std::net::UdpSocket,
    202    mdns_addr: &std::net::SocketAddr,
    203    mut buffer: &mut [u8],
    204    hosts: &mut HashMap<String, Vec<u8>>,
    205    pending_queries: &mut HashMap<String, Query>,
    206 ) -> bool {
    207    // Record a simple marker to see how often this is called.
    208    gecko_profiler::add_untyped_marker(
    209        "handle_mdns_socket",
    210        gecko_profiler::gecko_profiler_category!(Network),
    211        Default::default(),
    212    );
    213 
    214    match socket.recv_from(&mut buffer) {
    215        Ok((amt, _)) => {
    216            if amt > 0 {
    217                let buffer = &buffer[0..amt];
    218                match dns_parser::Packet::parse(&buffer) {
    219                    Ok(parsed) => {
    220                        let mut answers: Vec<(String, &[u8])> = Vec::new();
    221 
    222                        // If a packet contains both both questions and
    223                        // answers, the questions should be ignored.
    224                        if parsed.answers.is_empty() {
    225                            parsed
    226                                .questions
    227                                .iter()
    228                                .filter(|question| question.qtype == dns_parser::QueryType::A)
    229                                .for_each(|question| {
    230                                    let qname = question.qname.to_string();
    231                                    trace!("mDNS question: {} {:?}", qname, question.qtype);
    232                                    if let Some(octets) = hosts.get(&qname) {
    233                                        trace!("Sending mDNS answer for {}: {:?}", qname, octets);
    234                                        answers.push((qname, &octets));
    235                                    }
    236                                });
    237                        }
    238                        for answer in parsed.answers {
    239                            let hostname = answer.name.to_string();
    240                            match pending_queries.get(&hostname) {
    241                                Some(query) => {
    242                                    match answer.data {
    243                                        dns_parser::RData::A(dns_parser::rdata::a::Record(
    244                                            addr,
    245                                        )) => {
    246                                            let addr = addr.to_string();
    247                                            trace!("mDNS response: {} {}", hostname, addr);
    248                                            hostname_resolved(&query.callback, &hostname, &addr);
    249                                        }
    250                                        dns_parser::RData::AAAA(
    251                                            dns_parser::rdata::aaaa::Record(addr),
    252                                        ) => {
    253                                            let addr = addr.to_string();
    254                                            trace!("mDNS response: {} {}", hostname, addr);
    255                                            hostname_resolved(&query.callback, &hostname, &addr);
    256                                        }
    257                                        _ => {}
    258                                    }
    259                                    pending_queries.remove(&hostname);
    260                                }
    261                                None => {
    262                                    continue;
    263                                }
    264                            }
    265                        }
    266                        // TODO: If we did not answer every query in this
    267                        // question, we should wait for a random amount of time
    268                        // so as to not collide with someone else responding to
    269                        // this query.
    270                        if !answers.is_empty() {
    271                            if let Ok(buf) = create_answer(parsed.header.id, &answers) {
    272                                if let Err(err) = socket.send_to(&buf, &mdns_addr) {
    273                                    warn!("Sending mDNS answer failed: {}", err);
    274                                }
    275                            }
    276                        }
    277                    }
    278                    Err(err) => {
    279                        warn!("Could not parse mDNS packet: {}", err);
    280                    }
    281                }
    282            }
    283        }
    284        Err(err) => {
    285            if err.kind() != io::ErrorKind::Interrupted
    286                && err.kind() != io::ErrorKind::TimedOut
    287                && err.kind() != io::ErrorKind::WouldBlock
    288            {
    289                error!("Socket error: {}", err);
    290                return false;
    291            }
    292        }
    293    }
    294 
    295    true
    296 }
    297 
    298 fn validate_hostname(hostname: &str) -> bool {
    299    match hostname.find(".local") {
    300        Some(index) => match hostname.get(0..index) {
    301            Some(uuid) => match uuid.get(0..36) {
    302                Some(initial) => match Uuid::parse_str(initial) {
    303                    Ok(_) => {
    304                        // Oddly enough, Safari does not generate valid UUIDs,
    305                        // the last part sometimes contains more than 12 digits.
    306                        match uuid.get(36..) {
    307                            Some(trailing) => {
    308                                for c in trailing.chars() {
    309                                    if !c.is_ascii_hexdigit() {
    310                                        return false;
    311                                    }
    312                                }
    313                                true
    314                            }
    315                            None => true,
    316                        }
    317                    }
    318                    Err(_) => false,
    319                },
    320                None => false,
    321            },
    322            None => false,
    323        },
    324        None => false,
    325    }
    326 }
    327 
    328 enum ServiceControl {
    329    Register {
    330        hostname: String,
    331        address: String,
    332    },
    333    Query {
    334        callback: Callback,
    335        hostname: String,
    336    },
    337    Unregister {
    338        hostname: String,
    339    },
    340    Stop,
    341 }
    342 
    343 struct Query {
    344    hostname: String,
    345    callback: Callback,
    346    timestamp: time::Instant,
    347    attempts: i32,
    348 }
    349 
    350 impl Query {
    351    fn new(hostname: &str, callback: Callback) -> Query {
    352        Query {
    353            hostname: hostname.to_string(),
    354            callback,
    355            timestamp: time::Instant::now(),
    356            attempts: 0,
    357        }
    358    }
    359 }
    360 
    361 pub struct MDNSService {
    362    handle: Option<std::thread::JoinHandle<()>>,
    363    sender: Option<std::sync::mpsc::Sender<ServiceControl>>,
    364 }
    365 
    366 impl MDNSService {
    367    fn register_hostname(&mut self, hostname: &str, address: &str) {
    368        if let Some(sender) = &self.sender {
    369            if let Err(err) = sender.send(ServiceControl::Register {
    370                hostname: hostname.to_string(),
    371                address: address.to_string(),
    372            }) {
    373                warn!(
    374                    "Could not send register hostname {} message: {}",
    375                    hostname, err
    376                );
    377            }
    378        }
    379    }
    380 
    381    fn query_hostname(&mut self, callback: Callback, hostname: &str) {
    382        if let Some(sender) = &self.sender {
    383            if let Err(err) = sender.send(ServiceControl::Query {
    384                callback,
    385                hostname: hostname.to_string(),
    386            }) {
    387                warn!(
    388                    "Could not send query hostname {} message: {}",
    389                    hostname, err
    390                );
    391            }
    392        }
    393    }
    394 
    395    fn unregister_hostname(&mut self, hostname: &str) {
    396        if let Some(sender) = &self.sender {
    397            if let Err(err) = sender.send(ServiceControl::Unregister {
    398                hostname: hostname.to_string(),
    399            }) {
    400                warn!(
    401                    "Could not send unregister hostname {} message: {}",
    402                    hostname, err
    403                );
    404            }
    405        }
    406    }
    407 
    408    fn start(&mut self, addrs: Vec<std::net::Ipv4Addr>) -> io::Result<()> {
    409        let (sender, receiver) = channel();
    410        self.sender = Some(sender);
    411 
    412        let mdns_addr = std::net::Ipv4Addr::new(224, 0, 0, 251);
    413        let port = 5353;
    414 
    415        let socket = Socket::new(Domain::IPV4, Type::DGRAM, None)?;
    416        socket.set_reuse_address(true)?;
    417 
    418        #[cfg(not(target_os = "windows"))]
    419        socket.set_reuse_port(true)?;
    420        socket.bind(&socket2::SockAddr::from(std::net::SocketAddr::from((
    421            [0, 0, 0, 0],
    422            port,
    423        ))))?;
    424 
    425        let socket = std::net::UdpSocket::from(socket);
    426        socket.set_multicast_loop_v4(true)?;
    427        socket.set_read_timeout(Some(time::Duration::from_millis(1)))?;
    428        socket.set_write_timeout(Some(time::Duration::from_millis(1)))?;
    429        for addr in addrs {
    430            if let Err(err) = socket.join_multicast_v4(&mdns_addr, &addr) {
    431                warn!(
    432                    "Could not join multicast group on interface: {:?}: {}",
    433                    addr, err
    434                );
    435            }
    436        }
    437 
    438        let thread_name = "mdns_service";
    439        let builder = thread::Builder::new().name(thread_name.into());
    440        self.handle = Some(builder.spawn(move || {
    441            gecko_profiler::register_thread(thread_name);
    442            let mdns_addr = std::net::SocketAddr::from(([224, 0, 0, 251], port));
    443            let mut buffer: [u8; 9_000] = [0; 9_000];
    444            let mut hosts = HashMap::new();
    445            let mut unsent_queries = LinkedList::new();
    446            let mut pending_queries = HashMap::new();
    447            loop {
    448                match receiver.try_recv() {
    449                    Ok(msg) => match msg {
    450                        ServiceControl::Register { hostname, address } => {
    451                            if !validate_hostname(&hostname) {
    452                                warn!("Not registering invalid hostname: {}", hostname);
    453                                continue;
    454                            }
    455                            trace!("Registering {} for: {}", hostname, address);
    456                            match address.parse().and_then(|ip| {
    457                                Ok(match ip {
    458                                    net::IpAddr::V4(ip) => ip.octets().to_vec(),
    459                                    net::IpAddr::V6(ip) => ip.octets().to_vec(),
    460                                })
    461                            }) {
    462                                Ok(octets) => {
    463                                    let mut v = Vec::new();
    464                                    v.extend(octets);
    465                                    hosts.insert(hostname, v);
    466                                }
    467                                Err(err) => {
    468                                    warn!(
    469                                        "Could not parse address for {}: {}: {}",
    470                                        hostname, address, err
    471                                    );
    472                                }
    473                            }
    474                        }
    475                        ServiceControl::Query { callback, hostname } => {
    476                            trace!("Querying {}", hostname);
    477                            if !validate_hostname(&hostname) {
    478                                warn!("Not sending mDNS query for invalid hostname: {}", hostname);
    479                                continue;
    480                            }
    481                            unsent_queries.push_back(Query::new(&hostname, callback));
    482                        }
    483                        ServiceControl::Unregister { hostname } => {
    484                            trace!("Unregistering {}", hostname);
    485                            hosts.remove(&hostname);
    486                        }
    487                        ServiceControl::Stop => {
    488                            trace!("Stopping");
    489                            break;
    490                        }
    491                    },
    492                    Err(std::sync::mpsc::TryRecvError::Disconnected) => {
    493                        break;
    494                    }
    495                    Err(std::sync::mpsc::TryRecvError::Empty) => {}
    496                }
    497 
    498                handle_queries(
    499                    &socket,
    500                    &mdns_addr,
    501                    &mut pending_queries,
    502                    &mut unsent_queries,
    503                );
    504 
    505                if !handle_mdns_socket(
    506                    &socket,
    507                    &mdns_addr,
    508                    &mut buffer,
    509                    &mut hosts,
    510                    &mut pending_queries,
    511                ) {
    512                    break;
    513                }
    514            }
    515            gecko_profiler::unregister_thread();
    516        })?);
    517 
    518        Ok(())
    519    }
    520 
    521    fn stop(self) {
    522        if let Some(sender) = self.sender {
    523            if let Err(err) = sender.send(ServiceControl::Stop) {
    524                warn!("Could not stop mDNS Service: {}", err);
    525            }
    526            if let Some(handle) = self.handle {
    527                if handle.join().is_err() {
    528                    error!("Error on thread join");
    529                }
    530            }
    531        }
    532    }
    533 
    534    fn new() -> MDNSService {
    535        MDNSService {
    536            handle: None,
    537            sender: None,
    538        }
    539    }
    540 }
    541 
    542 /// # Safety
    543 ///
    544 /// This function must only be called with a valid MDNSService pointer.
    545 /// This hostname and address arguments must be zero terminated strings.
    546 #[no_mangle]
    547 pub unsafe extern "C" fn mdns_service_register_hostname(
    548    serv: *mut MDNSService,
    549    hostname: *const c_char,
    550    address: *const c_char,
    551 ) {
    552    assert!(!serv.is_null());
    553    assert!(!hostname.is_null());
    554    assert!(!address.is_null());
    555    let hostname = CStr::from_ptr(hostname).to_string_lossy();
    556    let address = CStr::from_ptr(address).to_string_lossy();
    557    (*serv).register_hostname(&hostname, &address);
    558 }
    559 
    560 /// # Safety
    561 ///
    562 /// This ifaddrs argument must be a zero terminated string.
    563 #[no_mangle]
    564 pub unsafe extern "C" fn mdns_service_start(ifaddrs: *const c_char) -> *mut MDNSService {
    565    assert!(!ifaddrs.is_null());
    566    let mut r = Box::new(MDNSService::new());
    567    let ifaddrs = CStr::from_ptr(ifaddrs).to_string_lossy();
    568    let addrs: Vec<std::net::Ipv4Addr> =
    569        ifaddrs.split(';').filter_map(|x| x.parse().ok()).collect();
    570 
    571    if addrs.is_empty() {
    572        warn!("Could not parse interface addresses from: {}", ifaddrs);
    573    } else if let Err(err) = r.start(addrs) {
    574        warn!("Could not start mDNS Service: {}", err);
    575    }
    576 
    577    Box::into_raw(r)
    578 }
    579 
    580 /// # Safety
    581 ///
    582 /// This function must only be called with a valid MDNSService pointer.
    583 #[no_mangle]
    584 pub unsafe extern "C" fn mdns_service_stop(serv: *mut MDNSService) {
    585    assert!(!serv.is_null());
    586    let boxed = Box::from_raw(serv);
    587    boxed.stop();
    588 }
    589 
    590 /// # Safety
    591 ///
    592 /// This function must only be called with a valid MDNSService pointer.
    593 /// The data argument will be passed back into the resolved and timedout
    594 /// functions. The object it points to must not be freed until the MDNSService
    595 /// has stopped.
    596 #[no_mangle]
    597 pub unsafe extern "C" fn mdns_service_query_hostname(
    598    serv: *mut MDNSService,
    599    data: *const c_void,
    600    resolved: unsafe extern "C" fn(*const c_void, *const c_char, *const c_char),
    601    timedout: unsafe extern "C" fn(*const c_void, *const c_char),
    602    hostname: *const c_char,
    603 ) {
    604    assert!(!serv.is_null());
    605    assert!(!data.is_null());
    606    assert!(!hostname.is_null());
    607    let hostname = CStr::from_ptr(hostname).to_string_lossy();
    608    let callback = Callback {
    609        data,
    610        resolved,
    611        timedout,
    612    };
    613    (*serv).query_hostname(callback, &hostname);
    614 }
    615 
    616 /// # Safety
    617 ///
    618 /// This function must only be called with a valid MDNSService pointer.
    619 /// This function should only be called once per hostname.
    620 #[no_mangle]
    621 pub unsafe extern "C" fn mdns_service_unregister_hostname(
    622    serv: *mut MDNSService,
    623    hostname: *const c_char,
    624 ) {
    625    assert!(!serv.is_null());
    626    assert!(!hostname.is_null());
    627    let hostname = CStr::from_ptr(hostname).to_string_lossy();
    628    (*serv).unregister_hostname(&hostname);
    629 }
    630 
    631 #[cfg(test)]
    632 mod tests {
    633    use crate::create_query;
    634    use crate::validate_hostname;
    635    use crate::Callback;
    636    use crate::MDNSService;
    637    use socket2::{Domain, Socket, Type};
    638    use std::collections::HashSet;
    639    use std::ffi::c_void;
    640    use std::io;
    641    use std::iter::FromIterator;
    642    use std::os::raw::c_char;
    643    use std::thread;
    644    use std::time;
    645    use uuid::Uuid;
    646 
    647    #[no_mangle]
    648    pub unsafe extern "C" fn mdns_service_resolved(
    649        _: *const c_void,
    650        _: *const c_char,
    651        _: *const c_char,
    652    ) -> () {
    653    }
    654 
    655    #[no_mangle]
    656    pub unsafe extern "C" fn mdns_service_timedout(_: *const c_void, _: *const c_char) -> () {}
    657 
    658    fn listen_until(addr: &std::net::Ipv4Addr, stop: u64) -> thread::JoinHandle<Vec<String>> {
    659        let port = 5353;
    660 
    661        let socket = Socket::new(Domain::IPV4, Type::DGRAM, None).unwrap();
    662        socket.set_reuse_address(true).unwrap();
    663 
    664        #[cfg(not(target_os = "windows"))]
    665        socket.set_reuse_port(true).unwrap();
    666        socket
    667            .bind(&socket2::SockAddr::from(std::net::SocketAddr::from((
    668                [0, 0, 0, 0],
    669                port,
    670            ))))
    671            .unwrap();
    672 
    673        let socket = std::net::UdpSocket::from(socket);
    674        socket.set_multicast_loop_v4(true).unwrap();
    675        socket
    676            .set_read_timeout(Some(time::Duration::from_millis(10)))
    677            .unwrap();
    678        socket
    679            .set_write_timeout(Some(time::Duration::from_millis(10)))
    680            .unwrap();
    681        socket
    682            .join_multicast_v4(&std::net::Ipv4Addr::new(224, 0, 0, 251), &addr)
    683            .unwrap();
    684 
    685        let mut buffer: [u8; 9_000] = [0; 9_000];
    686        thread::spawn(move || {
    687            let start = time::Instant::now();
    688            let mut questions = Vec::new();
    689            while time::Instant::now().duration_since(start).as_secs() < stop {
    690                match socket.recv_from(&mut buffer) {
    691                    Ok((amt, _)) => {
    692                        if amt > 0 {
    693                            let buffer = &buffer[0..amt];
    694                            match dns_parser::Packet::parse(&buffer) {
    695                                Ok(parsed) => {
    696                                    parsed
    697                                        .questions
    698                                        .iter()
    699                                        .filter(|question| {
    700                                            question.qtype == dns_parser::QueryType::A
    701                                        })
    702                                        .for_each(|question| {
    703                                            let qname = question.qname.to_string();
    704                                            questions.push(qname);
    705                                        });
    706                                }
    707                                Err(err) => {
    708                                    warn!("Could not parse mDNS packet: {}", err);
    709                                }
    710                            }
    711                        }
    712                    }
    713                    Err(err) => {
    714                        if err.kind() != io::ErrorKind::WouldBlock
    715                            && err.kind() != io::ErrorKind::TimedOut
    716                        {
    717                            error!("Socket error: {}", err);
    718                            break;
    719                        }
    720                    }
    721                }
    722            }
    723            questions
    724        })
    725    }
    726 
    727    #[test]
    728    fn test_validate_hostname() {
    729        assert_eq!(
    730            validate_hostname("e17f08d4-689a-4df6-ba31-35bb9f041100.local"),
    731            true
    732        );
    733        assert_eq!(
    734            validate_hostname("62240723-ae6d-4f6a-99b8-94a233e3f84a2.local"),
    735            true
    736        );
    737        assert_eq!(
    738            validate_hostname("62240723-ae6d-4f6a-99b8.94e3f84a2.local"),
    739            false
    740        );
    741        assert_eq!(validate_hostname("hi there"), false);
    742    }
    743 
    744    #[test]
    745    fn start_stop() {
    746        let mut service = MDNSService::new();
    747        let addr = "127.0.0.1".parse().unwrap();
    748        service.start(vec![addr]).unwrap();
    749        service.stop();
    750    }
    751 
    752    #[test]
    753    fn simple_query() {
    754        let mut service = MDNSService::new();
    755        let addr = "127.0.0.1".parse().unwrap();
    756        let handle = listen_until(&addr, 1);
    757 
    758        service.start(vec![addr]).unwrap();
    759 
    760        let callback = Callback {
    761            data: 0 as *const c_void,
    762            resolved: mdns_service_resolved,
    763            timedout: mdns_service_timedout,
    764        };
    765        let hostname = Uuid::new_v4().as_hyphenated().to_string() + ".local";
    766        service.query_hostname(callback, &hostname);
    767        service.stop();
    768        let questions = handle.join().unwrap();
    769        assert!(questions.contains(&hostname));
    770    }
    771 
    772    #[test]
    773    fn rate_limited_query() {
    774        let mut service = MDNSService::new();
    775        let addr = "127.0.0.1".parse().unwrap();
    776        let handle = listen_until(&addr, 1);
    777 
    778        service.start(vec![addr]).unwrap();
    779 
    780        let mut hostnames = HashSet::new();
    781        for _ in 0..100 {
    782            let callback = Callback {
    783                data: 0 as *const c_void,
    784                resolved: mdns_service_resolved,
    785                timedout: mdns_service_timedout,
    786            };
    787            let hostname = Uuid::new_v4().as_hyphenated().to_string() + ".local";
    788            service.query_hostname(callback, &hostname);
    789            hostnames.insert(hostname);
    790        }
    791        service.stop();
    792        let questions = HashSet::from_iter(handle.join().unwrap().iter().map(|x| x.to_string()));
    793        let intersection: HashSet<&String> = questions.intersection(&hostnames).collect();
    794        assert_eq!(intersection.len(), 50);
    795    }
    796 
    797    #[test]
    798    fn repeat_failed_query() {
    799        let mut service = MDNSService::new();
    800        let addr = "127.0.0.1".parse().unwrap();
    801        let handle = listen_until(&addr, 4);
    802 
    803        service.start(vec![addr]).unwrap();
    804 
    805        let hostname = Uuid::new_v4().as_hyphenated().to_string() + ".local";
    806        let callback = Callback {
    807            data: 0 as *const c_void,
    808            resolved: mdns_service_resolved,
    809            timedout: mdns_service_timedout,
    810        };
    811        service.query_hostname(callback, &hostname);
    812        thread::sleep(time::Duration::from_secs(4));
    813        service.stop();
    814 
    815        let questions: Vec<String> = handle
    816            .join()
    817            .unwrap()
    818            .iter()
    819            .filter(|x| *x == &hostname)
    820            .map(|x| x.to_string())
    821            .collect();
    822        assert_eq!(questions.len(), 2);
    823    }
    824 
    825    #[test]
    826    fn multiple_queries_in_a_single_packet() {
    827        let mut hostnames: Vec<String> = Vec::new();
    828        for _ in 0..100 {
    829            let hostname = Uuid::new_v4().as_hyphenated().to_string() + ".local";
    830            hostnames.push(hostname);
    831        }
    832 
    833        match create_query(42, &hostnames) {
    834            Ok(q) => match dns_parser::Packet::parse(&q) {
    835                Ok(parsed) => {
    836                    assert_eq!(parsed.questions.len(), 100);
    837                }
    838                Err(_) => assert!(false),
    839            },
    840            Err(_) => assert!(false),
    841        }
    842    }
    843 }