Alternative ATProto PDS implementation
1#![allow(clippy::arbitrary_source_item_ordering)] 2use std::io::{ErrorKind, Read as _, Seek as _, Write as _}; 3 4#[cfg(unix)] 5use std::os::fd::AsRawFd as _; 6#[cfg(windows)] 7use std::os::windows::io::AsRawHandle; 8 9use memmap2::{MmapMut, MmapOptions}; 10 11pub(crate) struct MappedFile { 12 /// The underlying file handle. 13 file: std::fs::File, 14 /// The length of the file. 15 len: u64, 16 /// The mapped memory region. 17 map: MmapMut, 18 /// Our current offset into the file. 19 off: u64, 20} 21 22impl MappedFile { 23 pub(crate) fn new(mut f: std::fs::File) -> std::io::Result<Self> { 24 let len = f.seek(std::io::SeekFrom::End(0))?; 25 26 #[cfg(windows)] 27 let raw = f.as_raw_handle(); 28 #[cfg(unix)] 29 let raw = f.as_raw_fd(); 30 31 #[expect(unsafe_code)] 32 Ok(Self { 33 // SAFETY: 34 // All file-backed memory map constructors are marked \ 35 // unsafe because of the potential for Undefined Behavior (UB) \ 36 // using the map if the underlying file is subsequently modified, in or out of process. 37 map: unsafe { MmapOptions::new().map_mut(raw)? }, 38 file: f, 39 len, 40 off: 0, 41 }) 42 } 43 44 /// Resize the memory-mapped file. This will reallocate the memory mapping. 45 #[expect(unsafe_code)] 46 fn resize(&mut self, len: u64) -> std::io::Result<()> { 47 // Resize the file. 48 self.file.set_len(len)?; 49 50 #[cfg(windows)] 51 let raw = self.file.as_raw_handle(); 52 #[cfg(unix)] 53 let raw = self.file.as_raw_fd(); 54 55 // SAFETY: 56 // All file-backed memory map constructors are marked \ 57 // unsafe because of the potential for Undefined Behavior (UB) \ 58 // using the map if the underlying file is subsequently modified, in or out of process. 59 self.map = unsafe { MmapOptions::new().map_mut(raw)? }; 60 self.len = len; 61 62 Ok(()) 63 } 64} 65 66impl std::io::Read for MappedFile { 67 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> { 68 if self.off == self.len { 69 // If we're at EOF, return an EOF error code. `Ok(0)` tends to trip up some implementations. 70 return Err(std::io::Error::new(ErrorKind::UnexpectedEof, "eof")); 71 } 72 73 // Calculate the number of bytes we're going to read. 74 let remaining_bytes = self.len.saturating_sub(self.off); 75 let buf_len = u64::try_from(buf.len()).unwrap_or(u64::MAX); 76 let len = usize::try_from(std::cmp::min(remaining_bytes, buf_len)).unwrap_or(usize::MAX); 77 78 let off = usize::try_from(self.off).map_err(|e| { 79 std::io::Error::new( 80 ErrorKind::InvalidInput, 81 format!("offset too large for this platform: {e}"), 82 ) 83 })?; 84 85 if let (Some(dest), Some(src)) = ( 86 buf.get_mut(..len), 87 self.map.get(off..off.saturating_add(len)), 88 ) { 89 dest.copy_from_slice(src); 90 self.off = self.off.saturating_add(u64::try_from(len).unwrap_or(0)); 91 Ok(len) 92 } else { 93 Err(std::io::Error::new( 94 ErrorKind::InvalidInput, 95 "invalid buffer range", 96 )) 97 } 98 } 99} 100 101impl std::io::Write for MappedFile { 102 fn flush(&mut self) -> std::io::Result<()> { 103 // This is done by the system. 104 Ok(()) 105 } 106 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { 107 // Determine if we need to resize the file. 108 let buf_len = u64::try_from(buf.len()).map_err(|e| { 109 std::io::Error::new( 110 ErrorKind::InvalidInput, 111 format!("buffer length too large for this platform: {e}"), 112 ) 113 })?; 114 115 if self.off.saturating_add(buf_len) >= self.len { 116 self.resize(self.off.saturating_add(buf_len))?; 117 } 118 119 let off = usize::try_from(self.off).map_err(|e| { 120 std::io::Error::new( 121 ErrorKind::InvalidInput, 122 format!("offset too large for this platform: {e}"), 123 ) 124 })?; 125 let len = buf.len(); 126 127 if let Some(dest) = self.map.get_mut(off..off.saturating_add(len)) { 128 dest.copy_from_slice(buf); 129 self.off = self.off.saturating_add(buf_len); 130 Ok(len) 131 } else { 132 Err(std::io::Error::new( 133 ErrorKind::InvalidInput, 134 "invalid buffer range", 135 )) 136 } 137 } 138} 139 140impl std::io::Seek for MappedFile { 141 fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> { 142 let off = match pos { 143 std::io::SeekFrom::Start(i) => i, 144 std::io::SeekFrom::End(i) => { 145 if i <= 0 { 146 // If i is negative or zero, we're seeking backwards from the end 147 // or exactly at the end 148 self.len.saturating_sub(i.unsigned_abs()) 149 } else { 150 // If i is positive, we're seeking beyond the end, which is allowed 151 // but requires extending the file 152 self.len.saturating_add(i.unsigned_abs()) 153 } 154 } 155 std::io::SeekFrom::Current(i) => { 156 if i >= 0 { 157 self.off.saturating_add(i.unsigned_abs()) 158 } else { 159 self.off.saturating_sub(i.unsigned_abs()) 160 } 161 } 162 }; 163 164 // If the offset is beyond EOF, extend the file to the new size. 165 if off > self.len { 166 self.resize(off)?; 167 } 168 169 self.off = off; 170 Ok(off) 171 } 172} 173 174impl tokio::io::AsyncRead for MappedFile { 175 fn poll_read( 176 mut self: std::pin::Pin<&mut Self>, 177 _cx: &mut std::task::Context<'_>, 178 buf: &mut tokio::io::ReadBuf<'_>, 179 ) -> std::task::Poll<std::io::Result<()>> { 180 let wbuf = buf.initialize_unfilled(); 181 let len = wbuf.len(); 182 183 std::task::Poll::Ready(match self.read(wbuf) { 184 Ok(_) => { 185 buf.advance(len); 186 Ok(()) 187 } 188 Err(e) => Err(e), 189 }) 190 } 191} 192 193impl tokio::io::AsyncWrite for MappedFile { 194 fn poll_flush( 195 self: std::pin::Pin<&mut Self>, 196 _cx: &mut std::task::Context<'_>, 197 ) -> std::task::Poll<Result<(), std::io::Error>> { 198 std::task::Poll::Ready(Ok(())) 199 } 200 201 fn poll_shutdown( 202 self: std::pin::Pin<&mut Self>, 203 _cx: &mut std::task::Context<'_>, 204 ) -> std::task::Poll<Result<(), std::io::Error>> { 205 std::task::Poll::Ready(Ok(())) 206 } 207 208 fn poll_write( 209 mut self: std::pin::Pin<&mut Self>, 210 _cx: &mut std::task::Context<'_>, 211 buf: &[u8], 212 ) -> std::task::Poll<Result<usize, std::io::Error>> { 213 std::task::Poll::Ready(self.write(buf)) 214 } 215} 216 217impl tokio::io::AsyncSeek for MappedFile { 218 fn poll_complete( 219 self: std::pin::Pin<&mut Self>, 220 _cx: &mut std::task::Context<'_>, 221 ) -> std::task::Poll<std::io::Result<u64>> { 222 std::task::Poll::Ready(Ok(self.off)) 223 } 224 225 fn start_seek( 226 mut self: std::pin::Pin<&mut Self>, 227 position: std::io::SeekFrom, 228 ) -> std::io::Result<()> { 229 self.seek(position).map(|_p| ()) 230 } 231} 232 233#[cfg(test)] 234mod test { 235 use rand::Rng as _; 236 use std::io::Write as _; 237 238 use super::*; 239 240 #[test] 241 fn basic_rw() { 242 let tmp = std::env::temp_dir().join( 243 rand::thread_rng() 244 .sample_iter(rand::distributions::Alphanumeric) 245 .take(10) 246 .map(char::from) 247 .collect::<String>(), 248 ); 249 250 let mut m = MappedFile::new( 251 std::fs::File::options() 252 .create(true) 253 .truncate(true) 254 .read(true) 255 .write(true) 256 .open(&tmp) 257 .expect("Failed to open temporary file"), 258 ) 259 .expect("Failed to create MappedFile"); 260 261 m.write_all(b"abcd123").expect("Failed to write data"); 262 let _: u64 = m 263 .seek(std::io::SeekFrom::Start(0)) 264 .expect("Failed to seek to start"); 265 266 let mut buf = [0_u8; 7]; 267 m.read_exact(&mut buf).expect("Failed to read data"); 268 269 assert_eq!(&buf, b"abcd123"); 270 271 drop(m); 272 std::fs::remove_file(tmp).expect("Failed to remove temporary file"); 273 } 274}