Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
1#!/usr/bin/env python
2
3import argparse
4import base64
5import datetime
6import json
7import logging
8import os
9import subprocess
10import sys
11from collections.abc import Callable
12from dataclasses import asdict, dataclass, replace
13from pathlib import Path
14from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
15from urllib.request import Request, urlopen
16
17import git
18from packaging.version import Version, parse
19
20INDEX_URL = "https://azcliextensionsync.blob.core.windows.net/index1/index.json"
21
22logger = logging.getLogger(__name__)
23
24
25@dataclass(frozen=True)
26class Ext:
27 pname: str
28 version: Version
29 url: str
30 hash: str
31 description: str
32
33
34def _read_cached_index(path: Path) -> Tuple[datetime.datetime, Any]:
35 with open(path, "r") as f:
36 data = f.read()
37
38 j = json.loads(data)
39 cache_date_str = j["cache_date"]
40 if cache_date_str:
41 cache_date = datetime.datetime.fromisoformat(cache_date_str)
42 else:
43 cache_date = datetime.datetime.min
44 return cache_date, data
45
46
47def _write_index_to_cache(data: Any, path: Path) -> None:
48 j = json.loads(data)
49 j["cache_date"] = datetime.datetime.now().isoformat()
50 with open(path, "w") as f:
51 json.dump(j, f, indent=2)
52
53
54def _fetch_remote_index() -> Any:
55 r = Request(INDEX_URL)
56 with urlopen(r) as resp:
57 return resp.read()
58
59
60def get_extension_index(cache_dir: Path) -> Any:
61 index_file = cache_dir / "index.json"
62 os.makedirs(cache_dir, exist_ok=True)
63
64 try:
65 index_cache_date, index_data = _read_cached_index(index_file)
66 except FileNotFoundError:
67 logger.info("index has not been cached, downloading from source")
68 logger.info("creating index cache in %s", index_file)
69 _write_index_to_cache(_fetch_remote_index(), index_file)
70 return get_extension_index(cache_dir)
71
72 if (
73 index_cache_date
74 and datetime.datetime.now() - index_cache_date > datetime.timedelta(days=1)
75 ):
76 logger.info(
77 "cache is outdated (%s), refreshing",
78 datetime.datetime.now() - index_cache_date,
79 )
80 _write_index_to_cache(_fetch_remote_index(), index_file)
81 return get_extension_index(cache_dir)
82
83 logger.info("using index cache from %s", index_file)
84 return json.loads(index_data)
85
86
87def _read_extension_set(extensions_generated: Path) -> Set[Ext]:
88 with open(extensions_generated, "r") as f:
89 data = f.read()
90
91 parsed_exts = {Ext(**json_ext) for _pname, json_ext in json.loads(data).items()}
92 parsed_exts_with_ver = set()
93 for ext in parsed_exts:
94 ext2 = replace(ext, version=parse(ext.version))
95 parsed_exts_with_ver.add(ext2)
96
97 return parsed_exts_with_ver
98
99
100def _write_extension_set(extensions_generated: Path, extensions: Set[Ext]) -> None:
101 set_without_ver = {replace(ext, version=str(ext.version)) for ext in extensions}
102 ls = list(set_without_ver)
103 ls.sort(key=lambda e: e.pname)
104 with open(extensions_generated, "w") as f:
105 json.dump({ext.pname: asdict(ext) for ext in ls}, f, indent=2)
106 f.write("\n")
107
108
109def _convert_hash_digest_from_hex_to_b64_sri(s: str) -> str:
110 try:
111 b = bytes.fromhex(s)
112 except ValueError as err:
113 logger.error("not a hex value: %s", str(err))
114 raise err
115
116 return f"sha256-{base64.b64encode(b).decode('utf-8')}"
117
118
119def _commit(repo: git.Repo, message: str, files: List[Path], actor: git.Actor) -> None:
120 repo.index.add([str(f.resolve()) for f in files])
121 if repo.index.diff("HEAD"):
122 logger.info(f'committing to nixpkgs "{message}"')
123 repo.index.commit(message, author=actor, committer=actor)
124 else:
125 logger.warning("no changes in working tree to commit")
126
127
128def _filter_invalid(o: Dict[str, Any]) -> bool:
129 if "metadata" not in o:
130 logger.warning("extension without metadata")
131 return False
132 metadata = o["metadata"]
133 if "name" not in metadata:
134 logger.warning("extension without name")
135 return False
136 if "version" not in metadata:
137 logger.warning(f"{metadata['name']} without version")
138 return False
139 if "azext.minCliCoreVersion" not in metadata:
140 logger.warning(
141 f"{metadata['name']} {metadata['version']} does not have azext.minCliCoreVersion"
142 )
143 return False
144 if "summary" not in metadata:
145 logger.info(f"{metadata['name']} {metadata['version']} without summary")
146 return False
147 if "downloadUrl" not in o:
148 logger.warning(f"{metadata['name']} {metadata['version']} without downloadUrl")
149 return False
150 if "sha256Digest" not in o:
151 logger.warning(f"{metadata['name']} {metadata['version']} without sha256Digest")
152 return False
153
154 return True
155
156
157def _filter_compatible(o: Dict[str, Any], cli_version: Version) -> bool:
158 minCliVersion = parse(o["metadata"]["azext.minCliCoreVersion"])
159 return bool(cli_version >= minCliVersion)
160
161
162def _transform_dict_to_obj(o: Dict[str, Any]) -> Ext:
163 m = o["metadata"]
164 return Ext(
165 pname=m["name"],
166 version=parse(m["version"]),
167 url=o["downloadUrl"],
168 hash=_convert_hash_digest_from_hex_to_b64_sri(o["sha256Digest"]),
169 description=m["summary"].rstrip("."),
170 )
171
172
173def _get_latest_version(versions: dict) -> dict:
174 return max(versions, key=lambda e: parse(e["metadata"]["version"]), default=None)
175
176
177def processExtension(
178 extVersions: dict,
179 cli_version: Version,
180 ext_name: Optional[str] = None,
181 requirements: bool = False,
182) -> Optional[Ext]:
183 versions = filter(_filter_invalid, extVersions)
184 versions = filter(lambda v: _filter_compatible(v, cli_version), versions)
185 latest = _get_latest_version(versions)
186 if not latest:
187 return None
188 if ext_name and latest["metadata"]["name"] != ext_name:
189 return None
190 if not requirements and "run_requires" in latest["metadata"]:
191 return None
192
193 return _transform_dict_to_obj(latest)
194
195
196def _diff_sets(
197 set_local: Set[Ext], set_remote: Set[Ext]
198) -> Tuple[Set[Ext], Set[Ext], Set[Tuple[Ext, Ext]]]:
199 local_exts = {ext.pname: ext for ext in set_local}
200 remote_exts = {ext.pname: ext for ext in set_remote}
201 only_local = local_exts.keys() - remote_exts.keys()
202 only_remote = remote_exts.keys() - local_exts.keys()
203 both = remote_exts.keys() & local_exts.keys()
204 return (
205 {local_exts[pname] for pname in only_local},
206 {remote_exts[pname] for pname in only_remote},
207 {(local_exts[pname], remote_exts[pname]) for pname in both},
208 )
209
210
211def _filter_updated(e: Tuple[Ext, Ext]) -> bool:
212 prev, new = e
213 return prev != new
214
215
216@dataclass(frozen=True)
217class AttrPos:
218 file: str
219 line: int
220 column: int
221
222
223def nix_get_value(attr_path: str) -> Optional[str]:
224 try:
225 output = (
226 subprocess.run(
227 [
228 "nix-instantiate",
229 "--eval",
230 "--strict",
231 "--json",
232 "-E",
233 f"with import ./. {{ }}; {attr_path}",
234 ],
235 stdout=subprocess.PIPE,
236 text=True,
237 check=True,
238 )
239 .stdout.rstrip()
240 .strip('"')
241 )
242 except subprocess.CalledProcessError as e:
243 logger.error("failed to nix-instantiate: %s", e)
244 return None
245 return output
246
247
248def nix_unsafe_get_attr_pos(attr: str, attr_path: str) -> Optional[AttrPos]:
249 try:
250 output = subprocess.run(
251 [
252 "nix-instantiate",
253 "--eval",
254 "--strict",
255 "--json",
256 "-E",
257 f'with import ./. {{ }}; (builtins.unsafeGetAttrPos "{attr}" {attr_path})',
258 ],
259 stdout=subprocess.PIPE,
260 text=True,
261 check=True,
262 ).stdout.rstrip()
263 except subprocess.CalledProcessError as e:
264 logger.error("failed to unsafeGetAttrPos: %s", e)
265 return None
266 if output == "null":
267 logger.error("failed to unsafeGetAttrPos: nix-instantiate returned 'null'")
268 return None
269 pos = json.loads(output)
270 return AttrPos(pos["file"], pos["line"] - 1, pos["column"])
271
272
273def edit_file(file: str, rewrite: Callable[[str], str]) -> None:
274 with open(file, "r") as f:
275 lines = f.readlines()
276 lines = [rewrite(line) for line in lines]
277 with open(file, "w") as f:
278 f.writelines(lines)
279
280
281def edit_file_at_pos(pos: AttrPos, rewrite: Callable[[str], str]) -> None:
282 with open(pos.file, "r") as f:
283 lines = f.readlines()
284 lines[pos.line] = rewrite(lines[pos.line])
285 with open(pos.file, "w") as f:
286 f.writelines(lines)
287
288
289def read_value_at_pos(pos: AttrPos) -> str:
290 with open(pos.file, "r") as f:
291 lines = f.readlines()
292 return value_from_nix_line(lines[pos.line])
293
294
295def value_from_nix_line(line: str) -> str:
296 return line.split("=")[1].strip().strip(";").strip('"')
297
298
299def replace_value_in_nix_line(new: str) -> Callable[[str], str]:
300 return lambda line: line.replace(value_from_nix_line(line), new)
301
302
303def main() -> None:
304 sh = logging.StreamHandler(sys.stderr)
305 sh.setFormatter(
306 logging.Formatter(
307 "[%(asctime)s] [%(levelname)8s] --- %(message)s (%(filename)s:%(lineno)s)",
308 "%Y-%m-%d %H:%M:%S",
309 )
310 )
311 logging.basicConfig(level=logging.INFO, handlers=[sh])
312
313 parser = argparse.ArgumentParser(
314 prog="azure-cli.extensions-tool",
315 description="Script to handle Azure CLI extension updates",
316 )
317 parser.add_argument(
318 "--cli-version", type=str, help="version of azure-cli (required)"
319 )
320 parser.add_argument("--extension", type=str, help="name of extension to query")
321 parser.add_argument(
322 "--cache-dir",
323 type=Path,
324 help="path where to cache the extension index",
325 default=Path(os.getenv("XDG_CACHE_HOME", Path.home() / ".cache"))
326 / "azure-cli-extensions-tool",
327 )
328 parser.add_argument(
329 "--requirements",
330 action=argparse.BooleanOptionalAction,
331 help="whether to list extensions that have requirements",
332 )
333 parser.add_argument(
334 "--commit",
335 action=argparse.BooleanOptionalAction,
336 help="whether to commit changes to git",
337 )
338 args = parser.parse_args()
339 cli_version = parse(args.cli_version)
340
341 repo = git.Repo(Path(".").resolve(), search_parent_directories=True)
342 # Workaround for https://github.com/gitpython-developers/GitPython/issues/1923
343 author = repo.config_reader().get_value("user", "name").lstrip('"').rstrip('"')
344 email = repo.config_reader().get_value("user", "email").lstrip('"').rstrip('"')
345 actor = git.Actor(author, email)
346
347 index = get_extension_index(args.cache_dir)
348 assert index["formatVersion"] == "1" # only support formatVersion 1
349 extensions_remote = index["extensions"]
350
351 if args.extension:
352 logger.info(f"updating extension: {args.extension}")
353
354 ext = Optional[Ext]
355 for _ext_name, extension in extensions_remote.items():
356 extension = processExtension(
357 extension, cli_version, args.extension, requirements=True
358 )
359 if extension:
360 ext = extension
361 break
362 if not ext:
363 logger.error(f"Extension {args.extension} not found in index")
364 exit(1)
365
366 version_pos = nix_unsafe_get_attr_pos(
367 "version", f"azure-cli-extensions.{ext.pname}"
368 )
369 if not version_pos:
370 logger.error(
371 f"no position for attribute 'version' found on attribute path {ext.pname}"
372 )
373 exit(1)
374 version = read_value_at_pos(version_pos)
375 current_version = parse(version)
376
377 if ext.version == current_version:
378 logger.info(
379 f"no update needed for {ext.pname}, latest version is {ext.version}"
380 )
381 return
382 logger.info("updated extensions:")
383 logger.info(f" {ext.pname} {current_version} -> {ext.version}")
384 edit_file_at_pos(version_pos, replace_value_in_nix_line(str(ext.version)))
385
386 current_hash = nix_get_value(f"azure-cli-extensions.{ext.pname}.src.outputHash")
387 if not current_hash:
388 logger.error(
389 f"no attribute 'src.outputHash' found on attribute path {ext.pname}"
390 )
391 exit(1)
392 edit_file(version_pos.file, lambda line: line.replace(current_hash, ext.hash))
393
394 if args.commit:
395 commit_msg = (
396 f"azure-cli-extensions.{ext.pname}: {current_version} -> {ext.version}"
397 )
398 _commit(repo, commit_msg, [Path(version_pos.file)], actor)
399 return
400
401 logger.info("updating generated extension set")
402
403 extensions_remote_filtered = set()
404 for _ext_name, extension in extensions_remote.items():
405 extension = processExtension(extension, cli_version, args.extension)
406 if extension:
407 extensions_remote_filtered.add(extension)
408
409 extension_file = (
410 Path(repo.working_dir) / "pkgs/by-name/az/azure-cli/extensions-generated.json"
411 )
412 extensions_local = _read_extension_set(extension_file)
413 extensions_local_filtered = set()
414 if args.extension:
415 extensions_local_filtered = filter(
416 lambda ext: args.extension == ext.pname, extensions_local
417 )
418 else:
419 extensions_local_filtered = extensions_local
420
421 removed, init, updated = _diff_sets(
422 extensions_local_filtered, extensions_remote_filtered
423 )
424 updated = set(filter(_filter_updated, updated))
425
426 logger.info("initialized extensions:")
427 for ext in init:
428 logger.info(f" {ext.pname} {ext.version}")
429 logger.info("removed extensions:")
430 for ext in removed:
431 logger.info(f" {ext.pname} {ext.version}")
432 logger.info("updated extensions:")
433 for prev, new in updated:
434 logger.info(f" {prev.pname} {prev.version} -> {new.version}")
435
436 for ext in init:
437 extensions_local.add(ext)
438 commit_msg = f"azure-cli-extensions.{ext.pname}: init at {ext.version}"
439 _write_extension_set(extension_file, extensions_local)
440 if args.commit:
441 _commit(repo, commit_msg, [extension_file], actor)
442
443 for prev, new in updated:
444 extensions_local.remove(prev)
445 extensions_local.add(new)
446 commit_msg = (
447 f"azure-cli-extensions.{prev.pname}: {prev.version} -> {new.version}"
448 )
449 _write_extension_set(extension_file, extensions_local)
450 if args.commit:
451 _commit(repo, commit_msg, [extension_file], actor)
452
453 for ext in removed:
454 extensions_local.remove(ext)
455 # TODO: Add additional check why this is removed
456 # TODO: Add an alias to extensions manual?
457 commit_msg = f"azure-cli-extensions.{ext.pname}: remove"
458 _write_extension_set(extension_file, extensions_local)
459 if args.commit:
460 _commit(repo, commit_msg, [extension_file], actor)
461
462
463if __name__ == "__main__":
464 main()