Alternative ATProto PDS implementation
1#![allow(clippy::arbitrary_source_item_ordering)]
2use std::io::{ErrorKind, Read as _, Seek as _, Write as _};
3
4#[cfg(unix)]
5use std::os::fd::AsRawFd as _;
6#[cfg(windows)]
7use std::os::windows::io::AsRawHandle;
8
9use memmap2::{MmapMut, MmapOptions};
10
11pub(crate) struct MappedFile {
12 /// The underlying file handle.
13 file: std::fs::File,
14 /// The length of the file.
15 len: u64,
16 /// The mapped memory region.
17 map: MmapMut,
18 /// Our current offset into the file.
19 off: u64,
20}
21
22impl MappedFile {
23 pub(crate) fn new(mut f: std::fs::File) -> std::io::Result<Self> {
24 let len = f.seek(std::io::SeekFrom::End(0))?;
25
26 #[cfg(windows)]
27 let raw = f.as_raw_handle();
28 #[cfg(unix)]
29 let raw = f.as_raw_fd();
30
31 #[expect(unsafe_code)]
32 Ok(Self {
33 // SAFETY:
34 // All file-backed memory map constructors are marked \
35 // unsafe because of the potential for Undefined Behavior (UB) \
36 // using the map if the underlying file is subsequently modified, in or out of process.
37 map: unsafe { MmapOptions::new().map_mut(raw)? },
38 file: f,
39 len,
40 off: 0,
41 })
42 }
43
44 /// Resize the memory-mapped file. This will reallocate the memory mapping.
45 #[expect(unsafe_code)]
46 fn resize(&mut self, len: u64) -> std::io::Result<()> {
47 // Resize the file.
48 self.file.set_len(len)?;
49
50 #[cfg(windows)]
51 let raw = self.file.as_raw_handle();
52 #[cfg(unix)]
53 let raw = self.file.as_raw_fd();
54
55 // SAFETY:
56 // All file-backed memory map constructors are marked \
57 // unsafe because of the potential for Undefined Behavior (UB) \
58 // using the map if the underlying file is subsequently modified, in or out of process.
59 self.map = unsafe { MmapOptions::new().map_mut(raw)? };
60 self.len = len;
61
62 Ok(())
63 }
64}
65
66impl std::io::Read for MappedFile {
67 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
68 if self.off == self.len {
69 // If we're at EOF, return an EOF error code. `Ok(0)` tends to trip up some implementations.
70 return Err(std::io::Error::new(ErrorKind::UnexpectedEof, "eof"));
71 }
72
73 // Calculate the number of bytes we're going to read.
74 let remaining_bytes = self.len.saturating_sub(self.off);
75 let buf_len = u64::try_from(buf.len()).unwrap_or(u64::MAX);
76 let len = usize::try_from(std::cmp::min(remaining_bytes, buf_len)).unwrap_or(usize::MAX);
77
78 let off = usize::try_from(self.off).map_err(|e| {
79 std::io::Error::new(
80 ErrorKind::InvalidInput,
81 format!("offset too large for this platform: {e}"),
82 )
83 })?;
84
85 if let (Some(dest), Some(src)) = (
86 buf.get_mut(..len),
87 self.map.get(off..off.saturating_add(len)),
88 ) {
89 dest.copy_from_slice(src);
90 self.off = self.off.saturating_add(u64::try_from(len).unwrap_or(0));
91 Ok(len)
92 } else {
93 Err(std::io::Error::new(
94 ErrorKind::InvalidInput,
95 "invalid buffer range",
96 ))
97 }
98 }
99}
100
101impl std::io::Write for MappedFile {
102 fn flush(&mut self) -> std::io::Result<()> {
103 // This is done by the system.
104 Ok(())
105 }
106 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
107 // Determine if we need to resize the file.
108 let buf_len = u64::try_from(buf.len()).map_err(|e| {
109 std::io::Error::new(
110 ErrorKind::InvalidInput,
111 format!("buffer length too large for this platform: {e}"),
112 )
113 })?;
114
115 if self.off.saturating_add(buf_len) >= self.len {
116 self.resize(self.off.saturating_add(buf_len))?;
117 }
118
119 let off = usize::try_from(self.off).map_err(|e| {
120 std::io::Error::new(
121 ErrorKind::InvalidInput,
122 format!("offset too large for this platform: {e}"),
123 )
124 })?;
125 let len = buf.len();
126
127 if let Some(dest) = self.map.get_mut(off..off.saturating_add(len)) {
128 dest.copy_from_slice(buf);
129 self.off = self.off.saturating_add(buf_len);
130 Ok(len)
131 } else {
132 Err(std::io::Error::new(
133 ErrorKind::InvalidInput,
134 "invalid buffer range",
135 ))
136 }
137 }
138}
139
140impl std::io::Seek for MappedFile {
141 fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
142 let off = match pos {
143 std::io::SeekFrom::Start(i) => i,
144 std::io::SeekFrom::End(i) => {
145 if i <= 0 {
146 // If i is negative or zero, we're seeking backwards from the end
147 // or exactly at the end
148 self.len.saturating_sub(i.unsigned_abs())
149 } else {
150 // If i is positive, we're seeking beyond the end, which is allowed
151 // but requires extending the file
152 self.len.saturating_add(i.unsigned_abs())
153 }
154 }
155 std::io::SeekFrom::Current(i) => {
156 if i >= 0 {
157 self.off.saturating_add(i.unsigned_abs())
158 } else {
159 self.off.saturating_sub(i.unsigned_abs())
160 }
161 }
162 };
163
164 // If the offset is beyond EOF, extend the file to the new size.
165 if off > self.len {
166 self.resize(off)?;
167 }
168
169 self.off = off;
170 Ok(off)
171 }
172}
173
174impl tokio::io::AsyncRead for MappedFile {
175 fn poll_read(
176 mut self: std::pin::Pin<&mut Self>,
177 _cx: &mut std::task::Context<'_>,
178 buf: &mut tokio::io::ReadBuf<'_>,
179 ) -> std::task::Poll<std::io::Result<()>> {
180 let wbuf = buf.initialize_unfilled();
181 let len = wbuf.len();
182
183 std::task::Poll::Ready(match self.read(wbuf) {
184 Ok(_) => {
185 buf.advance(len);
186 Ok(())
187 }
188 Err(e) => Err(e),
189 })
190 }
191}
192
193impl tokio::io::AsyncWrite for MappedFile {
194 fn poll_flush(
195 self: std::pin::Pin<&mut Self>,
196 _cx: &mut std::task::Context<'_>,
197 ) -> std::task::Poll<Result<(), std::io::Error>> {
198 std::task::Poll::Ready(Ok(()))
199 }
200
201 fn poll_shutdown(
202 self: std::pin::Pin<&mut Self>,
203 _cx: &mut std::task::Context<'_>,
204 ) -> std::task::Poll<Result<(), std::io::Error>> {
205 std::task::Poll::Ready(Ok(()))
206 }
207
208 fn poll_write(
209 mut self: std::pin::Pin<&mut Self>,
210 _cx: &mut std::task::Context<'_>,
211 buf: &[u8],
212 ) -> std::task::Poll<Result<usize, std::io::Error>> {
213 std::task::Poll::Ready(self.write(buf))
214 }
215}
216
217impl tokio::io::AsyncSeek for MappedFile {
218 fn poll_complete(
219 self: std::pin::Pin<&mut Self>,
220 _cx: &mut std::task::Context<'_>,
221 ) -> std::task::Poll<std::io::Result<u64>> {
222 std::task::Poll::Ready(Ok(self.off))
223 }
224
225 fn start_seek(
226 mut self: std::pin::Pin<&mut Self>,
227 position: std::io::SeekFrom,
228 ) -> std::io::Result<()> {
229 self.seek(position).map(|_p| ())
230 }
231}
232
233#[cfg(test)]
234mod test {
235 use rand::Rng as _;
236 use std::io::Write as _;
237
238 use super::*;
239
240 #[test]
241 fn basic_rw() {
242 let tmp = std::env::temp_dir().join(
243 rand::thread_rng()
244 .sample_iter(rand::distributions::Alphanumeric)
245 .take(10)
246 .map(char::from)
247 .collect::<String>(),
248 );
249
250 let mut m = MappedFile::new(
251 std::fs::File::options()
252 .create(true)
253 .truncate(true)
254 .read(true)
255 .write(true)
256 .open(&tmp)
257 .expect("Failed to open temporary file"),
258 )
259 .expect("Failed to create MappedFile");
260
261 m.write_all(b"abcd123").expect("Failed to write data");
262 let _: u64 = m
263 .seek(std::io::SeekFrom::Start(0))
264 .expect("Failed to seek to start");
265
266 let mut buf = [0_u8; 7];
267 m.read_exact(&mut buf).expect("Failed to read data");
268
269 assert_eq!(&buf, b"abcd123");
270
271 drop(m);
272 std::fs::remove_file(tmp).expect("Failed to remove temporary file");
273 }
274}