Tool to order the items constained in rust source files
1use std::{
2 borrow::Cow, cmp::Ordering, convert::Infallible, fmt::Display, path::PathBuf, process::ExitCode,
3};
4
5use pico_args::Arguments;
6use tree_sitter::{Node, Parser};
7
8const CLI_HELP: &str = r#"USAGE
9 $ rust-organizer [-c] [-w] FILE
10
11ARGUMENTS
12 FILE File name of the Rust source file to reorganize.
13
14FLAGS
15 -c, --check Check whether reorganizing the file would change the file contents.
16 -w, --write Overwrite the file with the reorganized contents.
17"#;
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
20struct Cli {
21 check: bool,
22 overwrite: bool,
23 path: PathBuf,
24}
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27enum Item<'a> {
28 InnerDoc(Cow<'a, str>),
29 Macro {
30 name: &'a str,
31 content: Cow<'a, str>,
32 },
33 ModDecl {
34 name: &'a str,
35 content: Cow<'a, str>,
36 },
37 Use(Cow<'a, str>),
38 Const {
39 name: &'a str,
40 content: Cow<'a, str>,
41 },
42 Type {
43 name: &'a str,
44 content: Cow<'a, str>,
45 },
46 Func {
47 name: &'a str,
48 content: Cow<'a, str>,
49 },
50 Trait {
51 name: &'a str,
52 content: SortableContent<'a>,
53 },
54 Impl {
55 name: TypeIdent<'a>,
56 trt: Option<&'a str>,
57 content: SortableContent<'a>,
58 },
59 MacroInvocation(Cow<'a, str>),
60 Mod {
61 name: &'a str,
62 content: SortableContent<'a>,
63 },
64}
65
66#[derive(Debug, Clone, PartialEq, Eq, Hash)]
67struct Module<'a> {
68 items: Vec<(bool, Item<'a>)>,
69}
70
71#[derive(Debug, Clone, PartialEq, Eq, Hash)]
72struct SortableContent<'a> {
73 before: Cow<'a, str>,
74 inner: Module<'a>,
75 after: Cow<'a, str>,
76}
77
78#[derive(Debug, Clone, PartialEq, Eq, Hash)]
79struct TypeIdent<'a> {
80 name: &'a str,
81 generics: Option<&'a str>,
82 reference_type: Option<&'a str>,
83}
84
85fn main() -> ExitCode {
86 // Parse commandline arguments
87 let mut args = Arguments::from_env();
88 if args.contains(["-h", "--help"]) {
89 print!("{}", CLI_HELP);
90 return ExitCode::SUCCESS;
91 }
92 let cli: Cli = match args.try_into() {
93 Ok(cli) => cli,
94 Err(e) => {
95 eprintln!("Error: {}", e);
96 return ExitCode::FAILURE;
97 }
98 };
99 // Run the main program
100 match cli.run() {
101 Ok(code) => code,
102 Err(e) => {
103 eprintln!("Error: {}", e);
104 ExitCode::FAILURE
105 }
106 }
107}
108
109impl Cli {
110 fn run(&self) -> Result<ExitCode, String> {
111 let mut parser = Parser::new();
112 parser
113 .set_language(&tree_sitter_rust::LANGUAGE.into())
114 .expect("Error loading Rust grammar");
115
116 let text = std::fs::read_to_string(&self.path)
117 .map_err(|e| format!("unable to read file: {}", e))?;
118
119 let Some(tree) = parser.parse(&text, None) else {
120 return Err("unable to parse file".to_owned());
121 };
122
123 let root = tree.root_node();
124 assert_eq!(root.kind(), "source_file");
125 let mut root = Module::from_node(&text, root);
126
127 let is_sorted = root.is_sorted(self.check);
128 if self.check {
129 return if is_sorted {
130 Ok(ExitCode::SUCCESS)
131 } else {
132 Ok(ExitCode::FAILURE)
133 };
134 }
135 if self.overwrite && is_sorted {
136 return Ok(ExitCode::SUCCESS);
137 }
138
139 root.sort();
140
141 if self.overwrite {
142 std::fs::write(&self.path, root.to_string())
143 .map_err(|e| format!("unable to write file: {}", e))?;
144 } else {
145 println!("{}", root);
146 }
147
148 Ok(ExitCode::SUCCESS)
149 }
150}
151
152impl TryFrom<Arguments> for Cli {
153 type Error = String;
154
155 fn try_from(mut args: Arguments) -> Result<Self, Self::Error> {
156 let cli = Cli {
157 check: args.contains(["-c", "--check"]),
158 overwrite: args.contains(["-w", "--write"]),
159 path: args
160 .free_from_os_str::<_, Infallible>(|s| Ok(PathBuf::from(s)))
161 .unwrap(),
162 };
163
164 let remaining = args.finish();
165 match remaining.len() {
166 0 => Ok(()),
167 1 => Err(format!(
168 "unexpected argument: '{}'",
169 remaining[0].to_string_lossy()
170 )),
171 _ => Err(format!(
172 "unexpected arguments: {}",
173 remaining
174 .into_iter()
175 .map(|s| format!("'{}'", s.to_string_lossy()))
176 .collect::<Vec<_>>()
177 .join(", ")
178 )),
179 }?;
180 Ok(cli)
181 }
182}
183
184impl<'a> Item<'a> {
185 fn append_content(&mut self, text: &str) {
186 match self {
187 Item::Macro { content, .. }
188 | Item::ModDecl { content, .. }
189 | Item::Const { content, .. }
190 | Item::Type { content, .. }
191 | Item::Func { content, .. }
192 | Item::InnerDoc(content)
193 | Item::Use(content)
194 | Item::MacroInvocation(content) => {
195 *content = Cow::Owned(format!("{}{}", content, text));
196 }
197 Item::Impl { .. } | Item::Mod { .. } | Item::Trait { .. } => {
198 // Cannot add content to these items
199 }
200 }
201 }
202
203 fn item_order(&self) -> u8 {
204 match self {
205 Item::InnerDoc(_) => 0,
206 Item::Macro { .. } => 1,
207 Item::ModDecl { .. } => 2,
208 Item::Use(_) => 3,
209 Item::Const { .. } => 4,
210 Item::Type { .. } => 5,
211 Item::Trait { .. } => 5,
212 Item::Func { .. } => 6,
213 Item::Impl { .. } => 7,
214 Item::MacroInvocation(_) => 8,
215 Item::Mod { .. } => 9,
216 }
217 }
218
219 fn maybe_item(text: &'a str, node: Node<'a>, start: Option<usize>) -> Option<Self> {
220 let get_field_str = |field_name| {
221 node.child_by_field_name(field_name)
222 .map(|n| n.utf8_text(text.as_bytes()).unwrap())
223 };
224
225 let start = start.unwrap_or(node.start_byte());
226 let end = if node.utf8_text(text.as_bytes()).unwrap().ends_with('\n') {
227 node.end_byte() - 1
228 } else {
229 node.end_byte()
230 };
231 let content: Cow<'a, str> = Cow::Borrowed(&text[start..end]);
232 match node.kind() {
233 "attribute_item" => {
234 // Ignore and add to the next item
235 None
236 }
237 "block_comment" | "line_comment" => {
238 let comment = node.utf8_text(text.as_bytes()).unwrap();
239 if comment.starts_with("//!") || comment.starts_with("/*!") {
240 // Doc comment for the file (ensure that it's at the top of the file).
241 Some(Self::InnerDoc(content))
242 } else {
243 None // Move comment with the next item
244 }
245 }
246 "const_item" | "static_item" => {
247 let name = get_field_str("name").unwrap();
248 Some(Self::Const { name, content })
249 }
250 "associated_type" | "enum_item" | "struct_item" | "type_item" => {
251 let name = get_field_str("name").unwrap();
252 Some(Self::Type { name, content })
253 }
254 "function_item" | "function_signature_item" => {
255 let name = get_field_str("name").unwrap();
256 Some(Self::Func { name, content })
257 }
258 "trait_item" => {
259 let name = get_field_str("name").unwrap();
260 let content = SortableContent::within_node(text, node, Some(start), "body");
261 Some(Self::Trait { name, content })
262 }
263 "impl_item" => {
264 let name = TypeIdent::from_node(text, node.child_by_field_name("type").unwrap());
265 let trt = get_field_str("trait");
266 let content = SortableContent::within_node(text, node, Some(start), "body");
267 Some(Self::Impl { name, trt, content })
268 }
269 "inner_attribute_item" => {
270 // Should be at the top (treat like inner doc, to keep it in the chosen
271 // order compared to the module documentation).
272 Some(Self::InnerDoc(content))
273 }
274 "macro_definition" => {
275 let name = get_field_str("name").unwrap();
276 Some(Self::Macro { name, content })
277 }
278 "macro_invocation" => Some(Self::MacroInvocation(content)),
279 "mod_item" => {
280 let name = get_field_str("name").unwrap();
281 if node.child_by_field_name("body").is_some() {
282 let content = SortableContent::within_node(text, node, Some(start), "body");
283 Some(Self::Mod { name, content })
284 } else {
285 Some(Self::ModDecl { name, content })
286 }
287 }
288 "use_declaration" => Some(Self::Use(content)),
289 _ => panic!(
290 "unexpected node kind: {}\ncontent: {}",
291 node.kind(),
292 content
293 ),
294 }
295 }
296}
297
298impl Display for Item<'_> {
299 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 match self {
301 Item::InnerDoc(content)
302 | Item::Macro { content, .. }
303 | Item::MacroInvocation(content)
304 | Item::ModDecl { content, .. }
305 | Item::Use(content)
306 | Item::Const { content, .. }
307 | Item::Type { content, .. }
308 | Item::Func { content, .. } => write!(f, "{content}"),
309 Item::Mod { content, .. }
310 | Item::Impl { content, .. }
311 | Item::Trait { content, .. } => {
312 write!(f, "{content}")
313 }
314 }
315 }
316}
317
318impl Ord for Item<'_> {
319 fn cmp(&self, other: &Self) -> Ordering {
320 use Item::*;
321
322 let self_order = self.item_order();
323 let other_order = other.item_order();
324 if self_order != other_order {
325 return self_order.cmp(&other_order);
326 }
327 match (self, other) {
328 (InnerDoc(_), InnerDoc(_)) => Ordering::Equal,
329 (Const { name: a, .. }, Const { name: b, .. })
330 | (Macro { name: a, .. }, Macro { name: b, .. })
331 | (Mod { name: a, .. }, Mod { name: b, .. })
332 | (ModDecl { name: a, .. }, ModDecl { name: b, .. })
333 | (
334 Type { name: a, .. } | Trait { name: a, .. },
335 Type { name: b, .. } | Trait { name: b, .. },
336 )
337 | (Func { name: a, .. }, Func { name: b, .. }) => a.cmp(b),
338 (Use(_), Use(_)) | (MacroInvocation(_), MacroInvocation(_)) => Ordering::Equal,
339 (
340 Impl {
341 name: a, trt: t_a, ..
342 },
343 Impl {
344 name: b, trt: t_b, ..
345 },
346 ) => {
347 let name_order = a.name.cmp(b.name);
348 if name_order == Ordering::Equal {
349 let trt_order = t_a.unwrap_or("").cmp(t_b.unwrap_or(""));
350 if trt_order == Ordering::Equal {
351 let a_parts = (a.generics.unwrap_or(""), a.reference_type.unwrap_or(""));
352 let b_parts = (b.generics.unwrap_or(""), b.reference_type.unwrap_or(""));
353 a_parts.cmp(&b_parts)
354 } else {
355 trt_order
356 }
357 } else {
358 name_order
359 }
360 }
361 _ => {
362 // eprintln!("{} -- {}", self, other);
363 unreachable!();
364 }
365 }
366 }
367}
368
369impl PartialOrd for Item<'_> {
370 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
371 Some(self.cmp(other))
372 }
373}
374
375impl<'a> Module<'a> {
376 pub fn from_node(text: &'a str, root: Node<'a>) -> Self {
377 assert!(matches!(root.kind(), "source_file" | "declaration_list"));
378 let mut cursor = root.walk();
379 cursor.goto_first_child();
380
381 let mut items: Vec<(bool, Item)> = Vec::new();
382 let mut start = None;
383 let mut last = None;
384 if cursor.node().kind() == "{" {
385 last = Some(cursor.node().end_byte());
386 cursor.goto_next_sibling();
387 }
388 loop {
389 if cursor.node().kind() == "}" {
390 assert!(!cursor.goto_next_sibling());
391 break;
392 }
393 let node = cursor.node();
394 // eprintln!("{} : {}\n\n", node.kind(), node.to_sexp());
395 let inbetween =
396 &text[last.unwrap_or(root.start_byte())..start.unwrap_or(node.start_byte())];
397 if node.kind() == "empty_statement" {
398 if let Some((_, it)) = items.last_mut() {
399 it.append_content(";");
400 }
401 debug_assert!(
402 inbetween.trim().is_empty(),
403 "unexpected skipped content: {:?}",
404 inbetween
405 );
406 start = None;
407 last = Some(node.end_byte());
408 } else if let Some(item) = Item::maybe_item(text, node, start) {
409 debug_assert!(
410 inbetween.trim().is_empty(),
411 "unexpected skipped content: {:?}",
412 inbetween
413 );
414 let newline_before = inbetween.contains("\n\n");
415 items.push((items.is_empty() || newline_before, item));
416 start = None;
417 last = Some(node.end_byte());
418 } else if start.is_none() {
419 start = Some(node.start_byte());
420 }
421 if !cursor.goto_next_sibling() {
422 break;
423 }
424 }
425
426 Self { items }
427 }
428
429 pub fn is_sorted(&self, print_diff: bool) -> bool {
430 for it in &self.items {
431 match &it.1 {
432 Item::Mod { content, .. }
433 | Item::Impl { content, .. }
434 | Item::Trait { content, .. } => {
435 if !content.is_sorted(print_diff) {
436 return false;
437 }
438 }
439 _ => {}
440 }
441 }
442 for window in self.items.windows(2) {
443 if window[0].1 > window[1].1 {
444 if print_diff {
445 eprintln!(
446 "Expected \n\"\"\"\n{}\n\"\"\"\n before \n\"\"\"\n{}\n\"\"\"",
447 window[1].1, window[0].1
448 );
449 }
450 return false;
451 }
452 }
453 true
454 }
455
456 pub fn sort(&mut self) {
457 for it in self.items.iter_mut() {
458 match &mut it.1 {
459 Item::Mod { content, .. }
460 | Item::Impl { content, .. }
461 | Item::Trait { content, .. } => content.sort(),
462 _ => {}
463 }
464 }
465 self.items.sort_unstable_by(|a, b| a.1.cmp(&b.1));
466 }
467}
468
469impl Display for Module<'_> {
470 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
471 let mut last = None;
472 for (newline, item) in &self.items {
473 if *newline || (last.is_some() && last != Some(item.item_order())) {
474 writeln!(f)?;
475 }
476 writeln!(f, "{}", item)?;
477 last = Some(item.item_order());
478 }
479 Ok(())
480 }
481}
482
483impl<'a> SortableContent<'a> {
484 fn is_sorted(&self, print_diff: bool) -> bool {
485 self.inner.is_sorted(print_diff)
486 }
487
488 fn sort(&mut self) {
489 self.inner.sort();
490 }
491
492 fn within_node(
493 text: &'a str,
494 node: Node<'a>,
495 start: Option<usize>,
496 child: &'static str,
497 ) -> Self {
498 let start = start.unwrap_or(node.start_byte());
499 let body = node.child_by_field_name(child).unwrap();
500
501 let mut cursor = body.walk();
502 cursor.goto_first_child();
503 assert_eq!(cursor.node().kind(), "{");
504 let before = Cow::Borrowed(&text[start..cursor.node().end_byte()]);
505
506 cursor.goto_parent();
507 cursor.goto_last_child();
508 assert_eq!(cursor.node().kind(), "}");
509 let after = Cow::Borrowed(&text[cursor.node().start_byte()..node.end_byte()]);
510
511 let inner = Module::from_node(text, body);
512 Self {
513 before,
514 inner,
515 after,
516 }
517 }
518}
519
520impl Display for SortableContent<'_> {
521 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
522 write!(f, "{}{}{}", self.before, self.inner, self.after)
523 }
524}
525
526impl<'a> TypeIdent<'a> {
527 fn from_node(text: &'a str, node: Node<'a>) -> Self {
528 let get_field_str = |field_name| {
529 node.child_by_field_name(field_name)
530 .map(|n| n.utf8_text(text.as_bytes()).unwrap())
531 };
532
533 match node.kind() {
534 "array_type" => {
535 let inner = node.child_by_field_name("element").unwrap();
536 let mut ty = TypeIdent::from_node(text, inner);
537 let reference_str = &text[node.start_byte()..inner.start_byte()];
538 ty.reference_type = Some(reference_str);
539 ty
540 }
541 "generic_type" => {
542 let name = get_field_str("type").unwrap();
543 let generics = get_field_str("type_arguments");
544 debug_assert!(generics.is_some());
545 Self {
546 name,
547 generics,
548 reference_type: None,
549 }
550 }
551 "reference_type" => {
552 let inner = node.child_by_field_name("type").unwrap();
553 let mut ty = TypeIdent::from_node(text, inner);
554 let reference_str = &text[node.start_byte()..inner.start_byte()];
555 ty.reference_type = Some(reference_str);
556 ty
557 }
558 "type_identifier" | "primitive_type" | "bounded_type" => Self {
559 name: node.utf8_text(text.as_bytes()).unwrap(),
560 generics: None,
561 reference_type: None,
562 },
563 _ => panic!(
564 "invalid type identifier node: {}, `{}'",
565 node.kind(),
566 node.utf8_text(text.as_bytes()).unwrap()
567 ),
568 }
569 }
570}