tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

message.rs (10290B)


      1 /* This Source Code Form is subject to the terms of the Mozilla Public
      2 * License, v. 2.0. If a copy of the MPL was not distributed with this
      3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
      4 
      5 use serde::de::{self, SeqAccess, Unexpected, Visitor};
      6 use serde::{Deserialize, Deserializer, Serialize, Serializer};
      7 use serde_json::{Map, Value};
      8 use serde_repr::{Deserialize_repr, Serialize_repr};
      9 use std::fmt;
     10 
     11 use crate::error::MarionetteError;
     12 use crate::marionette;
     13 use crate::result::MarionetteResult;
     14 use crate::webdriver;
     15 
     16 #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
     17 #[serde(untagged)]
     18 pub enum Command {
     19    WebDriver(webdriver::Command),
     20    Marionette(marionette::Command),
     21 }
     22 
     23 impl Command {
     24    pub fn name(&self) -> String {
     25        let (command_name, _) = self.first_entry();
     26        command_name
     27    }
     28 
     29    fn params(&self) -> Value {
     30        let (_, params) = self.first_entry();
     31        params
     32    }
     33 
     34    fn first_entry(&self) -> (String, serde_json::Value) {
     35        match serde_json::to_value(self).unwrap() {
     36            Value::String(cmd) => (cmd, Value::Object(Map::new())),
     37            Value::Object(items) => {
     38                let mut iter = items.iter();
     39                let (cmd, params) = iter.next().unwrap();
     40                (cmd.to_string(), params.clone())
     41            }
     42            _ => unreachable!(),
     43        }
     44    }
     45 }
     46 
     47 #[derive(Clone, Debug, PartialEq, Serialize_repr, Deserialize_repr)]
     48 #[repr(u8)]
     49 enum MessageDirection {
     50    Incoming = 0,
     51    Outgoing = 1,
     52 }
     53 
     54 pub type MessageId = u32;
     55 
     56 #[derive(Debug, Clone, PartialEq)]
     57 pub struct Request(pub MessageId, pub Command);
     58 
     59 impl Request {
     60    pub fn id(&self) -> MessageId {
     61        self.0
     62    }
     63 
     64    pub fn command(&self) -> &Command {
     65        &self.1
     66    }
     67 
     68    pub fn params(&self) -> Value {
     69        self.command().params()
     70    }
     71 }
     72 
     73 impl Serialize for Request {
     74    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
     75    where
     76        S: Serializer,
     77    {
     78        (
     79            MessageDirection::Incoming,
     80            self.id(),
     81            self.command().name(),
     82            self.params(),
     83        )
     84            .serialize(serializer)
     85    }
     86 }
     87 
     88 #[derive(Debug, PartialEq)]
     89 pub enum Response {
     90    Result {
     91        id: MessageId,
     92        result: MarionetteResult,
     93    },
     94    Error {
     95        id: MessageId,
     96        error: MarionetteError,
     97    },
     98 }
     99 
    100 impl Serialize for Response {
    101    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    102    where
    103        S: Serializer,
    104    {
    105        match self {
    106            Response::Result { id, result } => {
    107                (MessageDirection::Outgoing, id, Value::Null, &result).serialize(serializer)
    108            }
    109            Response::Error { id, error } => {
    110                (MessageDirection::Outgoing, id, &error, Value::Null).serialize(serializer)
    111            }
    112        }
    113    }
    114 }
    115 
    116 #[derive(Debug, PartialEq, Serialize)]
    117 #[serde(untagged)]
    118 pub enum Message {
    119    Incoming(Request),
    120    Outgoing(Response),
    121 }
    122 
    123 struct MessageVisitor;
    124 
    125 impl<'de> Visitor<'de> for MessageVisitor {
    126    type Value = Message;
    127 
    128    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
    129        formatter.write_str("four-element array")
    130    }
    131 
    132    fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
    133        let direction = seq
    134            .next_element::<MessageDirection>()?
    135            .ok_or_else(|| de::Error::invalid_length(0, &self))?;
    136        let id: MessageId = seq
    137            .next_element()?
    138            .ok_or_else(|| de::Error::invalid_length(1, &self))?;
    139 
    140        let msg = match direction {
    141            MessageDirection::Incoming => {
    142                let name: String = seq
    143                    .next_element()?
    144                    .ok_or_else(|| de::Error::invalid_length(2, &self))?;
    145                let params: Value = seq
    146                    .next_element()?
    147                    .ok_or_else(|| de::Error::invalid_length(3, &self))?;
    148 
    149                let command = match params {
    150                    Value::Object(ref items) if !items.is_empty() => {
    151                        let command_to_params = {
    152                            let mut m = Map::new();
    153                            m.insert(name, params);
    154                            Value::Object(m)
    155                        };
    156                        serde_json::from_value(command_to_params).map_err(de::Error::custom)
    157                    }
    158                    Value::Object(_) | Value::Null => {
    159                        serde_json::from_value(Value::String(name)).map_err(de::Error::custom)
    160                    }
    161                    x => Err(de::Error::custom(format!("unknown params type: {}", x))),
    162                }?;
    163                Message::Incoming(Request(id, command))
    164            }
    165 
    166            MessageDirection::Outgoing => {
    167                let maybe_error: Option<MarionetteError> = seq
    168                    .next_element()?
    169                    .ok_or_else(|| de::Error::invalid_length(2, &self))?;
    170 
    171                let response = if let Some(error) = maybe_error {
    172                    seq.next_element::<Value>()?
    173                        .ok_or_else(|| de::Error::invalid_length(3, &self))?
    174                        .as_null()
    175                        .ok_or_else(|| de::Error::invalid_type(Unexpected::Unit, &self))?;
    176                    Response::Error { id, error }
    177                } else {
    178                    let result: MarionetteResult = seq
    179                        .next_element()?
    180                        .ok_or_else(|| de::Error::invalid_length(3, &self))?;
    181                    Response::Result { id, result }
    182                };
    183 
    184                Message::Outgoing(response)
    185            }
    186        };
    187 
    188        Ok(msg)
    189    }
    190 }
    191 
    192 impl<'de> Deserialize<'de> for Message {
    193    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    194    where
    195        D: Deserializer<'de>,
    196    {
    197        deserializer.deserialize_seq(MessageVisitor)
    198    }
    199 }
    200 
    201 #[cfg(test)]
    202 mod tests {
    203    use serde_json::json;
    204 
    205    use super::*;
    206 
    207    use crate::common::*;
    208    use crate::error::{ErrorKind, MarionetteError};
    209    use crate::test::assert_ser_de;
    210 
    211    #[test]
    212    fn test_incoming() {
    213        let json =
    214            json!([0, 42, "WebDriver:FindElement", {"using": "css selector", "value": "value"}]);
    215        let find_element = webdriver::Command::FindElement(webdriver::Locator {
    216            element: None,
    217            using: webdriver::Selector::Css,
    218            value: "value".into(),
    219        });
    220        let req = Request(42, Command::WebDriver(find_element));
    221        let msg = Message::Incoming(req);
    222        assert_ser_de(&msg, json);
    223    }
    224 
    225    #[test]
    226    fn test_incoming_empty_params() {
    227        let json = json!([0, 42, "WebDriver:GetTimeouts", {}]);
    228        let req = Request(42, Command::WebDriver(webdriver::Command::GetTimeouts));
    229        let msg = Message::Incoming(req);
    230        assert_ser_de(&msg, json);
    231    }
    232 
    233    #[test]
    234    fn test_incoming_common_params() {
    235        let json = json!([0, 42, "Marionette:AcceptConnections", {"value": false}]);
    236        let params = BoolValue::new(false);
    237        let req = Request(
    238            42,
    239            Command::Marionette(marionette::Command::AcceptConnections(params)),
    240        );
    241        let msg = Message::Incoming(req);
    242        assert_ser_de(&msg, json);
    243    }
    244 
    245    #[test]
    246    fn test_incoming_params_derived() {
    247        assert!(serde_json::from_value::<Message>(
    248            json!([0,42,"WebDriver:FindElement",{"using":"foo","value":"foo"}])
    249        )
    250        .is_err());
    251        assert!(serde_json::from_value::<Message>(
    252            json!([0,42,"Marionette:AcceptConnections",{"value":"foo"}])
    253        )
    254        .is_err());
    255    }
    256 
    257    #[test]
    258    fn test_incoming_no_params() {
    259        assert!(serde_json::from_value::<Message>(
    260            json!([0,42,"WebDriver:GetTimeouts",{"value":true}])
    261        )
    262        .is_err());
    263        assert!(serde_json::from_value::<Message>(
    264            json!([0,42,"Marionette:Context",{"value":"foo"}])
    265        )
    266        .is_err());
    267        assert!(serde_json::from_value::<Message>(
    268            json!([0,42,"Marionette:GetScreenOrientation",{"value":true}])
    269        )
    270        .is_err());
    271    }
    272 
    273    #[test]
    274    fn test_outgoing_result() {
    275        let json = json!([1, 42, null, { "value": null }]);
    276        let result = MarionetteResult::Null;
    277        let msg = Message::Outgoing(Response::Result { id: 42, result });
    278 
    279        assert_ser_de(&msg, json);
    280    }
    281 
    282    #[test]
    283    fn test_outgoing_error() {
    284        let json =
    285            json!([1, 42, {"error": "no such element", "message": "", "stacktrace": ""}, null]);
    286        let error = MarionetteError {
    287            kind: ErrorKind::NoSuchElement,
    288            message: "".into(),
    289            stack: "".into(),
    290        };
    291        let msg = Message::Outgoing(Response::Error { id: 42, error });
    292 
    293        assert_ser_de(&msg, json);
    294    }
    295 
    296    #[test]
    297    fn test_invalid_type() {
    298        assert!(
    299            serde_json::from_value::<Message>(json!([2, 42, "WebDriver:GetTimeouts", {}])).is_err()
    300        );
    301        assert!(serde_json::from_value::<Message>(json!([3, 42, "no such element", {}])).is_err());
    302    }
    303 
    304    #[test]
    305    fn test_missing_fields() {
    306        // all fields are required
    307        assert!(
    308            serde_json::from_value::<Message>(json!([2, 42, "WebDriver:GetTimeouts"])).is_err()
    309        );
    310        assert!(serde_json::from_value::<Message>(json!([2, 42])).is_err());
    311        assert!(serde_json::from_value::<Message>(json!([2])).is_err());
    312        assert!(serde_json::from_value::<Message>(json!([])).is_err());
    313    }
    314 
    315    #[test]
    316    fn test_unknown_command() {
    317        assert!(serde_json::from_value::<Message>(json!([0, 42, "hooba", {}])).is_err());
    318    }
    319 
    320    #[test]
    321    fn test_unknown_error() {
    322        assert!(serde_json::from_value::<Message>(json!([1, 42, "flooba", {}])).is_err());
    323    }
    324 
    325    #[test]
    326    fn test_message_id_bounds() {
    327        let overflow = i64::from(std::u32::MAX) + 1;
    328        let underflow = -1;
    329 
    330        fn get_timeouts(message_id: i64) -> Value {
    331            json!([0, message_id, "WebDriver:GetTimeouts", {}])
    332        }
    333 
    334        assert!(serde_json::from_value::<Message>(get_timeouts(overflow)).is_err());
    335        assert!(serde_json::from_value::<Message>(get_timeouts(underflow)).is_err());
    336    }
    337 }