1#!/usr/bin/env python
2
3import argparse
4import base64
5import datetime
6import json
7import logging
8import os
9import subprocess
10import sys
11from collections.abc import Callable
12from dataclasses import asdict, dataclass, replace
13from pathlib import Path
14from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
15from urllib.request import Request, urlopen
16
17import git
18from packaging.version import Version, parse
19
20INDEX_URL = "https://azcliextensionsync.blob.core.windows.net/index1/index.json"
21
22logger = logging.getLogger(__name__)
23
24
25@dataclass(frozen=True)
26class Ext:
27 pname: str
28 version: Version
29 url: str
30 hash: str
31 description: str
32
33
34def _read_cached_index(path: Path) -> Tuple[datetime.datetime, Any]:
35 with open(path, "r") as f:
36 data = f.read()
37
38 j = json.loads(data)
39 cache_date_str = j["cache_date"]
40 if cache_date_str:
41 cache_date = datetime.datetime.fromisoformat(cache_date_str)
42 else:
43 cache_date = datetime.datetime.min
44 return cache_date, data
45
46
47def _write_index_to_cache(data: Any, path: Path) -> None:
48 j = json.loads(data)
49 j["cache_date"] = datetime.datetime.now().isoformat()
50 with open(path, "w") as f:
51 json.dump(j, f, indent=2)
52
53
54def _fetch_remote_index() -> Any:
55 r = Request(INDEX_URL)
56 with urlopen(r) as resp:
57 return resp.read()
58
59
60def get_extension_index(cache_dir: Path) -> Any:
61 index_file = cache_dir / "index.json"
62 os.makedirs(cache_dir, exist_ok=True)
63
64 try:
65 index_cache_date, index_data = _read_cached_index(index_file)
66 except FileNotFoundError:
67 logger.info("index has not been cached, downloading from source")
68 logger.info("creating index cache in %s", index_file)
69 _write_index_to_cache(_fetch_remote_index(), index_file)
70 return get_extension_index(cache_dir)
71
72 if (
73 index_cache_date
74 and datetime.datetime.now() - index_cache_date > datetime.timedelta(days=1)
75 ):
76 logger.info(
77 "cache is outdated (%s), refreshing",
78 datetime.datetime.now() - index_cache_date,
79 )
80 _write_index_to_cache(_fetch_remote_index(), index_file)
81 return get_extension_index(cache_dir)
82
83 logger.info("using index cache from %s", index_file)
84 return json.loads(index_data)
85
86
87def _read_extension_set(extensions_generated: Path) -> Set[Ext]:
88 with open(extensions_generated, "r") as f:
89 data = f.read()
90
91 parsed_exts = {Ext(**json_ext) for _pname, json_ext in json.loads(data).items()}
92 parsed_exts_with_ver = set()
93 for ext in parsed_exts:
94 ext2 = replace(ext, version=parse(ext.version))
95 parsed_exts_with_ver.add(ext2)
96
97 return parsed_exts_with_ver
98
99
100def _write_extension_set(extensions_generated: Path, extensions: Set[Ext]) -> None:
101 set_without_ver = {replace(ext, version=str(ext.version)) for ext in extensions}
102 ls = list(set_without_ver)
103 ls.sort(key=lambda e: e.pname)
104 with open(extensions_generated, "w") as f:
105 json.dump({ext.pname: asdict(ext) for ext in ls}, f, indent=2)
106 f.write("\n")
107
108
109def _convert_hash_digest_from_hex_to_b64_sri(s: str) -> str:
110 try:
111 b = bytes.fromhex(s)
112 except ValueError as err:
113 logger.error("not a hex value: %s", str(err))
114 raise err
115
116 return f"sha256-{base64.b64encode(b).decode('utf-8')}"
117
118
119def _commit(repo: git.Repo, message: str, files: List[Path], actor: git.Actor) -> None:
120 repo.index.add([str(f.resolve()) for f in files])
121 if repo.index.diff("HEAD"):
122 logger.info(f'committing to nixpkgs "{message}"')
123 repo.index.commit(message, author=actor, committer=actor)
124 else:
125 logger.warning("no changes in working tree to commit")
126
127
128def _filter_invalid(o: Dict[str, Any]) -> bool:
129 if "metadata" not in o:
130 logger.warning("extension without metadata")
131 return False
132 metadata = o["metadata"]
133 if "name" not in metadata:
134 logger.warning("extension without name")
135 return False
136 if "version" not in metadata:
137 logger.warning(f"{metadata['name']} without version")
138 return False
139 if "azext.minCliCoreVersion" not in metadata:
140 logger.warning(
141 f"{metadata['name']} {metadata['version']} does not have azext.minCliCoreVersion"
142 )
143 return False
144 if "summary" not in metadata:
145 logger.info(f"{metadata['name']} {metadata['version']} without summary")
146 return False
147 if "downloadUrl" not in o:
148 logger.warning(f"{metadata['name']} {metadata['version']} without downloadUrl")
149 return False
150 if "sha256Digest" not in o:
151 logger.warning(f"{metadata['name']} {metadata['version']} without sha256Digest")
152 return False
153
154 return True
155
156
157def _filter_compatible(o: Dict[str, Any], cli_version: Version) -> bool:
158 minCliVersion = parse(o["metadata"]["azext.minCliCoreVersion"])
159 return bool(cli_version >= minCliVersion)
160
161
162def _transform_dict_to_obj(o: Dict[str, Any]) -> Ext:
163 m = o["metadata"]
164 return Ext(
165 pname=m["name"],
166 version=parse(m["version"]),
167 url=o["downloadUrl"],
168 hash=_convert_hash_digest_from_hex_to_b64_sri(o["sha256Digest"]),
169 description=m["summary"].rstrip("."),
170 )
171
172
173def _get_latest_version(versions: dict) -> dict:
174 return max(versions, key=lambda e: parse(e["metadata"]["version"]), default=None)
175
176
177def find_extension_version(
178 extVersions: dict,
179 cli_version: Version,
180 ext_name: Optional[str] = None,
181 requirements: bool = False,
182) -> Optional[Dict[str, Any]]:
183 versions = filter(_filter_invalid, extVersions)
184 versions = filter(lambda v: _filter_compatible(v, cli_version), versions)
185 latest = _get_latest_version(versions)
186 if not latest:
187 return None
188 if ext_name and latest["metadata"]["name"] != ext_name:
189 return None
190 if not requirements and "run_requires" in latest["metadata"]:
191 return None
192 return latest
193
194
195def find_and_transform_extension_version(
196 extVersions: dict,
197 cli_version: Version,
198 ext_name: Optional[str] = None,
199 requirements: bool = False,
200) -> Optional[Ext]:
201 latest = find_extension_version(extVersions, cli_version, ext_name, requirements)
202 if not latest:
203 return None
204
205 return _transform_dict_to_obj(latest)
206
207
208def _diff_sets(
209 set_local: Set[Ext], set_remote: Set[Ext]
210) -> Tuple[Set[Ext], Set[Ext], Set[Tuple[Ext, Ext]]]:
211 local_exts = {ext.pname: ext for ext in set_local}
212 remote_exts = {ext.pname: ext for ext in set_remote}
213 only_local = local_exts.keys() - remote_exts.keys()
214 only_remote = remote_exts.keys() - local_exts.keys()
215 both = remote_exts.keys() & local_exts.keys()
216 return (
217 {local_exts[pname] for pname in only_local},
218 {remote_exts[pname] for pname in only_remote},
219 {(local_exts[pname], remote_exts[pname]) for pname in both},
220 )
221
222
223def _filter_updated(e: Tuple[Ext, Ext]) -> bool:
224 prev, new = e
225 return prev != new
226
227
228@dataclass(frozen=True)
229class AttrPos:
230 file: str
231 line: int
232 column: int
233
234
235def nix_get_value(attr_path: str) -> Optional[str]:
236 try:
237 output = (
238 subprocess.run(
239 [
240 "nix-instantiate",
241 "--eval",
242 "--strict",
243 "--json",
244 "-E",
245 f"with import ./. {{ }}; {attr_path}",
246 ],
247 stdout=subprocess.PIPE,
248 text=True,
249 check=True,
250 )
251 .stdout.rstrip()
252 .strip('"')
253 )
254 except subprocess.CalledProcessError as e:
255 logger.error("failed to nix-instantiate: %s", e)
256 return None
257 return output
258
259
260def nix_unsafe_get_attr_pos(attr: str, attr_path: str) -> Optional[AttrPos]:
261 try:
262 output = subprocess.run(
263 [
264 "nix-instantiate",
265 "--eval",
266 "--strict",
267 "--json",
268 "-E",
269 f'with import ./. {{ }}; (builtins.unsafeGetAttrPos "{attr}" {attr_path})',
270 ],
271 stdout=subprocess.PIPE,
272 text=True,
273 check=True,
274 ).stdout.rstrip()
275 except subprocess.CalledProcessError as e:
276 logger.error("failed to unsafeGetAttrPos: %s", e)
277 return None
278 if output == "null":
279 logger.error("failed to unsafeGetAttrPos: nix-instantiate returned 'null'")
280 return None
281 pos = json.loads(output)
282 return AttrPos(pos["file"], pos["line"] - 1, pos["column"])
283
284
285def edit_file(file: str, rewrite: Callable[[str], str]) -> None:
286 with open(file, "r") as f:
287 lines = f.readlines()
288 lines = [rewrite(line) for line in lines]
289 with open(file, "w") as f:
290 f.writelines(lines)
291
292
293def edit_file_at_pos(pos: AttrPos, rewrite: Callable[[str], str]) -> None:
294 with open(pos.file, "r") as f:
295 lines = f.readlines()
296 lines[pos.line] = rewrite(lines[pos.line])
297 with open(pos.file, "w") as f:
298 f.writelines(lines)
299
300
301def read_value_at_pos(pos: AttrPos) -> str:
302 with open(pos.file, "r") as f:
303 lines = f.readlines()
304 return value_from_nix_line(lines[pos.line])
305
306
307def value_from_nix_line(line: str) -> str:
308 return line.split("=")[1].strip().strip(";").strip('"')
309
310
311def replace_value_in_nix_line(new: str) -> Callable[[str], str]:
312 return lambda line: line.replace(value_from_nix_line(line), new)
313
314
315def main() -> None:
316 sh = logging.StreamHandler(sys.stderr)
317 sh.setFormatter(
318 logging.Formatter(
319 "[%(asctime)s] [%(levelname)8s] --- %(message)s (%(filename)s:%(lineno)s)",
320 "%Y-%m-%d %H:%M:%S",
321 )
322 )
323 logging.basicConfig(level=logging.INFO, handlers=[sh])
324
325 parser = argparse.ArgumentParser(
326 prog="azure-cli.extensions-tool",
327 description="Script to handle Azure CLI extension updates",
328 )
329 parser.add_argument(
330 "--cli-version", type=str, help="version of azure-cli (required)"
331 )
332 parser.add_argument("--extension", type=str, help="name of extension to query")
333 parser.add_argument(
334 "--cache-dir",
335 type=Path,
336 help="path where to cache the extension index",
337 default=Path(os.getenv("XDG_CACHE_HOME", Path.home() / ".cache"))
338 / "azure-cli-extensions-tool",
339 )
340 parser.add_argument(
341 "--requirements",
342 action=argparse.BooleanOptionalAction,
343 help="whether to list extensions that have requirements",
344 )
345 parser.add_argument(
346 "--commit",
347 action=argparse.BooleanOptionalAction,
348 help="whether to commit changes to git",
349 )
350 parser.add_argument(
351 "--init",
352 action=argparse.BooleanOptionalAction,
353 help="whether you want to init a new extension",
354 )
355 args = parser.parse_args()
356 cli_version = parse(args.cli_version)
357
358 repo = git.Repo(Path(".").resolve(), search_parent_directories=True)
359 # Workaround for https://github.com/gitpython-developers/GitPython/issues/1923
360 author = repo.config_reader().get_value("user", "name").lstrip('"').rstrip('"')
361 email = repo.config_reader().get_value("user", "email").lstrip('"').rstrip('"')
362 actor = git.Actor(author, email)
363
364 index = get_extension_index(args.cache_dir)
365 assert index["formatVersion"] == "1" # only support formatVersion 1
366 extensions_remote = index["extensions"]
367
368 # init just prints the json of the extension version that matches the cli version.
369 if args.init:
370 if not args.extension:
371 logger.error("extension name is required for --init")
372 exit(1)
373
374 for ext_name, ext_versions in extensions_remote.items():
375 if ext_name != args.extension:
376 continue
377 ext = find_extension_version(
378 ext_versions,
379 cli_version,
380 args.extension,
381 requirements=True,
382 )
383 break
384 if not ext:
385 logger.error(f"Extension {args.extension} not found in index")
386 exit(1)
387
388 ext_translated = {
389 "pname": ext["metadata"]["name"],
390 "version": ext["metadata"]["version"],
391 "url": ext["downloadUrl"],
392 "hash": _convert_hash_digest_from_hex_to_b64_sri(ext["sha256Digest"]),
393 "description": ext["metadata"]["summary"].rstrip("."),
394 "license": ext["metadata"]["license"],
395 "requirements": ext["metadata"]["run_requires"][0]["requires"],
396 }
397 print(json.dumps(ext_translated, indent=2))
398 return
399
400 if args.extension:
401 logger.info(f"updating extension: {args.extension}")
402
403 ext = Optional[Ext]
404 for _ext_name, extension in extensions_remote.items():
405 extension = find_and_transform_extension_version(
406 extension, cli_version, args.extension, requirements=True
407 )
408 if extension:
409 ext = extension
410 break
411 if not ext:
412 logger.error(f"Extension {args.extension} not found in index")
413 exit(1)
414
415 version_pos = nix_unsafe_get_attr_pos(
416 "version", f"azure-cli-extensions.{ext.pname}"
417 )
418 if not version_pos:
419 logger.error(
420 f"no position for attribute 'version' found on attribute path {ext.pname}"
421 )
422 exit(1)
423 version = read_value_at_pos(version_pos)
424 current_version = parse(version)
425
426 if ext.version == current_version:
427 logger.info(
428 f"no update needed for {ext.pname}, latest version is {ext.version}"
429 )
430 return
431 logger.info("updated extensions:")
432 logger.info(f" {ext.pname} {current_version} -> {ext.version}")
433 edit_file_at_pos(version_pos, replace_value_in_nix_line(str(ext.version)))
434
435 current_hash = nix_get_value(f"azure-cli-extensions.{ext.pname}.src.outputHash")
436 if not current_hash:
437 logger.error(
438 f"no attribute 'src.outputHash' found on attribute path {ext.pname}"
439 )
440 exit(1)
441 edit_file(version_pos.file, lambda line: line.replace(current_hash, ext.hash))
442
443 if args.commit:
444 commit_msg = (
445 f"azure-cli-extensions.{ext.pname}: {current_version} -> {ext.version}"
446 )
447 _commit(repo, commit_msg, [Path(version_pos.file)], actor)
448 return
449
450 logger.info("updating generated extension set")
451
452 extensions_remote_filtered = set()
453 for _ext_name, extension in extensions_remote.items():
454 extension = find_and_transform_extension_version(
455 extension, cli_version, args.extension
456 )
457 if extension:
458 extensions_remote_filtered.add(extension)
459
460 extension_file = (
461 Path(repo.working_dir) / "pkgs/by-name/az/azure-cli/extensions-generated.json"
462 )
463 extensions_local = _read_extension_set(extension_file)
464 extensions_local_filtered = set()
465 if args.extension:
466 extensions_local_filtered = filter(
467 lambda ext: args.extension == ext.pname, extensions_local
468 )
469 else:
470 extensions_local_filtered = extensions_local
471
472 removed, init, updated = _diff_sets(
473 extensions_local_filtered, extensions_remote_filtered
474 )
475 updated = set(filter(_filter_updated, updated))
476
477 logger.info("initialized extensions:")
478 for ext in init:
479 logger.info(f" {ext.pname} {ext.version}")
480 logger.info("removed extensions:")
481 for ext in removed:
482 logger.info(f" {ext.pname} {ext.version}")
483 logger.info("updated extensions:")
484 for prev, new in updated:
485 logger.info(f" {prev.pname} {prev.version} -> {new.version}")
486
487 for ext in init:
488 extensions_local.add(ext)
489 commit_msg = f"azure-cli-extensions.{ext.pname}: init at {ext.version}"
490 _write_extension_set(extension_file, extensions_local)
491 if args.commit:
492 _commit(repo, commit_msg, [extension_file], actor)
493
494 for prev, new in updated:
495 extensions_local.remove(prev)
496 extensions_local.add(new)
497 commit_msg = (
498 f"azure-cli-extensions.{prev.pname}: {prev.version} -> {new.version}"
499 )
500 _write_extension_set(extension_file, extensions_local)
501 if args.commit:
502 _commit(repo, commit_msg, [extension_file], actor)
503
504 for ext in removed:
505 extensions_local.remove(ext)
506 # TODO: Add additional check why this is removed
507 # TODO: Add an alias to extensions manual?
508 commit_msg = f"azure-cli-extensions.{ext.pname}: remove"
509 _write_extension_set(extension_file, extensions_local)
510 if args.commit:
511 _commit(repo, commit_msg, [extension_file], actor)
512
513
514if __name__ == "__main__":
515 main()