Buttplug sex toy control library
1use buttplug_core::{ 2 errors::{ButtplugError, ButtplugHandshakeError, ButtplugMessageError}, 3 message::{ 4 self, 5 ButtplugClientMessageV4, 6 ButtplugMessageFinalizer, 7 ButtplugMessageSpecVersion, 8 ButtplugServerMessageCurrent, 9 ButtplugServerMessageV4, 10 serializer::{ 11 ButtplugMessageSerializer, 12 ButtplugSerializedMessage, 13 ButtplugSerializerError, 14 json_serializer::{ 15 create_message_validator, 16 deserialize_to_message, 17 msg_to_protocol_json, 18 vec_to_protocol_json, 19 }, 20 }, 21 }, 22}; 23use jsonschema::Validator; 24use once_cell::sync::OnceCell; 25use serde::Deserialize; 26 27use super::{ 28 ButtplugClientMessageV0, 29 ButtplugClientMessageV1, 30 ButtplugClientMessageV2, 31 ButtplugClientMessageV3, 32 ButtplugClientMessageVariant, 33 ButtplugServerMessageV0, 34 ButtplugServerMessageV1, 35 ButtplugServerMessageV2, 36 ButtplugServerMessageV3, 37 ButtplugServerMessageVariant, 38}; 39 40#[derive(Deserialize, ButtplugMessageFinalizer, Clone, Debug)] 41struct RequestServerInfoMessage { 42 #[serde(rename = "RequestServerInfo")] 43 rsi: RequestServerInfoVersion, 44} 45 46#[derive(Deserialize, ButtplugMessageFinalizer, Clone, Debug)] 47struct RequestServerInfoVersion { 48 #[serde(rename = "Id")] 49 _id: u32, 50 #[serde(rename = "ClientName")] 51 _client_name: String, 52 #[serde(default, rename = "MessageVersion")] 53 message_version: Option<u32>, 54 #[serde(default, rename = "ProtocolVersionMajor")] 55 api_major_version: Option<u32>, 56} 57 58pub struct ButtplugServerJSONSerializer { 59 pub(super) message_version: OnceCell<message::ButtplugMessageSpecVersion>, 60 validator: Validator, 61} 62 63impl Default for ButtplugServerJSONSerializer { 64 fn default() -> Self { 65 Self { 66 message_version: OnceCell::new(), 67 validator: create_message_validator(), 68 } 69 } 70} 71 72impl ButtplugServerJSONSerializer { 73 pub fn force_message_version(&self, version: &ButtplugMessageSpecVersion) { 74 self 75 .message_version 76 .set(*version) 77 .expect("This should only ever be called once."); 78 } 79} 80 81impl ButtplugMessageSerializer for ButtplugServerJSONSerializer { 82 type Inbound = ButtplugClientMessageVariant; 83 type Outbound = ButtplugServerMessageVariant; 84 85 fn deserialize( 86 &self, 87 serialized_msg: &ButtplugSerializedMessage, 88 ) -> Result<Vec<ButtplugClientMessageVariant>, ButtplugSerializerError> { 89 let msg = if let ButtplugSerializedMessage::Text(text_msg) = serialized_msg { 90 text_msg 91 } else { 92 return Err(ButtplugSerializerError::BinaryDeserializationError); 93 }; 94 95 if let Some(version) = self.message_version.get() { 96 return Ok(match version { 97 ButtplugMessageSpecVersion::Version0 => { 98 deserialize_to_message::<ButtplugClientMessageV0>(Some(&self.validator), msg)? 99 .iter() 100 .cloned() 101 .map(|m| m.into()) 102 .collect() 103 } 104 ButtplugMessageSpecVersion::Version1 => { 105 deserialize_to_message::<ButtplugClientMessageV1>(Some(&self.validator), msg)? 106 .iter() 107 .cloned() 108 .map(|m| m.into()) 109 .collect() 110 } 111 ButtplugMessageSpecVersion::Version2 => { 112 deserialize_to_message::<ButtplugClientMessageV2>(Some(&self.validator), msg)? 113 .iter() 114 .cloned() 115 .map(|m| m.into()) 116 .collect() 117 } 118 ButtplugMessageSpecVersion::Version3 => { 119 deserialize_to_message::<ButtplugClientMessageV3>(Some(&self.validator), msg)? 120 .iter() 121 .cloned() 122 .map(|m| m.into()) 123 .collect() 124 } 125 ButtplugMessageSpecVersion::Version4 => { 126 deserialize_to_message::<ButtplugClientMessageV4>(Some(&self.validator), msg)? 127 .iter() 128 .cloned() 129 .map(|m| m.into()) 130 .collect() 131 } 132 }); 133 } 134 // If we don't have a message version yet, we need to parse this as a RequestServerInfo message 135 // to get the version. As of v4, RequestServerInfo is of a different layout than RSI v0-v3, 136 // therefore we need to step through versions for compatibility sake. 137 info!("{:?}", msg); 138 let msg_version = 139 if let Ok(msg_union) = deserialize_to_message::<RequestServerInfoMessage>(None, msg) { 140 info!("PARSING {:?}", msg_union); 141 if msg_union.is_empty() { 142 Err(ButtplugSerializerError::MessageSpecVersionNotReceived) 143 } else if let Some(v) = msg_union[0].rsi.api_major_version { 144 ButtplugMessageSpecVersion::try_from(v as i32) 145 .map_err(|_| ButtplugSerializerError::MessageSpecVersionNotReceived) 146 } else if let Some(v) = msg_union[0].rsi.message_version { 147 ButtplugMessageSpecVersion::try_from(v as i32) 148 .map_err(|_| ButtplugSerializerError::MessageSpecVersionNotReceived) 149 } else { 150 Ok(ButtplugMessageSpecVersion::Version0) 151 } 152 } else { 153 info!("NOT EVEN PARSING"); 154 Err(ButtplugSerializerError::MessageSpecVersionNotReceived) 155 }?; 156 157 info!("Setting JSON Wrapper message version to {}", msg_version); 158 self 159 .message_version 160 .set(msg_version) 161 .expect("This should only ever be called once."); 162 // Now that we know our version, parse the message again. 163 self.deserialize(serialized_msg) 164 } 165 166 fn serialize(&self, msgs: &[ButtplugServerMessageVariant]) -> ButtplugSerializedMessage { 167 if let Some(version) = self.message_version.get() { 168 ButtplugSerializedMessage::Text(match version { 169 ButtplugMessageSpecVersion::Version0 => { 170 let msg_vec: Vec<ButtplugServerMessageV0> = msgs 171 .iter() 172 .map(|msg| match msg { 173 ButtplugServerMessageVariant::V0(msgv0) => msgv0.clone(), 174 _ => ButtplugServerMessageV0::Error(message::ErrorV0::from(ButtplugError::from( 175 ButtplugMessageError::MessageConversionError(format!( 176 "Message {msg:?} not in Spec V0! This is a server bug." 177 )), 178 ))), 179 }) 180 .collect(); 181 vec_to_protocol_json(&msg_vec) 182 } 183 ButtplugMessageSpecVersion::Version1 => { 184 let msg_vec: Vec<ButtplugServerMessageV1> = msgs 185 .iter() 186 .map(|msg| match msg { 187 ButtplugServerMessageVariant::V1(msgv1) => msgv1.clone(), 188 _ => ButtplugServerMessageV1::Error(message::ErrorV0::from(ButtplugError::from( 189 ButtplugMessageError::MessageConversionError(format!( 190 "Message {msg:?} not in Spec V1! This is a server bug." 191 )), 192 ))), 193 }) 194 .collect(); 195 vec_to_protocol_json(&msg_vec) 196 } 197 ButtplugMessageSpecVersion::Version2 => { 198 let msg_vec: Vec<ButtplugServerMessageV2> = msgs 199 .iter() 200 .map(|msg| match msg { 201 ButtplugServerMessageVariant::V2(msgv2) => msgv2.clone(), 202 _ => ButtplugServerMessageV2::Error(message::ErrorV0::from(ButtplugError::from( 203 ButtplugMessageError::MessageConversionError(format!( 204 "Message {msg:?} not in Spec V2! This is a server bug." 205 )), 206 ))), 207 }) 208 .collect(); 209 vec_to_protocol_json(&msg_vec) 210 } 211 ButtplugMessageSpecVersion::Version3 => { 212 let msg_vec: Vec<ButtplugServerMessageV3> = msgs 213 .iter() 214 .map(|msg| match msg { 215 ButtplugServerMessageVariant::V3(msgv3) => msgv3.clone(), 216 _ => ButtplugServerMessageV3::Error(message::ErrorV0::from(ButtplugError::from( 217 ButtplugMessageError::MessageConversionError(format!( 218 "Message {msg:?} not in Spec V3! This is a server bug." 219 )), 220 ))), 221 }) 222 .collect(); 223 vec_to_protocol_json(&msg_vec) 224 } 225 ButtplugMessageSpecVersion::Version4 => { 226 let msg_vec: Vec<ButtplugServerMessageV4> = msgs 227 .iter() 228 .map(|msg| match msg { 229 ButtplugServerMessageVariant::V4(msgv4) => msgv4.clone(), 230 _ => ButtplugServerMessageV4::Error(message::ErrorV0::from(ButtplugError::from( 231 ButtplugMessageError::MessageConversionError(format!( 232 "Message {msg:?} not in Spec V4! This is a server bug." 233 )), 234 ))), 235 }) 236 .collect(); 237 vec_to_protocol_json(&msg_vec) 238 } 239 }) 240 } else { 241 // If we don't even have enough info to know which message 242 // version to convert to, consider this a handshake error. 243 ButtplugSerializedMessage::Text(msg_to_protocol_json(ButtplugServerMessageCurrent::Error( 244 ButtplugError::from(ButtplugHandshakeError::RequestServerInfoExpected).into(), 245 ))) 246 } 247 } 248} 249 250#[cfg(test)] 251mod test { 252 use super::*; 253 254 #[test] 255 fn test_correct_message_version() { 256 let json = r#"[{ 257 "RequestServerInfo": { 258 "Id": 1, 259 "ClientName": "Test Client", 260 "ProtocolVersionMajor": 4, 261 "ProtocolVersionMinor": 0 262 } 263 }]"#; 264 let serializer = ButtplugServerJSONSerializer::default(); 265 serializer 266 .deserialize(&ButtplugSerializedMessage::Text(json.to_owned())) 267 .expect("Infallible deserialization"); 268 assert_eq!( 269 *serializer.message_version.get().unwrap(), 270 ButtplugMessageSpecVersion::Version4 271 ); 272 } 273 274 #[test] 275 fn test_wrong_message_version() { 276 let json = r#"[{ 277 "RequestServerInfo": { 278 "Id": 1, 279 "ClientName": "Test Client", 280 "ProtocolVersionMajor": 100, 281 "ProtocolVersionMinor": 0 282 } 283 }]"#; 284 let serializer = ButtplugServerJSONSerializer::default(); 285 let msg = serializer.deserialize(&ButtplugSerializedMessage::Text(json.to_owned())); 286 info!("{:?}", msg); 287 assert!(msg.is_err()); 288 } 289 290 #[test] 291 fn test_message_array() { 292 let json = r#"[ 293 { 294 "RequestServerInfo": { 295 "Id": 1, 296 "ClientName": "Test Client", 297 "ProtocolVersionMajor": 4, 298 "ProtocolVersionMinor": 0 299 } 300 }, 301 { 302 "RequestServerInfo": { 303 "Id": 1, 304 "ClientName": "Test Client", 305 "ProtocolVersionMajor": 4, 306 "ProtocolVersionMinor": 0 307 } 308 }, 309 { 310 "RequestServerInfo": { 311 "Id": 1, 312 "ClientName": "Test Client", 313 "ProtocolVersionMajor": 4, 314 "ProtocolVersionMinor": 0 315 } 316 } 317 ]"#; 318 let serializer = ButtplugServerJSONSerializer::default(); 319 let messages = serializer 320 .deserialize(&ButtplugSerializedMessage::Text(json.to_owned())) 321 .expect("Infallible deserialization"); 322 assert_eq!(messages.len(), 3); 323 } 324 325 #[test] 326 fn test_streamed_message_array() { 327 let json = r#"[ 328 { 329 "RequestServerInfo": { 330 "Id": 1, 331 "ClientName": "Test Client", 332 "ProtocolVersionMajor": 4, 333 "ProtocolVersionMinor": 0 334 } 335 }] 336 [{ 337 "RequestServerInfo": { 338 "Id": 1, 339 "ClientName": "Test Client", 340 "ProtocolVersionMajor": 4, 341 "ProtocolVersionMinor": 0 342 } 343 }] 344 [{ 345 "RequestServerInfo": { 346 "Id": 1, 347 "ClientName": "Test Client", 348 "ProtocolVersionMajor": 4, 349 "ProtocolVersionMinor": 0 350 } 351 }] 352 "#; 353 let serializer = ButtplugServerJSONSerializer::default(); 354 let messages = serializer 355 .deserialize(&ButtplugSerializedMessage::Text(json.to_owned())) 356 .expect("Infallible deserialization"); 357 assert_eq!(messages.len(), 3); 358 } 359 360 #[test] 361 fn test_invalid_streamed_message_array() { 362 // Missing a } in the second message. 363 let json = r#"[ 364 "RequestServerInfo": { 365 "Id": 1, 366 "ClientName": "Test Client", 367 "ProtocolVersionMajor": 4, 368 "ProtocolVersionMinor": 0 369 } 370 }] 371 [{ 372 "RequestServerInfo": { 373 "Id": 1, 374 "ClientName": "Test Client", 375 "ProtocolVersionMajor": 4, 376 "ProtocolVersionMinor": 0 377 }] 378 [{ 379 "RequestServerInfo": { 380 "Id": 1, 381 "ClientName": "Test Client", 382 "ProtocolVersionMajor": 4, 383 "ProtocolVersionMinor": 0 384 } 385 }] 386 "#; 387 let serializer = ButtplugServerJSONSerializer::default(); 388 assert!(matches!( 389 serializer.deserialize(&ButtplugSerializedMessage::Text(json.to_owned())), 390 Err(_) 391 )); 392 } 393}