this repo has no description
1#!/usr/bin/env python3
2# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
3
4import argparse
5import os
6import re
7import shutil
8import subprocess
9import sys
10from abc import ABC, abstractmethod
11from pathlib import Path
12from typing import List
13
14
15class VCS(ABC):
16 @abstractmethod
17 def root(self) -> Path:
18 """Return the root of the repo"""
19 raise NotImplementedError("Root of the VCS not supported")
20
21 @abstractmethod
22 def has_changes(self) -> bool:
23 """Return whether there are modified files with respect to the last commit"""
24 raise NotImplementedError("Checking if there are changes is not supported")
25
26 @abstractmethod
27 def get_commit_title(self) -> str:
28 """Return the title of the last commit"""
29 raise NotImplementedError("Getting the commit title is not supported")
30
31 @abstractmethod
32 def get_changed_files(self, from_commit: bool = False) -> List[Path]:
33 pass
34
35 @abstractmethod
36 def authored_previous(self) -> bool:
37 pass
38
39 @staticmethod
40 def infer_vcs(dir: Path):
41 """Return the VCS inferred from the given directory"""
42 path = Path(os.getcwd())
43 while True:
44 if Path(path, ".git").exists():
45 raise RuntimeError("git is not supported")
46 if Path(path, ".hg").exists():
47 return Mercurial()
48 if str(path) == path.root:
49 raise RuntimeError(f"Couldn't find vcs root from {os.getcwd()}")
50 path = path.parent
51
52
53class Mercurial(VCS):
54 files_re = re.compile("(\.cpp|\.h|\.py)$")
55
56 def __init__(self) -> None:
57 self.exe = get_exe("hg")
58
59 def root(self) -> Path:
60 return Path(subprocess.check_output([str(self.exe), "root"]).decode().strip())
61
62 def has_changes(self) -> bool:
63 return (
64 subprocess.check_output(
65 # Ask for all files except untracked.
66 [str(self.exe), "status", "-mard"]
67 )
68 .decode()
69 .strip()
70 != ""
71 )
72
73 def get_commit_title(self) -> str:
74 return (
75 subprocess.check_output([str(self.exe), "log", "-r.", "-T{desc|firstline}"])
76 .decode()
77 .strip()
78 )
79
80 def get_changed_files(self, from_commit: bool = False) -> List[Path]:
81 files = (
82 subprocess.check_output(
83 [str(self.exe), "status", "--rev", ".^" if from_commit else ".", "-man"]
84 )
85 .decode()
86 .strip()
87 .split("\n")
88 )
89 return [Path(f) for f in files if (type(self).files_re.search(f) is not None)]
90
91 def authored_previous(self) -> bool:
92 author = (
93 subprocess.check_output([str(self.exe), "log", "-r.", "-T{author}"])
94 .decode()
95 .strip()
96 )
97 return os.environ["USER"] in author
98
99
100class Formatter(ABC):
101 @abstractmethod
102 def format(self, file: Path) -> str:
103 raise NotImplementedError("Formatting not supported")
104
105 @staticmethod
106 def for_file(file: Path):
107 if file.suffix == ".py":
108 return BlackFormat()
109 elif file.suffix in {".cpp", ".h"}:
110 return ClangFormat()
111
112
113class ClangFormat(Formatter):
114 def __init__(self) -> None:
115 self.exe = get_exe("clang-format")
116
117 def format(self, file: Path) -> str:
118 return subprocess.check_output(
119 [str(self.exe), "-i", "-style=file", str(file)]
120 ).decode()
121
122
123class BlackFormat(Formatter):
124 def __init__(self) -> None:
125 self.exe = get_exe("black")
126
127 def format(self, file: Path) -> str:
128 return subprocess.check_output([str(self.exe), "--quiet", str(file)]).decode()
129
130
131def main() -> int:
132 args = parse_args()
133 vcs = VCS.infer_vcs(Path("."))
134 vcs_root = vcs.root()
135 assert vcs_root.exists()
136 files: List[Path] = []
137 if args.all:
138 print("Formatting all files...")
139 files += list(vcs_root.glob("**/*.py"))
140 files += list(vcs_root.glob("**/*.cpp"))
141 files += list(vcs_root.glob("**/*.h"))
142 elif vcs.has_changes():
143 print("Formatting only modified files...")
144 files = vcs.get_changed_files()
145 elif vcs.authored_previous():
146 print("There are no modified files, but you authored the last commit:")
147 print(vcs.get_commit_title())
148 reply = input("Format files in that commit? [Y/n] ")
149 if not reply or reply.lower() == "y":
150 files = vcs.get_changed_files(from_commit=True)
151 else:
152 print(
153 "You have no modified files and you didn't recently commit on this "
154 "branch."
155 )
156 print(f"To format all files: {sys.argv[0]} -a")
157 return 1
158
159 if not files:
160 print("No files to format")
161 return 0
162
163 files = [f for f in sorted(files) if file_should_be_formatted(f)]
164 print(
165 "Formatting {num_files} {extensions} file{plural}".format(
166 num_files=len(files),
167 extensions=" and ".join(sorted({f.suffix for f in files})),
168 plural="s" if len(files) != 1 else "",
169 )
170 )
171
172 for f in files:
173 Formatter.for_file(f).format(f)
174 # Don't buffer stdout because otherwise the dots aren't written out
175 # to show progress.
176 sys.stdout.write(".")
177 sys.stdout.flush()
178 print()
179 print("Done")
180 return 0
181
182
183def parse_args():
184 parser = argparse.ArgumentParser()
185 parser.add_argument(
186 "--all", "-a", action="store_true", help="Format all files in the repo"
187 )
188 return parser.parse_args()
189
190
191def file_should_be_formatted(file: Path) -> bool:
192 """Return true if the given file should be formatted"""
193 file_str = str(file)
194 return (
195 "benchmarks" not in file_str
196 and "third-party" not in file_str
197 and "ext/config" not in file_str
198 and "ext/Include" not in file_str
199 and "library/importlib" not in file_str
200 )
201
202
203def get_exe(exe: str) -> Path:
204 path = shutil.which(exe)
205 if path is None:
206 raise RuntimeError(f"Couldn't find {exe} in PATH")
207 return Path(path)
208
209
210if __name__ == "__main__":
211 sys.exit(main())