message.rs (10290B)
1 /* This Source Code Form is subject to the terms of the Mozilla Public 2 * License, v. 2.0. If a copy of the MPL was not distributed with this 3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ 4 5 use serde::de::{self, SeqAccess, Unexpected, Visitor}; 6 use serde::{Deserialize, Deserializer, Serialize, Serializer}; 7 use serde_json::{Map, Value}; 8 use serde_repr::{Deserialize_repr, Serialize_repr}; 9 use std::fmt; 10 11 use crate::error::MarionetteError; 12 use crate::marionette; 13 use crate::result::MarionetteResult; 14 use crate::webdriver; 15 16 #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 17 #[serde(untagged)] 18 pub enum Command { 19 WebDriver(webdriver::Command), 20 Marionette(marionette::Command), 21 } 22 23 impl Command { 24 pub fn name(&self) -> String { 25 let (command_name, _) = self.first_entry(); 26 command_name 27 } 28 29 fn params(&self) -> Value { 30 let (_, params) = self.first_entry(); 31 params 32 } 33 34 fn first_entry(&self) -> (String, serde_json::Value) { 35 match serde_json::to_value(self).unwrap() { 36 Value::String(cmd) => (cmd, Value::Object(Map::new())), 37 Value::Object(items) => { 38 let mut iter = items.iter(); 39 let (cmd, params) = iter.next().unwrap(); 40 (cmd.to_string(), params.clone()) 41 } 42 _ => unreachable!(), 43 } 44 } 45 } 46 47 #[derive(Clone, Debug, PartialEq, Serialize_repr, Deserialize_repr)] 48 #[repr(u8)] 49 enum MessageDirection { 50 Incoming = 0, 51 Outgoing = 1, 52 } 53 54 pub type MessageId = u32; 55 56 #[derive(Debug, Clone, PartialEq)] 57 pub struct Request(pub MessageId, pub Command); 58 59 impl Request { 60 pub fn id(&self) -> MessageId { 61 self.0 62 } 63 64 pub fn command(&self) -> &Command { 65 &self.1 66 } 67 68 pub fn params(&self) -> Value { 69 self.command().params() 70 } 71 } 72 73 impl Serialize for Request { 74 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> 75 where 76 S: Serializer, 77 { 78 ( 79 MessageDirection::Incoming, 80 self.id(), 81 self.command().name(), 82 self.params(), 83 ) 84 .serialize(serializer) 85 } 86 } 87 88 #[derive(Debug, PartialEq)] 89 pub enum Response { 90 Result { 91 id: MessageId, 92 result: MarionetteResult, 93 }, 94 Error { 95 id: MessageId, 96 error: MarionetteError, 97 }, 98 } 99 100 impl Serialize for Response { 101 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> 102 where 103 S: Serializer, 104 { 105 match self { 106 Response::Result { id, result } => { 107 (MessageDirection::Outgoing, id, Value::Null, &result).serialize(serializer) 108 } 109 Response::Error { id, error } => { 110 (MessageDirection::Outgoing, id, &error, Value::Null).serialize(serializer) 111 } 112 } 113 } 114 } 115 116 #[derive(Debug, PartialEq, Serialize)] 117 #[serde(untagged)] 118 pub enum Message { 119 Incoming(Request), 120 Outgoing(Response), 121 } 122 123 struct MessageVisitor; 124 125 impl<'de> Visitor<'de> for MessageVisitor { 126 type Value = Message; 127 128 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 129 formatter.write_str("four-element array") 130 } 131 132 fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> { 133 let direction = seq 134 .next_element::<MessageDirection>()? 135 .ok_or_else(|| de::Error::invalid_length(0, &self))?; 136 let id: MessageId = seq 137 .next_element()? 138 .ok_or_else(|| de::Error::invalid_length(1, &self))?; 139 140 let msg = match direction { 141 MessageDirection::Incoming => { 142 let name: String = seq 143 .next_element()? 144 .ok_or_else(|| de::Error::invalid_length(2, &self))?; 145 let params: Value = seq 146 .next_element()? 147 .ok_or_else(|| de::Error::invalid_length(3, &self))?; 148 149 let command = match params { 150 Value::Object(ref items) if !items.is_empty() => { 151 let command_to_params = { 152 let mut m = Map::new(); 153 m.insert(name, params); 154 Value::Object(m) 155 }; 156 serde_json::from_value(command_to_params).map_err(de::Error::custom) 157 } 158 Value::Object(_) | Value::Null => { 159 serde_json::from_value(Value::String(name)).map_err(de::Error::custom) 160 } 161 x => Err(de::Error::custom(format!("unknown params type: {}", x))), 162 }?; 163 Message::Incoming(Request(id, command)) 164 } 165 166 MessageDirection::Outgoing => { 167 let maybe_error: Option<MarionetteError> = seq 168 .next_element()? 169 .ok_or_else(|| de::Error::invalid_length(2, &self))?; 170 171 let response = if let Some(error) = maybe_error { 172 seq.next_element::<Value>()? 173 .ok_or_else(|| de::Error::invalid_length(3, &self))? 174 .as_null() 175 .ok_or_else(|| de::Error::invalid_type(Unexpected::Unit, &self))?; 176 Response::Error { id, error } 177 } else { 178 let result: MarionetteResult = seq 179 .next_element()? 180 .ok_or_else(|| de::Error::invalid_length(3, &self))?; 181 Response::Result { id, result } 182 }; 183 184 Message::Outgoing(response) 185 } 186 }; 187 188 Ok(msg) 189 } 190 } 191 192 impl<'de> Deserialize<'de> for Message { 193 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> 194 where 195 D: Deserializer<'de>, 196 { 197 deserializer.deserialize_seq(MessageVisitor) 198 } 199 } 200 201 #[cfg(test)] 202 mod tests { 203 use serde_json::json; 204 205 use super::*; 206 207 use crate::common::*; 208 use crate::error::{ErrorKind, MarionetteError}; 209 use crate::test::assert_ser_de; 210 211 #[test] 212 fn test_incoming() { 213 let json = 214 json!([0, 42, "WebDriver:FindElement", {"using": "css selector", "value": "value"}]); 215 let find_element = webdriver::Command::FindElement(webdriver::Locator { 216 element: None, 217 using: webdriver::Selector::Css, 218 value: "value".into(), 219 }); 220 let req = Request(42, Command::WebDriver(find_element)); 221 let msg = Message::Incoming(req); 222 assert_ser_de(&msg, json); 223 } 224 225 #[test] 226 fn test_incoming_empty_params() { 227 let json = json!([0, 42, "WebDriver:GetTimeouts", {}]); 228 let req = Request(42, Command::WebDriver(webdriver::Command::GetTimeouts)); 229 let msg = Message::Incoming(req); 230 assert_ser_de(&msg, json); 231 } 232 233 #[test] 234 fn test_incoming_common_params() { 235 let json = json!([0, 42, "Marionette:AcceptConnections", {"value": false}]); 236 let params = BoolValue::new(false); 237 let req = Request( 238 42, 239 Command::Marionette(marionette::Command::AcceptConnections(params)), 240 ); 241 let msg = Message::Incoming(req); 242 assert_ser_de(&msg, json); 243 } 244 245 #[test] 246 fn test_incoming_params_derived() { 247 assert!(serde_json::from_value::<Message>( 248 json!([0,42,"WebDriver:FindElement",{"using":"foo","value":"foo"}]) 249 ) 250 .is_err()); 251 assert!(serde_json::from_value::<Message>( 252 json!([0,42,"Marionette:AcceptConnections",{"value":"foo"}]) 253 ) 254 .is_err()); 255 } 256 257 #[test] 258 fn test_incoming_no_params() { 259 assert!(serde_json::from_value::<Message>( 260 json!([0,42,"WebDriver:GetTimeouts",{"value":true}]) 261 ) 262 .is_err()); 263 assert!(serde_json::from_value::<Message>( 264 json!([0,42,"Marionette:Context",{"value":"foo"}]) 265 ) 266 .is_err()); 267 assert!(serde_json::from_value::<Message>( 268 json!([0,42,"Marionette:GetScreenOrientation",{"value":true}]) 269 ) 270 .is_err()); 271 } 272 273 #[test] 274 fn test_outgoing_result() { 275 let json = json!([1, 42, null, { "value": null }]); 276 let result = MarionetteResult::Null; 277 let msg = Message::Outgoing(Response::Result { id: 42, result }); 278 279 assert_ser_de(&msg, json); 280 } 281 282 #[test] 283 fn test_outgoing_error() { 284 let json = 285 json!([1, 42, {"error": "no such element", "message": "", "stacktrace": ""}, null]); 286 let error = MarionetteError { 287 kind: ErrorKind::NoSuchElement, 288 message: "".into(), 289 stack: "".into(), 290 }; 291 let msg = Message::Outgoing(Response::Error { id: 42, error }); 292 293 assert_ser_de(&msg, json); 294 } 295 296 #[test] 297 fn test_invalid_type() { 298 assert!( 299 serde_json::from_value::<Message>(json!([2, 42, "WebDriver:GetTimeouts", {}])).is_err() 300 ); 301 assert!(serde_json::from_value::<Message>(json!([3, 42, "no such element", {}])).is_err()); 302 } 303 304 #[test] 305 fn test_missing_fields() { 306 // all fields are required 307 assert!( 308 serde_json::from_value::<Message>(json!([2, 42, "WebDriver:GetTimeouts"])).is_err() 309 ); 310 assert!(serde_json::from_value::<Message>(json!([2, 42])).is_err()); 311 assert!(serde_json::from_value::<Message>(json!([2])).is_err()); 312 assert!(serde_json::from_value::<Message>(json!([])).is_err()); 313 } 314 315 #[test] 316 fn test_unknown_command() { 317 assert!(serde_json::from_value::<Message>(json!([0, 42, "hooba", {}])).is_err()); 318 } 319 320 #[test] 321 fn test_unknown_error() { 322 assert!(serde_json::from_value::<Message>(json!([1, 42, "flooba", {}])).is_err()); 323 } 324 325 #[test] 326 fn test_message_id_bounds() { 327 let overflow = i64::from(std::u32::MAX) + 1; 328 let underflow = -1; 329 330 fn get_timeouts(message_id: i64) -> Value { 331 json!([0, message_id, "WebDriver:GetTimeouts", {}]) 332 } 333 334 assert!(serde_json::from_value::<Message>(get_timeouts(overflow)).is_err()); 335 assert!(serde_json::from_value::<Message>(get_timeouts(underflow)).is_err()); 336 } 337 }