Buttplug sex toy control library
1use buttplug_core::{
2 errors::{ButtplugError, ButtplugHandshakeError, ButtplugMessageError},
3 message::{
4 self,
5 ButtplugClientMessageV4,
6 ButtplugMessageFinalizer,
7 ButtplugMessageSpecVersion,
8 ButtplugServerMessageCurrent,
9 ButtplugServerMessageV4,
10 serializer::{
11 ButtplugMessageSerializer,
12 ButtplugSerializedMessage,
13 ButtplugSerializerError,
14 json_serializer::{
15 create_message_validator,
16 deserialize_to_message,
17 msg_to_protocol_json,
18 vec_to_protocol_json,
19 },
20 },
21 },
22};
23use jsonschema::Validator;
24use once_cell::sync::OnceCell;
25use serde::Deserialize;
26
27use super::{
28 ButtplugClientMessageV0,
29 ButtplugClientMessageV1,
30 ButtplugClientMessageV2,
31 ButtplugClientMessageV3,
32 ButtplugClientMessageVariant,
33 ButtplugServerMessageV0,
34 ButtplugServerMessageV1,
35 ButtplugServerMessageV2,
36 ButtplugServerMessageV3,
37 ButtplugServerMessageVariant,
38};
39
40#[derive(Deserialize, ButtplugMessageFinalizer, Clone, Debug)]
41struct RequestServerInfoMessage {
42 #[serde(rename = "RequestServerInfo")]
43 rsi: RequestServerInfoVersion,
44}
45
46#[derive(Deserialize, ButtplugMessageFinalizer, Clone, Debug)]
47struct RequestServerInfoVersion {
48 #[serde(rename = "Id")]
49 _id: u32,
50 #[serde(rename = "ClientName")]
51 _client_name: String,
52 #[serde(default, rename = "MessageVersion")]
53 message_version: Option<u32>,
54 #[serde(default, rename = "ProtocolVersionMajor")]
55 api_major_version: Option<u32>,
56}
57
58pub struct ButtplugServerJSONSerializer {
59 pub(super) message_version: OnceCell<message::ButtplugMessageSpecVersion>,
60 validator: Validator,
61}
62
63impl Default for ButtplugServerJSONSerializer {
64 fn default() -> Self {
65 Self {
66 message_version: OnceCell::new(),
67 validator: create_message_validator(),
68 }
69 }
70}
71
72impl ButtplugServerJSONSerializer {
73 pub fn force_message_version(&self, version: &ButtplugMessageSpecVersion) {
74 self
75 .message_version
76 .set(*version)
77 .expect("This should only ever be called once.");
78 }
79}
80
81impl ButtplugMessageSerializer for ButtplugServerJSONSerializer {
82 type Inbound = ButtplugClientMessageVariant;
83 type Outbound = ButtplugServerMessageVariant;
84
85 fn deserialize(
86 &self,
87 serialized_msg: &ButtplugSerializedMessage,
88 ) -> Result<Vec<ButtplugClientMessageVariant>, ButtplugSerializerError> {
89 let msg = if let ButtplugSerializedMessage::Text(text_msg) = serialized_msg {
90 text_msg
91 } else {
92 return Err(ButtplugSerializerError::BinaryDeserializationError);
93 };
94
95 if let Some(version) = self.message_version.get() {
96 return Ok(match version {
97 ButtplugMessageSpecVersion::Version0 => {
98 deserialize_to_message::<ButtplugClientMessageV0>(Some(&self.validator), msg)?
99 .iter()
100 .cloned()
101 .map(|m| m.into())
102 .collect()
103 }
104 ButtplugMessageSpecVersion::Version1 => {
105 deserialize_to_message::<ButtplugClientMessageV1>(Some(&self.validator), msg)?
106 .iter()
107 .cloned()
108 .map(|m| m.into())
109 .collect()
110 }
111 ButtplugMessageSpecVersion::Version2 => {
112 deserialize_to_message::<ButtplugClientMessageV2>(Some(&self.validator), msg)?
113 .iter()
114 .cloned()
115 .map(|m| m.into())
116 .collect()
117 }
118 ButtplugMessageSpecVersion::Version3 => {
119 deserialize_to_message::<ButtplugClientMessageV3>(Some(&self.validator), msg)?
120 .iter()
121 .cloned()
122 .map(|m| m.into())
123 .collect()
124 }
125 ButtplugMessageSpecVersion::Version4 => {
126 deserialize_to_message::<ButtplugClientMessageV4>(Some(&self.validator), msg)?
127 .iter()
128 .cloned()
129 .map(|m| m.into())
130 .collect()
131 }
132 });
133 }
134 // If we don't have a message version yet, we need to parse this as a RequestServerInfo message
135 // to get the version. As of v4, RequestServerInfo is of a different layout than RSI v0-v3,
136 // therefore we need to step through versions for compatibility sake.
137 info!("{:?}", msg);
138 let msg_version =
139 if let Ok(msg_union) = deserialize_to_message::<RequestServerInfoMessage>(None, msg) {
140 info!("PARSING {:?}", msg_union);
141 if msg_union.is_empty() {
142 Err(ButtplugSerializerError::MessageSpecVersionNotReceived)
143 } else if let Some(v) = msg_union[0].rsi.api_major_version {
144 ButtplugMessageSpecVersion::try_from(v as i32)
145 .map_err(|_| ButtplugSerializerError::MessageSpecVersionNotReceived)
146 } else if let Some(v) = msg_union[0].rsi.message_version {
147 ButtplugMessageSpecVersion::try_from(v as i32)
148 .map_err(|_| ButtplugSerializerError::MessageSpecVersionNotReceived)
149 } else {
150 Ok(ButtplugMessageSpecVersion::Version0)
151 }
152 } else {
153 info!("NOT EVEN PARSING");
154 Err(ButtplugSerializerError::MessageSpecVersionNotReceived)
155 }?;
156
157 info!("Setting JSON Wrapper message version to {}", msg_version);
158 self
159 .message_version
160 .set(msg_version)
161 .expect("This should only ever be called once.");
162 // Now that we know our version, parse the message again.
163 self.deserialize(serialized_msg)
164 }
165
166 fn serialize(&self, msgs: &[ButtplugServerMessageVariant]) -> ButtplugSerializedMessage {
167 if let Some(version) = self.message_version.get() {
168 ButtplugSerializedMessage::Text(match version {
169 ButtplugMessageSpecVersion::Version0 => {
170 let msg_vec: Vec<ButtplugServerMessageV0> = msgs
171 .iter()
172 .map(|msg| match msg {
173 ButtplugServerMessageVariant::V0(msgv0) => msgv0.clone(),
174 _ => ButtplugServerMessageV0::Error(message::ErrorV0::from(ButtplugError::from(
175 ButtplugMessageError::MessageConversionError(format!(
176 "Message {msg:?} not in Spec V0! This is a server bug."
177 )),
178 ))),
179 })
180 .collect();
181 vec_to_protocol_json(&msg_vec)
182 }
183 ButtplugMessageSpecVersion::Version1 => {
184 let msg_vec: Vec<ButtplugServerMessageV1> = msgs
185 .iter()
186 .map(|msg| match msg {
187 ButtplugServerMessageVariant::V1(msgv1) => msgv1.clone(),
188 _ => ButtplugServerMessageV1::Error(message::ErrorV0::from(ButtplugError::from(
189 ButtplugMessageError::MessageConversionError(format!(
190 "Message {msg:?} not in Spec V1! This is a server bug."
191 )),
192 ))),
193 })
194 .collect();
195 vec_to_protocol_json(&msg_vec)
196 }
197 ButtplugMessageSpecVersion::Version2 => {
198 let msg_vec: Vec<ButtplugServerMessageV2> = msgs
199 .iter()
200 .map(|msg| match msg {
201 ButtplugServerMessageVariant::V2(msgv2) => msgv2.clone(),
202 _ => ButtplugServerMessageV2::Error(message::ErrorV0::from(ButtplugError::from(
203 ButtplugMessageError::MessageConversionError(format!(
204 "Message {msg:?} not in Spec V2! This is a server bug."
205 )),
206 ))),
207 })
208 .collect();
209 vec_to_protocol_json(&msg_vec)
210 }
211 ButtplugMessageSpecVersion::Version3 => {
212 let msg_vec: Vec<ButtplugServerMessageV3> = msgs
213 .iter()
214 .map(|msg| match msg {
215 ButtplugServerMessageVariant::V3(msgv3) => msgv3.clone(),
216 _ => ButtplugServerMessageV3::Error(message::ErrorV0::from(ButtplugError::from(
217 ButtplugMessageError::MessageConversionError(format!(
218 "Message {msg:?} not in Spec V3! This is a server bug."
219 )),
220 ))),
221 })
222 .collect();
223 vec_to_protocol_json(&msg_vec)
224 }
225 ButtplugMessageSpecVersion::Version4 => {
226 let msg_vec: Vec<ButtplugServerMessageV4> = msgs
227 .iter()
228 .map(|msg| match msg {
229 ButtplugServerMessageVariant::V4(msgv4) => msgv4.clone(),
230 _ => ButtplugServerMessageV4::Error(message::ErrorV0::from(ButtplugError::from(
231 ButtplugMessageError::MessageConversionError(format!(
232 "Message {msg:?} not in Spec V4! This is a server bug."
233 )),
234 ))),
235 })
236 .collect();
237 vec_to_protocol_json(&msg_vec)
238 }
239 })
240 } else {
241 // If we don't even have enough info to know which message
242 // version to convert to, consider this a handshake error.
243 ButtplugSerializedMessage::Text(msg_to_protocol_json(ButtplugServerMessageCurrent::Error(
244 ButtplugError::from(ButtplugHandshakeError::RequestServerInfoExpected).into(),
245 )))
246 }
247 }
248}
249
250#[cfg(test)]
251mod test {
252 use super::*;
253
254 #[test]
255 fn test_correct_message_version() {
256 let json = r#"[{
257 "RequestServerInfo": {
258 "Id": 1,
259 "ClientName": "Test Client",
260 "ProtocolVersionMajor": 4,
261 "ProtocolVersionMinor": 0
262 }
263 }]"#;
264 let serializer = ButtplugServerJSONSerializer::default();
265 serializer
266 .deserialize(&ButtplugSerializedMessage::Text(json.to_owned()))
267 .expect("Infallible deserialization");
268 assert_eq!(
269 *serializer.message_version.get().unwrap(),
270 ButtplugMessageSpecVersion::Version4
271 );
272 }
273
274 #[test]
275 fn test_wrong_message_version() {
276 let json = r#"[{
277 "RequestServerInfo": {
278 "Id": 1,
279 "ClientName": "Test Client",
280 "ProtocolVersionMajor": 100,
281 "ProtocolVersionMinor": 0
282 }
283 }]"#;
284 let serializer = ButtplugServerJSONSerializer::default();
285 let msg = serializer.deserialize(&ButtplugSerializedMessage::Text(json.to_owned()));
286 info!("{:?}", msg);
287 assert!(msg.is_err());
288 }
289
290 #[test]
291 fn test_message_array() {
292 let json = r#"[
293 {
294 "RequestServerInfo": {
295 "Id": 1,
296 "ClientName": "Test Client",
297 "ProtocolVersionMajor": 4,
298 "ProtocolVersionMinor": 0
299 }
300 },
301 {
302 "RequestServerInfo": {
303 "Id": 1,
304 "ClientName": "Test Client",
305 "ProtocolVersionMajor": 4,
306 "ProtocolVersionMinor": 0
307 }
308 },
309 {
310 "RequestServerInfo": {
311 "Id": 1,
312 "ClientName": "Test Client",
313 "ProtocolVersionMajor": 4,
314 "ProtocolVersionMinor": 0
315 }
316 }
317 ]"#;
318 let serializer = ButtplugServerJSONSerializer::default();
319 let messages = serializer
320 .deserialize(&ButtplugSerializedMessage::Text(json.to_owned()))
321 .expect("Infallible deserialization");
322 assert_eq!(messages.len(), 3);
323 }
324
325 #[test]
326 fn test_streamed_message_array() {
327 let json = r#"[
328 {
329 "RequestServerInfo": {
330 "Id": 1,
331 "ClientName": "Test Client",
332 "ProtocolVersionMajor": 4,
333 "ProtocolVersionMinor": 0
334 }
335 }]
336 [{
337 "RequestServerInfo": {
338 "Id": 1,
339 "ClientName": "Test Client",
340 "ProtocolVersionMajor": 4,
341 "ProtocolVersionMinor": 0
342 }
343 }]
344 [{
345 "RequestServerInfo": {
346 "Id": 1,
347 "ClientName": "Test Client",
348 "ProtocolVersionMajor": 4,
349 "ProtocolVersionMinor": 0
350 }
351 }]
352 "#;
353 let serializer = ButtplugServerJSONSerializer::default();
354 let messages = serializer
355 .deserialize(&ButtplugSerializedMessage::Text(json.to_owned()))
356 .expect("Infallible deserialization");
357 assert_eq!(messages.len(), 3);
358 }
359
360 #[test]
361 fn test_invalid_streamed_message_array() {
362 // Missing a } in the second message.
363 let json = r#"[
364 "RequestServerInfo": {
365 "Id": 1,
366 "ClientName": "Test Client",
367 "ProtocolVersionMajor": 4,
368 "ProtocolVersionMinor": 0
369 }
370 }]
371 [{
372 "RequestServerInfo": {
373 "Id": 1,
374 "ClientName": "Test Client",
375 "ProtocolVersionMajor": 4,
376 "ProtocolVersionMinor": 0
377 }]
378 [{
379 "RequestServerInfo": {
380 "Id": 1,
381 "ClientName": "Test Client",
382 "ProtocolVersionMajor": 4,
383 "ProtocolVersionMinor": 0
384 }
385 }]
386 "#;
387 let serializer = ButtplugServerJSONSerializer::default();
388 assert!(matches!(
389 serializer.deserialize(&ButtplugSerializedMessage::Text(json.to_owned())),
390 Err(_)
391 ));
392 }
393}