Tool to order the items constained in rust source files
at develop 570 lines 14 kB view raw
1use std::{ 2 borrow::Cow, cmp::Ordering, convert::Infallible, fmt::Display, path::PathBuf, process::ExitCode, 3}; 4 5use pico_args::Arguments; 6use tree_sitter::{Node, Parser}; 7 8const CLI_HELP: &str = r#"USAGE 9 $ rust-organizer [-c] [-w] FILE 10 11ARGUMENTS 12 FILE File name of the Rust source file to reorganize. 13 14FLAGS 15 -c, --check Check whether reorganizing the file would change the file contents. 16 -w, --write Overwrite the file with the reorganized contents. 17"#; 18 19#[derive(Debug, Clone, PartialEq, Eq, Hash)] 20struct Cli { 21 check: bool, 22 overwrite: bool, 23 path: PathBuf, 24} 25 26#[derive(Debug, Clone, PartialEq, Eq, Hash)] 27enum Item<'a> { 28 InnerDoc(Cow<'a, str>), 29 Macro { 30 name: &'a str, 31 content: Cow<'a, str>, 32 }, 33 ModDecl { 34 name: &'a str, 35 content: Cow<'a, str>, 36 }, 37 Use(Cow<'a, str>), 38 Const { 39 name: &'a str, 40 content: Cow<'a, str>, 41 }, 42 Type { 43 name: &'a str, 44 content: Cow<'a, str>, 45 }, 46 Func { 47 name: &'a str, 48 content: Cow<'a, str>, 49 }, 50 Trait { 51 name: &'a str, 52 content: SortableContent<'a>, 53 }, 54 Impl { 55 name: TypeIdent<'a>, 56 trt: Option<&'a str>, 57 content: SortableContent<'a>, 58 }, 59 MacroInvocation(Cow<'a, str>), 60 Mod { 61 name: &'a str, 62 content: SortableContent<'a>, 63 }, 64} 65 66#[derive(Debug, Clone, PartialEq, Eq, Hash)] 67struct Module<'a> { 68 items: Vec<(bool, Item<'a>)>, 69} 70 71#[derive(Debug, Clone, PartialEq, Eq, Hash)] 72struct SortableContent<'a> { 73 before: Cow<'a, str>, 74 inner: Module<'a>, 75 after: Cow<'a, str>, 76} 77 78#[derive(Debug, Clone, PartialEq, Eq, Hash)] 79struct TypeIdent<'a> { 80 name: &'a str, 81 generics: Option<&'a str>, 82 reference_type: Option<&'a str>, 83} 84 85fn main() -> ExitCode { 86 // Parse commandline arguments 87 let mut args = Arguments::from_env(); 88 if args.contains(["-h", "--help"]) { 89 print!("{}", CLI_HELP); 90 return ExitCode::SUCCESS; 91 } 92 let cli: Cli = match args.try_into() { 93 Ok(cli) => cli, 94 Err(e) => { 95 eprintln!("Error: {}", e); 96 return ExitCode::FAILURE; 97 } 98 }; 99 // Run the main program 100 match cli.run() { 101 Ok(code) => code, 102 Err(e) => { 103 eprintln!("Error: {}", e); 104 ExitCode::FAILURE 105 } 106 } 107} 108 109impl Cli { 110 fn run(&self) -> Result<ExitCode, String> { 111 let mut parser = Parser::new(); 112 parser 113 .set_language(&tree_sitter_rust::LANGUAGE.into()) 114 .expect("Error loading Rust grammar"); 115 116 let text = std::fs::read_to_string(&self.path) 117 .map_err(|e| format!("unable to read file: {}", e))?; 118 119 let Some(tree) = parser.parse(&text, None) else { 120 return Err("unable to parse file".to_owned()); 121 }; 122 123 let root = tree.root_node(); 124 assert_eq!(root.kind(), "source_file"); 125 let mut root = Module::from_node(&text, root); 126 127 let is_sorted = root.is_sorted(self.check); 128 if self.check { 129 return if is_sorted { 130 Ok(ExitCode::SUCCESS) 131 } else { 132 Ok(ExitCode::FAILURE) 133 }; 134 } 135 if self.overwrite && is_sorted { 136 return Ok(ExitCode::SUCCESS); 137 } 138 139 root.sort(); 140 141 if self.overwrite { 142 std::fs::write(&self.path, root.to_string()) 143 .map_err(|e| format!("unable to write file: {}", e))?; 144 } else { 145 println!("{}", root); 146 } 147 148 Ok(ExitCode::SUCCESS) 149 } 150} 151 152impl TryFrom<Arguments> for Cli { 153 type Error = String; 154 155 fn try_from(mut args: Arguments) -> Result<Self, Self::Error> { 156 let cli = Cli { 157 check: args.contains(["-c", "--check"]), 158 overwrite: args.contains(["-w", "--write"]), 159 path: args 160 .free_from_os_str::<_, Infallible>(|s| Ok(PathBuf::from(s))) 161 .unwrap(), 162 }; 163 164 let remaining = args.finish(); 165 match remaining.len() { 166 0 => Ok(()), 167 1 => Err(format!( 168 "unexpected argument: '{}'", 169 remaining[0].to_string_lossy() 170 )), 171 _ => Err(format!( 172 "unexpected arguments: {}", 173 remaining 174 .into_iter() 175 .map(|s| format!("'{}'", s.to_string_lossy())) 176 .collect::<Vec<_>>() 177 .join(", ") 178 )), 179 }?; 180 Ok(cli) 181 } 182} 183 184impl<'a> Item<'a> { 185 fn append_content(&mut self, text: &str) { 186 match self { 187 Item::Macro { content, .. } 188 | Item::ModDecl { content, .. } 189 | Item::Const { content, .. } 190 | Item::Type { content, .. } 191 | Item::Func { content, .. } 192 | Item::InnerDoc(content) 193 | Item::Use(content) 194 | Item::MacroInvocation(content) => { 195 *content = Cow::Owned(format!("{}{}", content, text)); 196 } 197 Item::Impl { .. } | Item::Mod { .. } | Item::Trait { .. } => { 198 // Cannot add content to these items 199 } 200 } 201 } 202 203 fn item_order(&self) -> u8 { 204 match self { 205 Item::InnerDoc(_) => 0, 206 Item::Macro { .. } => 1, 207 Item::ModDecl { .. } => 2, 208 Item::Use(_) => 3, 209 Item::Const { .. } => 4, 210 Item::Type { .. } => 5, 211 Item::Trait { .. } => 5, 212 Item::Func { .. } => 6, 213 Item::Impl { .. } => 7, 214 Item::MacroInvocation(_) => 8, 215 Item::Mod { .. } => 9, 216 } 217 } 218 219 fn maybe_item(text: &'a str, node: Node<'a>, start: Option<usize>) -> Option<Self> { 220 let get_field_str = |field_name| { 221 node.child_by_field_name(field_name) 222 .map(|n| n.utf8_text(text.as_bytes()).unwrap()) 223 }; 224 225 let start = start.unwrap_or(node.start_byte()); 226 let end = if node.utf8_text(text.as_bytes()).unwrap().ends_with('\n') { 227 node.end_byte() - 1 228 } else { 229 node.end_byte() 230 }; 231 let content: Cow<'a, str> = Cow::Borrowed(&text[start..end]); 232 match node.kind() { 233 "attribute_item" => { 234 // Ignore and add to the next item 235 None 236 } 237 "block_comment" | "line_comment" => { 238 let comment = node.utf8_text(text.as_bytes()).unwrap(); 239 if comment.starts_with("//!") || comment.starts_with("/*!") { 240 // Doc comment for the file (ensure that it's at the top of the file). 241 Some(Self::InnerDoc(content)) 242 } else { 243 None // Move comment with the next item 244 } 245 } 246 "const_item" | "static_item" => { 247 let name = get_field_str("name").unwrap(); 248 Some(Self::Const { name, content }) 249 } 250 "associated_type" | "enum_item" | "struct_item" | "type_item" => { 251 let name = get_field_str("name").unwrap(); 252 Some(Self::Type { name, content }) 253 } 254 "function_item" | "function_signature_item" => { 255 let name = get_field_str("name").unwrap(); 256 Some(Self::Func { name, content }) 257 } 258 "trait_item" => { 259 let name = get_field_str("name").unwrap(); 260 let content = SortableContent::within_node(text, node, Some(start), "body"); 261 Some(Self::Trait { name, content }) 262 } 263 "impl_item" => { 264 let name = TypeIdent::from_node(text, node.child_by_field_name("type").unwrap()); 265 let trt = get_field_str("trait"); 266 let content = SortableContent::within_node(text, node, Some(start), "body"); 267 Some(Self::Impl { name, trt, content }) 268 } 269 "inner_attribute_item" => { 270 // Should be at the top (treat like inner doc, to keep it in the chosen 271 // order compared to the module documentation). 272 Some(Self::InnerDoc(content)) 273 } 274 "macro_definition" => { 275 let name = get_field_str("name").unwrap(); 276 Some(Self::Macro { name, content }) 277 } 278 "macro_invocation" => Some(Self::MacroInvocation(content)), 279 "mod_item" => { 280 let name = get_field_str("name").unwrap(); 281 if node.child_by_field_name("body").is_some() { 282 let content = SortableContent::within_node(text, node, Some(start), "body"); 283 Some(Self::Mod { name, content }) 284 } else { 285 Some(Self::ModDecl { name, content }) 286 } 287 } 288 "use_declaration" => Some(Self::Use(content)), 289 _ => panic!( 290 "unexpected node kind: {}\ncontent: {}", 291 node.kind(), 292 content 293 ), 294 } 295 } 296} 297 298impl Display for Item<'_> { 299 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 300 match self { 301 Item::InnerDoc(content) 302 | Item::Macro { content, .. } 303 | Item::MacroInvocation(content) 304 | Item::ModDecl { content, .. } 305 | Item::Use(content) 306 | Item::Const { content, .. } 307 | Item::Type { content, .. } 308 | Item::Func { content, .. } => write!(f, "{content}"), 309 Item::Mod { content, .. } 310 | Item::Impl { content, .. } 311 | Item::Trait { content, .. } => { 312 write!(f, "{content}") 313 } 314 } 315 } 316} 317 318impl Ord for Item<'_> { 319 fn cmp(&self, other: &Self) -> Ordering { 320 use Item::*; 321 322 let self_order = self.item_order(); 323 let other_order = other.item_order(); 324 if self_order != other_order { 325 return self_order.cmp(&other_order); 326 } 327 match (self, other) { 328 (InnerDoc(_), InnerDoc(_)) => Ordering::Equal, 329 (Const { name: a, .. }, Const { name: b, .. }) 330 | (Macro { name: a, .. }, Macro { name: b, .. }) 331 | (Mod { name: a, .. }, Mod { name: b, .. }) 332 | (ModDecl { name: a, .. }, ModDecl { name: b, .. }) 333 | ( 334 Type { name: a, .. } | Trait { name: a, .. }, 335 Type { name: b, .. } | Trait { name: b, .. }, 336 ) 337 | (Func { name: a, .. }, Func { name: b, .. }) => a.cmp(b), 338 (Use(_), Use(_)) | (MacroInvocation(_), MacroInvocation(_)) => Ordering::Equal, 339 ( 340 Impl { 341 name: a, trt: t_a, .. 342 }, 343 Impl { 344 name: b, trt: t_b, .. 345 }, 346 ) => { 347 let name_order = a.name.cmp(b.name); 348 if name_order == Ordering::Equal { 349 let trt_order = t_a.unwrap_or("").cmp(t_b.unwrap_or("")); 350 if trt_order == Ordering::Equal { 351 let a_parts = (a.generics.unwrap_or(""), a.reference_type.unwrap_or("")); 352 let b_parts = (b.generics.unwrap_or(""), b.reference_type.unwrap_or("")); 353 a_parts.cmp(&b_parts) 354 } else { 355 trt_order 356 } 357 } else { 358 name_order 359 } 360 } 361 _ => { 362 // eprintln!("{} -- {}", self, other); 363 unreachable!(); 364 } 365 } 366 } 367} 368 369impl PartialOrd for Item<'_> { 370 fn partial_cmp(&self, other: &Self) -> Option<Ordering> { 371 Some(self.cmp(other)) 372 } 373} 374 375impl<'a> Module<'a> { 376 pub fn from_node(text: &'a str, root: Node<'a>) -> Self { 377 assert!(matches!(root.kind(), "source_file" | "declaration_list")); 378 let mut cursor = root.walk(); 379 cursor.goto_first_child(); 380 381 let mut items: Vec<(bool, Item)> = Vec::new(); 382 let mut start = None; 383 let mut last = None; 384 if cursor.node().kind() == "{" { 385 last = Some(cursor.node().end_byte()); 386 cursor.goto_next_sibling(); 387 } 388 loop { 389 if cursor.node().kind() == "}" { 390 assert!(!cursor.goto_next_sibling()); 391 break; 392 } 393 let node = cursor.node(); 394 // eprintln!("{} : {}\n\n", node.kind(), node.to_sexp()); 395 let inbetween = 396 &text[last.unwrap_or(root.start_byte())..start.unwrap_or(node.start_byte())]; 397 if node.kind() == "empty_statement" { 398 if let Some((_, it)) = items.last_mut() { 399 it.append_content(";"); 400 } 401 debug_assert!( 402 inbetween.trim().is_empty(), 403 "unexpected skipped content: {:?}", 404 inbetween 405 ); 406 start = None; 407 last = Some(node.end_byte()); 408 } else if let Some(item) = Item::maybe_item(text, node, start) { 409 debug_assert!( 410 inbetween.trim().is_empty(), 411 "unexpected skipped content: {:?}", 412 inbetween 413 ); 414 let newline_before = inbetween.contains("\n\n"); 415 items.push((items.is_empty() || newline_before, item)); 416 start = None; 417 last = Some(node.end_byte()); 418 } else if start.is_none() { 419 start = Some(node.start_byte()); 420 } 421 if !cursor.goto_next_sibling() { 422 break; 423 } 424 } 425 426 Self { items } 427 } 428 429 pub fn is_sorted(&self, print_diff: bool) -> bool { 430 for it in &self.items { 431 match &it.1 { 432 Item::Mod { content, .. } 433 | Item::Impl { content, .. } 434 | Item::Trait { content, .. } => { 435 if !content.is_sorted(print_diff) { 436 return false; 437 } 438 } 439 _ => {} 440 } 441 } 442 for window in self.items.windows(2) { 443 if window[0].1 > window[1].1 { 444 if print_diff { 445 eprintln!( 446 "Expected \n\"\"\"\n{}\n\"\"\"\n before \n\"\"\"\n{}\n\"\"\"", 447 window[1].1, window[0].1 448 ); 449 } 450 return false; 451 } 452 } 453 true 454 } 455 456 pub fn sort(&mut self) { 457 for it in self.items.iter_mut() { 458 match &mut it.1 { 459 Item::Mod { content, .. } 460 | Item::Impl { content, .. } 461 | Item::Trait { content, .. } => content.sort(), 462 _ => {} 463 } 464 } 465 self.items.sort_unstable_by(|a, b| a.1.cmp(&b.1)); 466 } 467} 468 469impl Display for Module<'_> { 470 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 471 let mut last = None; 472 for (newline, item) in &self.items { 473 if *newline || (last.is_some() && last != Some(item.item_order())) { 474 writeln!(f)?; 475 } 476 writeln!(f, "{}", item)?; 477 last = Some(item.item_order()); 478 } 479 Ok(()) 480 } 481} 482 483impl<'a> SortableContent<'a> { 484 fn is_sorted(&self, print_diff: bool) -> bool { 485 self.inner.is_sorted(print_diff) 486 } 487 488 fn sort(&mut self) { 489 self.inner.sort(); 490 } 491 492 fn within_node( 493 text: &'a str, 494 node: Node<'a>, 495 start: Option<usize>, 496 child: &'static str, 497 ) -> Self { 498 let start = start.unwrap_or(node.start_byte()); 499 let body = node.child_by_field_name(child).unwrap(); 500 501 let mut cursor = body.walk(); 502 cursor.goto_first_child(); 503 assert_eq!(cursor.node().kind(), "{"); 504 let before = Cow::Borrowed(&text[start..cursor.node().end_byte()]); 505 506 cursor.goto_parent(); 507 cursor.goto_last_child(); 508 assert_eq!(cursor.node().kind(), "}"); 509 let after = Cow::Borrowed(&text[cursor.node().start_byte()..node.end_byte()]); 510 511 let inner = Module::from_node(text, body); 512 Self { 513 before, 514 inner, 515 after, 516 } 517 } 518} 519 520impl Display for SortableContent<'_> { 521 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 522 write!(f, "{}{}{}", self.before, self.inner, self.after) 523 } 524} 525 526impl<'a> TypeIdent<'a> { 527 fn from_node(text: &'a str, node: Node<'a>) -> Self { 528 let get_field_str = |field_name| { 529 node.child_by_field_name(field_name) 530 .map(|n| n.utf8_text(text.as_bytes()).unwrap()) 531 }; 532 533 match node.kind() { 534 "array_type" => { 535 let inner = node.child_by_field_name("element").unwrap(); 536 let mut ty = TypeIdent::from_node(text, inner); 537 let reference_str = &text[node.start_byte()..inner.start_byte()]; 538 ty.reference_type = Some(reference_str); 539 ty 540 } 541 "generic_type" => { 542 let name = get_field_str("type").unwrap(); 543 let generics = get_field_str("type_arguments"); 544 debug_assert!(generics.is_some()); 545 Self { 546 name, 547 generics, 548 reference_type: None, 549 } 550 } 551 "reference_type" => { 552 let inner = node.child_by_field_name("type").unwrap(); 553 let mut ty = TypeIdent::from_node(text, inner); 554 let reference_str = &text[node.start_byte()..inner.start_byte()]; 555 ty.reference_type = Some(reference_str); 556 ty 557 } 558 "type_identifier" | "primitive_type" | "bounded_type" => Self { 559 name: node.utf8_text(text.as_bytes()).unwrap(), 560 generics: None, 561 reference_type: None, 562 }, 563 _ => panic!( 564 "invalid type identifier node: {}, `{}'", 565 node.kind(), 566 node.utf8_text(text.as_bytes()).unwrap() 567 ), 568 } 569 } 570}