+9
-2
slingshot/src/main.rs
+9
-2
slingshot/src/main.rs
···
81
81
Ok(())
82
82
});
83
83
84
-
let repo = Repo::new(identity);
84
+
let repo = Repo::new(identity.clone());
85
85
86
86
let server_shutdown = shutdown.clone();
87
87
let server_cache_handle = cache.clone();
88
88
tasks.spawn(async move {
89
-
serve(server_cache_handle, repo, args.host, server_shutdown).await?;
89
+
serve(
90
+
server_cache_handle,
91
+
identity,
92
+
repo,
93
+
args.host,
94
+
server_shutdown,
95
+
)
96
+
.await?;
90
97
Ok(())
91
98
});
92
99
+8
-18
slingshot/src/record.rs
+8
-18
slingshot/src/record.rs
···
1
1
//! cached record storage
2
2
3
3
use crate::{Identity, error::RecordError};
4
-
use atrium_api::types::string::{Cid, Did, Handle};
4
+
use atrium_api::types::string::{Cid, Did, Nsid, RecordKey};
5
5
use reqwest::Client;
6
6
use serde::{Deserialize, Serialize};
7
7
use serde_json::value::RawValue;
···
78
78
79
79
pub async fn get_record(
80
80
&self,
81
-
did_or_handle: String,
82
-
collection: String,
83
-
rkey: String,
84
-
cid: Option<String>,
81
+
did: &Did,
82
+
collection: &Nsid,
83
+
rkey: &RecordKey,
84
+
cid: &Option<Cid>,
85
85
) -> Result<CachedRecord, RecordError> {
86
-
let did = match Did::new(did_or_handle.clone()) {
87
-
Ok(did) => did,
88
-
Err(_) => {
89
-
let handle = Handle::new(did_or_handle).map_err(|_| RecordError::BadRepo)?;
90
-
let Some(did) = self.identity.handle_to_did(handle).await? else {
91
-
return Err(RecordError::NotFound("could not resolve and verify handle"));
92
-
};
93
-
did
94
-
}
95
-
};
96
86
let Some(pds) = self.identity.did_to_pds(did.clone()).await? else {
97
87
return Err(RecordError::NotFound("could not get pds for DID"));
98
88
};
···
101
91
102
92
let mut params = vec![
103
93
("repo", did.to_string()),
104
-
("collection", collection),
105
-
("rkey", rkey),
94
+
("collection", collection.to_string()),
95
+
("rkey", rkey.to_string()),
106
96
];
107
97
if let Some(cid) = cid {
108
-
params.push(("cid", cid));
98
+
params.push(("cid", cid.as_ref().to_string()));
109
99
}
110
100
let mut url = Url::parse_with_params(&pds, ¶ms)?;
111
101
url.set_path("/xrpc/com.atproto.repo.getRecord");
+85
-17
slingshot/src/server.rs
+85
-17
slingshot/src/server.rs
···
1
-
use crate::{CachedRecord, Repo, error::ServerError};
1
+
use crate::{CachedRecord, Identity, Repo, error::ServerError};
2
+
use atrium_api::types::string::{Cid, Did, Handle, Nsid, RecordKey};
2
3
use foyer::HybridCache;
3
4
use serde::Serialize;
5
+
use std::str::FromStr;
4
6
use std::sync::Arc;
5
7
use tokio_util::sync::CancellationToken;
6
8
···
43
45
message: "This record was deleted".to_string(),
44
46
}
45
47
}
48
+
}
49
+
type XrpcError = Json<XrpcErrorResponseObject>;
50
+
fn xrpc_error(error: impl AsRef<str>, message: impl AsRef<str>) -> XrpcError {
51
+
Json(XrpcErrorResponseObject {
52
+
error: error.as_ref().to_string(),
53
+
message: message.as_ref().to_string(),
54
+
})
46
55
}
47
56
48
57
fn bad_request_handler(err: poem::Error) -> GetRecordResponse {
···
100
109
/// also list `InvalidRequest`, `ExpiredToken`, and `InvalidToken`. Of
101
110
/// these, slingshot will only return `RecordNotFound` or `InvalidRequest`.
102
111
#[oai(status = 400)]
103
-
BadRequest(Json<XrpcErrorResponseObject>),
112
+
BadRequest(XrpcError),
113
+
/// Just using 500 for potentially upstream errors for now
114
+
#[oai(status = 500)]
115
+
ServerError(XrpcError),
104
116
}
105
117
106
118
struct Xrpc {
107
119
cache: HybridCache<String, CachedRecord>,
120
+
identity: Identity,
108
121
repo: Arc<Repo>,
109
122
}
110
123
···
140
153
/// record.
141
154
Query(cid): Query<Option<String>>,
142
155
) -> GetRecordResponse {
143
-
// TODO: yeah yeah
144
-
let at_uri = format!("at://{repo}/{collection}/{rkey}");
156
+
let did = match Did::new(repo.clone()) {
157
+
Ok(did) => did,
158
+
Err(_) => {
159
+
let Ok(handle) = Handle::new(repo) else {
160
+
return GetRecordResponse::BadRequest(xrpc_error(
161
+
"InvalidRequest",
162
+
"repo was not a valid DID or handle",
163
+
));
164
+
};
165
+
if let Ok(res) = self.identity.handle_to_did(handle).await {
166
+
if let Some(did) = res {
167
+
did
168
+
} else {
169
+
return GetRecordResponse::BadRequest(xrpc_error(
170
+
"InvalidRequest",
171
+
"Could not resolve handle repo to a DID",
172
+
));
173
+
}
174
+
} else {
175
+
return GetRecordResponse::ServerError(xrpc_error(
176
+
"ResolutionFailed",
177
+
"errored while trying to resolve handle to DID",
178
+
));
179
+
}
180
+
}
181
+
};
182
+
183
+
let Ok(collection) = Nsid::new(collection) else {
184
+
return GetRecordResponse::BadRequest(xrpc_error(
185
+
"InvalidRequest",
186
+
"invalid NSID for collection",
187
+
));
188
+
};
189
+
190
+
let Ok(rkey) = RecordKey::new(rkey) else {
191
+
return GetRecordResponse::BadRequest(xrpc_error("InvalidRequest", "invalid rkey"));
192
+
};
193
+
194
+
let cid: Option<Cid> = if let Some(cid) = cid {
195
+
let Ok(cid) = Cid::from_str(&cid) else {
196
+
return GetRecordResponse::BadRequest(xrpc_error("InvalidRequest", "invalid CID"));
197
+
};
198
+
Some(cid)
199
+
} else {
200
+
None
201
+
};
202
+
203
+
let at_uri = format!("at://{}/{}/{}", &*did, &*collection, &*rkey);
145
204
146
205
let entry = self
147
206
.cache
···
150
209
let repo_api = self.repo.clone();
151
210
|| async move {
152
211
repo_api
153
-
.get_record(repo, collection, rkey, cid)
212
+
.get_record(&did, &collection, &rkey, &cid)
154
213
.await
155
214
.map_err(|e| foyer::Error::Other(Box::new(e)))
156
215
}
···
163
222
match *entry {
164
223
CachedRecord::Found(ref raw) => {
165
224
let (found_cid, raw_value) = raw.into();
166
-
let found_cid = found_cid.as_ref().to_string();
167
225
if cid.clone().map(|c| c != found_cid).unwrap_or(false) {
168
226
return GetRecordResponse::BadRequest(Json(XrpcErrorResponseObject {
169
227
error: "RecordNotFound".to_string(),
···
176
234
serde_json::from_str(raw_value.get()).expect("RawValue to be valid json");
177
235
GetRecordResponse::Ok(Json(FoundRecordResponseObject {
178
236
uri: at_uri,
179
-
cid: Some(found_cid),
237
+
cid: Some(found_cid.as_ref().to_string()),
180
238
value,
181
239
}))
182
240
}
···
234
292
235
293
pub async fn serve(
236
294
cache: HybridCache<String, CachedRecord>,
295
+
identity: Identity,
237
296
repo: Repo,
238
297
host: Option<String>,
239
298
_shutdown: CancellationToken,
240
299
) -> Result<(), ServerError> {
241
300
let repo = Arc::new(repo);
242
-
let api_service =
243
-
OpenApiService::new(Xrpc { cache, repo }, "Slingshot", env!("CARGO_PKG_VERSION"))
244
-
.server("http://localhost:3000")
245
-
.url_prefix("/xrpc");
301
+
let api_service = OpenApiService::new(
302
+
Xrpc {
303
+
cache,
304
+
identity,
305
+
repo,
306
+
},
307
+
"Slingshot",
308
+
env!("CARGO_PKG_VERSION"),
309
+
)
310
+
.server("http://localhost:3000")
311
+
.url_prefix("/xrpc");
246
312
247
313
let mut app = Route::new()
248
314
.nest("/", api_service.scalar())
···
254
320
.install_default()
255
321
.expect("alskfjalksdjf");
256
322
257
-
app = app
258
-
.at("/.well-known/did.json", get_did_doc(&host));
323
+
app = app.at("/.well-known/did.json", get_did_doc(&host));
259
324
260
325
let auto_cert = AutoCert::builder()
261
326
.directory_url(LETS_ENCRYPT_PRODUCTION)
···
271
336
272
337
async fn run<L>(listener: L, app: Route) -> Result<(), ServerError>
273
338
where
274
-
L: Listener + 'static
339
+
L: Listener + 'static,
275
340
{
276
341
let app = app
277
-
.with(Cors::new()
278
-
.allow_method(Method::GET)
279
-
.allow_credentials(false))
342
+
.with(
343
+
Cors::new()
344
+
.allow_origin("*")
345
+
.allow_methods([Method::GET])
346
+
.allow_credentials(false),
347
+
)
280
348
.with(Tracing);
281
349
Server::new(listener)
282
350
.name("slingshot")