this repo has no description
at trunk 211 lines 6.2 kB view raw
1#!/usr/bin/env python3 2# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com) 3 4import argparse 5import os 6import re 7import shutil 8import subprocess 9import sys 10from abc import ABC, abstractmethod 11from pathlib import Path 12from typing import List 13 14 15class VCS(ABC): 16 @abstractmethod 17 def root(self) -> Path: 18 """Return the root of the repo""" 19 raise NotImplementedError("Root of the VCS not supported") 20 21 @abstractmethod 22 def has_changes(self) -> bool: 23 """Return whether there are modified files with respect to the last commit""" 24 raise NotImplementedError("Checking if there are changes is not supported") 25 26 @abstractmethod 27 def get_commit_title(self) -> str: 28 """Return the title of the last commit""" 29 raise NotImplementedError("Getting the commit title is not supported") 30 31 @abstractmethod 32 def get_changed_files(self, from_commit: bool = False) -> List[Path]: 33 pass 34 35 @abstractmethod 36 def authored_previous(self) -> bool: 37 pass 38 39 @staticmethod 40 def infer_vcs(dir: Path): 41 """Return the VCS inferred from the given directory""" 42 path = Path(os.getcwd()) 43 while True: 44 if Path(path, ".git").exists(): 45 raise RuntimeError("git is not supported") 46 if Path(path, ".hg").exists(): 47 return Mercurial() 48 if str(path) == path.root: 49 raise RuntimeError(f"Couldn't find vcs root from {os.getcwd()}") 50 path = path.parent 51 52 53class Mercurial(VCS): 54 files_re = re.compile("(\.cpp|\.h|\.py)$") 55 56 def __init__(self) -> None: 57 self.exe = get_exe("hg") 58 59 def root(self) -> Path: 60 return Path(subprocess.check_output([str(self.exe), "root"]).decode().strip()) 61 62 def has_changes(self) -> bool: 63 return ( 64 subprocess.check_output( 65 # Ask for all files except untracked. 66 [str(self.exe), "status", "-mard"] 67 ) 68 .decode() 69 .strip() 70 != "" 71 ) 72 73 def get_commit_title(self) -> str: 74 return ( 75 subprocess.check_output([str(self.exe), "log", "-r.", "-T{desc|firstline}"]) 76 .decode() 77 .strip() 78 ) 79 80 def get_changed_files(self, from_commit: bool = False) -> List[Path]: 81 files = ( 82 subprocess.check_output( 83 [str(self.exe), "status", "--rev", ".^" if from_commit else ".", "-man"] 84 ) 85 .decode() 86 .strip() 87 .split("\n") 88 ) 89 return [Path(f) for f in files if (type(self).files_re.search(f) is not None)] 90 91 def authored_previous(self) -> bool: 92 author = ( 93 subprocess.check_output([str(self.exe), "log", "-r.", "-T{author}"]) 94 .decode() 95 .strip() 96 ) 97 return os.environ["USER"] in author 98 99 100class Formatter(ABC): 101 @abstractmethod 102 def format(self, file: Path) -> str: 103 raise NotImplementedError("Formatting not supported") 104 105 @staticmethod 106 def for_file(file: Path): 107 if file.suffix == ".py": 108 return BlackFormat() 109 elif file.suffix in {".cpp", ".h"}: 110 return ClangFormat() 111 112 113class ClangFormat(Formatter): 114 def __init__(self) -> None: 115 self.exe = get_exe("clang-format") 116 117 def format(self, file: Path) -> str: 118 return subprocess.check_output( 119 [str(self.exe), "-i", "-style=file", str(file)] 120 ).decode() 121 122 123class BlackFormat(Formatter): 124 def __init__(self) -> None: 125 self.exe = get_exe("black") 126 127 def format(self, file: Path) -> str: 128 return subprocess.check_output([str(self.exe), "--quiet", str(file)]).decode() 129 130 131def main() -> int: 132 args = parse_args() 133 vcs = VCS.infer_vcs(Path(".")) 134 vcs_root = vcs.root() 135 assert vcs_root.exists() 136 files: List[Path] = [] 137 if args.all: 138 print("Formatting all files...") 139 files += list(vcs_root.glob("**/*.py")) 140 files += list(vcs_root.glob("**/*.cpp")) 141 files += list(vcs_root.glob("**/*.h")) 142 elif vcs.has_changes(): 143 print("Formatting only modified files...") 144 files = vcs.get_changed_files() 145 elif vcs.authored_previous(): 146 print("There are no modified files, but you authored the last commit:") 147 print(vcs.get_commit_title()) 148 reply = input("Format files in that commit? [Y/n] ") 149 if not reply or reply.lower() == "y": 150 files = vcs.get_changed_files(from_commit=True) 151 else: 152 print( 153 "You have no modified files and you didn't recently commit on this " 154 "branch." 155 ) 156 print(f"To format all files: {sys.argv[0]} -a") 157 return 1 158 159 if not files: 160 print("No files to format") 161 return 0 162 163 files = [f for f in sorted(files) if file_should_be_formatted(f)] 164 print( 165 "Formatting {num_files} {extensions} file{plural}".format( 166 num_files=len(files), 167 extensions=" and ".join(sorted({f.suffix for f in files})), 168 plural="s" if len(files) != 1 else "", 169 ) 170 ) 171 172 for f in files: 173 Formatter.for_file(f).format(f) 174 # Don't buffer stdout because otherwise the dots aren't written out 175 # to show progress. 176 sys.stdout.write(".") 177 sys.stdout.flush() 178 print() 179 print("Done") 180 return 0 181 182 183def parse_args(): 184 parser = argparse.ArgumentParser() 185 parser.add_argument( 186 "--all", "-a", action="store_true", help="Format all files in the repo" 187 ) 188 return parser.parse_args() 189 190 191def file_should_be_formatted(file: Path) -> bool: 192 """Return true if the given file should be formatted""" 193 file_str = str(file) 194 return ( 195 "benchmarks" not in file_str 196 and "third-party" not in file_str 197 and "ext/config" not in file_str 198 and "ext/Include" not in file_str 199 and "library/importlib" not in file_str 200 ) 201 202 203def get_exe(exe: str) -> Path: 204 path = shutil.which(exe) 205 if path is None: 206 raise RuntimeError(f"Couldn't find {exe} in PATH") 207 return Path(path) 208 209 210if __name__ == "__main__": 211 sys.exit(main())