1#!/usr/bin/env python3
2
3"""
4Update a Python package expression by passing in the `.nix` file, or the directory containing it.
5You can pass in multiple files or paths.
6
7You'll likely want to use
8``
9 $ ./update-python-libraries ../../pkgs/development/python-modules/**/default.nix
10``
11to update all non-pinned libraries in that folder.
12"""
13
14import argparse
15import collections
16import json
17import logging
18import os
19import re
20import subprocess
21from concurrent.futures import ThreadPoolExecutor as Pool
22from typing import Any, Optional
23
24import requests
25from packaging.specifiers import SpecifierSet
26from packaging.version import InvalidVersion
27from packaging.version import Version as _Version
28
29INDEX = "https://pypi.io/pypi"
30"""url of PyPI"""
31
32EXTENSIONS = ["tar.gz", "tar.bz2", "tar", "zip", ".whl"]
33"""Permitted file extensions. These are evaluated from left to right and the first occurance is returned."""
34
35PRERELEASES = False
36
37BULK_UPDATE = False
38
39NIX = "nix"
40NIX_PREFETCH_URL = "nix-prefetch-url"
41NIX_PREFETCH_GIT = "nix-prefetch-git"
42GIT = "git"
43
44NIXPKGS_ROOT = (
45 subprocess.check_output([GIT, "rev-parse", "--show-toplevel"])
46 .decode("utf-8")
47 .strip()
48)
49
50logging.basicConfig(level=logging.INFO)
51
52
53class Version(_Version, collections.abc.Sequence):
54 def __init__(self, version):
55 super().__init__(version)
56 # We cannot use `str(Version(0.04.21))` because that becomes `0.4.21`
57 # https://github.com/avian2/unidecode/issues/13#issuecomment-354538882
58 self.raw_version = version
59
60 def __getitem__(self, i):
61 return self._version.release[i]
62
63 def __len__(self):
64 return len(self._version.release)
65
66 def __iter__(self):
67 yield from self._version.release
68
69
70def _get_values(attribute, text):
71 """Match attribute in text and return all matches.
72
73 :returns: List of matches.
74 """
75 regex = rf'{re.escape(attribute)}\s+=\s+"(.*)";'
76 regex = re.compile(regex)
77 values = regex.findall(text)
78 return values
79
80
81def _get_attr_value(attr_path: str) -> Optional[Any]:
82 try:
83 response = subprocess.check_output(
84 [
85 NIX,
86 "--extra-experimental-features",
87 "nix-command",
88 "eval",
89 "-f",
90 f"{NIXPKGS_ROOT}/default.nix",
91 "--json",
92 f"{attr_path}",
93 ],
94 stderr=subprocess.DEVNULL,
95 )
96 return json.loads(response.decode())
97 except (subprocess.CalledProcessError, ValueError):
98 return None
99
100
101def _get_unique_value(attribute, text):
102 """Match attribute in text and return unique match.
103
104 :returns: Single match.
105 """
106 values = _get_values(attribute, text)
107 n = len(values)
108 if n > 1:
109 raise ValueError("found too many values for {}".format(attribute))
110 elif n == 1:
111 return values[0]
112 else:
113 raise ValueError("no value found for {}".format(attribute))
114
115
116def _get_line_and_value(attribute, text, value=None):
117 """Match attribute in text. Return the line and the value of the attribute."""
118 if value is None:
119 regex = rf"({re.escape(attribute)}\s+=\s+\"(.*)\";)"
120 else:
121 regex = rf"({re.escape(attribute)}\s+=\s+\"({re.escape(value)})\";)"
122 regex = re.compile(regex)
123 results = regex.findall(text)
124 n = len(results)
125 if n > 1:
126 raise ValueError("found too many values for {}".format(attribute))
127 elif n == 1:
128 return results[0]
129 else:
130 raise ValueError("no value found for {}".format(attribute))
131
132
133def _replace_value(attribute, value, text, oldvalue=None):
134 """Search and replace value of attribute in text."""
135 if oldvalue is None:
136 old_line, old_value = _get_line_and_value(attribute, text)
137 else:
138 old_line, old_value = _get_line_and_value(attribute, text, oldvalue)
139 new_line = old_line.replace(old_value, value)
140 new_text = text.replace(old_line, new_line)
141 return new_text
142
143
144def _fetch_page(url):
145 r = requests.get(url)
146 if r.status_code == requests.codes.ok:
147 return r.json()
148 else:
149 raise ValueError("request for {} failed".format(url))
150
151
152def _fetch_github(url):
153 headers = {}
154 token = os.environ.get("GITHUB_API_TOKEN")
155 if token:
156 headers["Authorization"] = f"token {token}"
157 r = requests.get(url, headers=headers)
158
159 if r.status_code == requests.codes.ok:
160 return r.json()
161 else:
162 raise ValueError("request for {} failed".format(url))
163
164
165def _hash_to_sri(algorithm, value):
166 """Convert a hash to its SRI representation"""
167 return (
168 subprocess.check_output(
169 [
170 NIX,
171 "--extra-experimental-features",
172 "nix-command",
173 "hash",
174 "to-sri",
175 "--type",
176 algorithm,
177 value,
178 ]
179 )
180 .decode()
181 .strip()
182 )
183
184
185def _skip_bulk_update(attr_name: str) -> bool:
186 return bool(_get_attr_value(f"{attr_name}.skipBulkUpdate"))
187
188
189SEMVER = {
190 "major": 0,
191 "minor": 1,
192 "patch": 2,
193}
194
195
196def _determine_latest_version(current_version, target, versions):
197 """Determine latest version, given `target`."""
198 current_version = Version(current_version)
199
200 def _parse_versions(versions):
201 for v in versions:
202 try:
203 yield Version(v)
204 except InvalidVersion:
205 pass
206
207 versions = _parse_versions(versions)
208
209 index = SEMVER[target]
210
211 ceiling = list(current_version[0:index])
212 if len(ceiling) == 0:
213 ceiling = None
214 else:
215 ceiling[-1] += 1
216 ceiling = Version(".".join(map(str, ceiling)))
217
218 # We do not want prereleases
219 versions = SpecifierSet(prereleases=PRERELEASES).filter(versions)
220
221 if ceiling is not None:
222 versions = SpecifierSet(f"<{ceiling}").filter(versions)
223
224 return (max(sorted(versions))).raw_version
225
226
227def _get_latest_version_pypi(attr_path, package, extension, current_version, target):
228 """Get latest version and hash from PyPI."""
229 url = "{}/{}/json".format(INDEX, package)
230 json = _fetch_page(url)
231
232 versions = {
233 version
234 for version, releases in json["releases"].items()
235 if not all(release["yanked"] for release in releases)
236 }
237 version = _determine_latest_version(current_version, target, versions)
238
239 try:
240 releases = json["releases"][version]
241 except KeyError as e:
242 raise KeyError(
243 "Could not find version {} for {}".format(version, package)
244 ) from e
245 for release in releases:
246 if release["filename"].endswith(extension):
247 # TODO: In case of wheel we need to do further checks!
248 sha256 = release["digests"]["sha256"]
249 break
250 else:
251 sha256 = None
252 return version, sha256, None
253
254
255def _get_latest_version_github(attr_path, package, extension, current_version, target):
256 def strip_prefix(tag):
257 return re.sub("^[^0-9]*", "", tag)
258
259 def get_prefix(string):
260 matches = re.findall(r"^([^0-9]*)", string)
261 return next(iter(matches), "")
262
263 try:
264 homepage = subprocess.check_output(
265 [
266 NIX,
267 "--extra-experimental-features",
268 "nix-command",
269 "eval",
270 "-f",
271 f"{NIXPKGS_ROOT}/default.nix",
272 "--raw",
273 f"{attr_path}.src.meta.homepage",
274 ]
275 ).decode("utf-8")
276 except Exception as e:
277 raise ValueError(f"Unable to determine homepage: {e}")
278 owner_repo = homepage[len("https://github.com/") :] # remove prefix
279 owner, repo = owner_repo.split("/")
280
281 url = f"https://api.github.com/repos/{owner}/{repo}/releases"
282 all_releases = _fetch_github(url)
283 releases = list(filter(lambda x: not x["prerelease"], all_releases))
284
285 if len(releases) == 0:
286 logging.warning(f"{homepage} does not contain any stable releases, looking for tags instead...")
287 url = f"https://api.github.com/repos/{owner}/{repo}/tags"
288 all_tags = _fetch_github(url)
289 # Releases are used with a couple of fields that tags possess as well. We will fake these releases.
290 releases = [{'tag_name': tag['name'], 'tarball_url': tag['tarball_url']} for tag in all_tags]
291
292 if len(releases) == 0:
293 raise ValueError(f"{homepage} does not contain any stable releases neither tags, stopping now.")
294
295 versions = map(lambda x: strip_prefix(x["tag_name"]), releases)
296 version = _determine_latest_version(current_version, target, versions)
297
298 release = next(filter(lambda x: strip_prefix(x["tag_name"]) == version, releases))
299 prefix = get_prefix(release["tag_name"])
300
301 fetcher = _get_attr_value(f"{attr_path}.src.fetcher")
302 if fetcher is not None and fetcher.endswith("nix-prefetch-git"):
303 # some attributes require using the fetchgit
304 git_fetcher_args = []
305 if _get_attr_value(f"{attr_path}.src.fetchSubmodules"):
306 git_fetcher_args.append("--fetch-submodules")
307 if _get_attr_value(f"{attr_path}.src.fetchLFS"):
308 git_fetcher_args.append("--fetch-lfs")
309 if _get_attr_value(f"{attr_path}.src.leaveDotGit"):
310 git_fetcher_args.append("--leave-dotGit")
311
312 algorithm = "sha256"
313 cmd = [
314 NIX_PREFETCH_GIT,
315 f"https://github.com/{owner}/{repo}.git",
316 "--hash",
317 algorithm,
318 "--rev",
319 f"refs/tags/{release['tag_name']}",
320 ]
321 cmd.extend(git_fetcher_args)
322 response = subprocess.check_output(cmd)
323 document = json.loads(response.decode())
324 hash = _hash_to_sri(algorithm, document[algorithm])
325 else:
326 try:
327 hash = (
328 subprocess.check_output(
329 [
330 NIX_PREFETCH_URL,
331 "--type",
332 "sha256",
333 "--unpack",
334 f"{release['tarball_url']}",
335 ],
336 stderr=subprocess.DEVNULL,
337 )
338 .decode("utf-8")
339 .strip()
340 )
341 except (subprocess.CalledProcessError, UnicodeError):
342 # this may fail if they have both a branch and a tag of the same name, attempt tag name
343 tag_url = str(release["tarball_url"]).replace(
344 "tarball", "tarball/refs/tags"
345 )
346 try:
347 hash = (
348 subprocess.check_output(
349 [NIX_PREFETCH_URL, "--type", "sha256", "--unpack", tag_url],
350 stderr=subprocess.DEVNULL,
351 )
352 .decode("utf-8")
353 .strip()
354 )
355 except subprocess.CalledProcessError:
356 raise ValueError("nix-prefetch-url failed")
357
358 return version, hash, prefix
359
360
361FETCHERS = {
362 "fetchFromGitHub": _get_latest_version_github,
363 "fetchPypi": _get_latest_version_pypi,
364 "fetchurl": _get_latest_version_pypi,
365}
366
367
368DEFAULT_SETUPTOOLS_EXTENSION = "tar.gz"
369
370
371FORMATS = {
372 "setuptools": DEFAULT_SETUPTOOLS_EXTENSION,
373 "wheel": "whl",
374 "pyproject": "tar.gz",
375 "flit": "tar.gz",
376}
377
378
379def _determine_fetcher(text):
380 # Count occurrences of fetchers.
381 nfetchers = sum(
382 text.count("src = {}".format(fetcher)) for fetcher in FETCHERS.keys()
383 )
384 if nfetchers == 0:
385 raise ValueError("no fetcher.")
386 elif nfetchers > 1:
387 raise ValueError("multiple fetchers.")
388 else:
389 # Then we check which fetcher to use.
390 for fetcher in FETCHERS.keys():
391 if "src = {}".format(fetcher) in text:
392 return fetcher
393
394
395def _determine_extension(text, fetcher):
396 """Determine what extension is used in the expression.
397
398 If we use:
399 - fetchPypi, we check if format is specified.
400 - fetchurl, we determine the extension from the url.
401 - fetchFromGitHub we simply use `.tar.gz`.
402 """
403 if fetcher == "fetchPypi":
404 try:
405 src_format = _get_unique_value("format", text)
406 except ValueError:
407 src_format = None # format was not given
408
409 try:
410 extension = _get_unique_value("extension", text)
411 except ValueError:
412 extension = None # extension was not given
413
414 if extension is None:
415 if src_format is None:
416 src_format = "setuptools"
417 elif src_format == "other":
418 raise ValueError("Don't know how to update a format='other' package.")
419 extension = FORMATS[src_format]
420
421 elif fetcher == "fetchurl":
422 url = _get_unique_value("url", text)
423 extension = os.path.splitext(url)[1]
424 if "pypi" not in url:
425 raise ValueError("url does not point to PyPI.")
426
427 elif fetcher == "fetchFromGitHub":
428 extension = "tar.gz"
429
430 return extension
431
432
433def _update_package(path, target):
434 # Read the expression
435 with open(path, "r") as f:
436 text = f.read()
437
438 # Determine pname. Many files have more than one pname
439 pnames = _get_values("pname", text)
440
441 # Determine version.
442 version = _get_unique_value("version", text)
443
444 # First we check how many fetchers are mentioned.
445 fetcher = _determine_fetcher(text)
446
447 extension = _determine_extension(text, fetcher)
448
449 # Attempt a fetch using each pname, e.g. backports-zoneinfo vs backports.zoneinfo
450 successful_fetch = False
451 for pname in pnames:
452 # when invoked as an updateScript, UPDATE_NIX_ATTR_PATH will be set
453 # this allows us to work with packages which live outside of python-modules
454 attr_path = os.environ.get("UPDATE_NIX_ATTR_PATH", f"python3Packages.{pname}")
455
456 if BULK_UPDATE and _skip_bulk_update(attr_path):
457 raise ValueError(f"Bulk update skipped for {pname}")
458 elif _get_attr_value(f"{attr_path}.cargoDeps") is not None:
459 raise ValueError(f"Cargo dependencies are unsupported, skipping {pname}")
460 try:
461 new_version, new_sha256, prefix = FETCHERS[fetcher](
462 attr_path, pname, extension, version, target
463 )
464 successful_fetch = True
465 break
466 except ValueError:
467 logging.exception(f"Failed to fetch releases for {pname}")
468 continue
469
470 if not successful_fetch:
471 raise ValueError(f"Unable to find correct package using these pnames: {pnames}")
472
473 if new_version == version:
474 logging.info("Path {}: no update available for {}.".format(path, pname))
475 return False
476 elif Version(new_version) <= Version(version):
477 raise ValueError("downgrade for {}.".format(pname))
478 if not new_sha256:
479 raise ValueError("no file available for {}.".format(pname))
480
481 text = _replace_value("version", new_version, text)
482
483 # hashes from pypi are 16-bit encoded sha256's, normalize it to sri to avoid merge conflicts
484 # sri hashes have been the default format since nix 2.4+
485 sri_hash = _hash_to_sri("sha256", new_sha256)
486
487 # retrieve the old output hash for a more precise match
488 if old_hash := _get_attr_value(f"{attr_path}.src.outputHash"):
489 # fetchers can specify a sha256, or a sri hash
490 try:
491 text = _replace_value("hash", sri_hash, text, old_hash)
492 except ValueError:
493 text = _replace_value("sha256", sri_hash, text, old_hash)
494 else:
495 raise ValueError(f"Unable to retrieve old hash for {pname}")
496
497 if fetcher == "fetchFromGitHub":
498 # in the case of fetchFromGitHub, it's common to see `rev = version;` or `rev = "v${version}";`
499 # in which no string value is meant to be substituted. However, we can just overwrite the previous value.
500 regex = r"((?:rev|tag)\s+=\s+[^;]*;)"
501 regex = re.compile(regex)
502 matches = regex.findall(text)
503 n = len(matches)
504
505 if n == 0:
506 raise ValueError("Unable to find rev value for {}.".format(pname))
507 else:
508 # forcefully rewrite rev, incase tagging conventions changed for a release
509 match = matches[0]
510 text = text.replace(match, f'tag = "{prefix}${{version}}";')
511 # incase there's no prefix, just rewrite without interpolation
512 text = text.replace('"${version}";', "version;")
513
514 # update changelog to reference the src.tag
515 if result := re.search("changelog = \"[^\"]+\";", text):
516 cl_old = result[0]
517 cl_new = re.sub(r"v?\$\{(version|src.rev)\}", "${src.tag}", cl_old)
518 text = text.replace(cl_old, cl_new)
519
520 with open(path, "w") as f:
521 f.write(text)
522
523 logging.info(
524 "Path {}: updated {} from {} to {}".format(
525 path, pname, version, new_version
526 )
527 )
528
529 result = {
530 "path": path,
531 "target": target,
532 "pname": pname,
533 "old_version": version,
534 "new_version": new_version,
535 #'fetcher' : fetcher,
536 }
537
538 return result
539
540
541def _update(path, target):
542 # We need to read and modify a Nix expression.
543 if os.path.isdir(path):
544 path = os.path.join(path, "default.nix")
545
546 # If a default.nix does not exist, we quit.
547 if not os.path.isfile(path):
548 logging.info("Path {}: does not exist.".format(path))
549 return False
550
551 # If file is not a Nix expression, we quit.
552 if not path.endswith(".nix"):
553 logging.info("Path {}: does not end with `.nix`.".format(path))
554 return False
555
556 try:
557 return _update_package(path, target)
558 except ValueError as e:
559 logging.warning("Path {}: {}".format(path, e))
560 return False
561
562
563def _commit(path, pname, old_version, new_version, pkgs_prefix="python: ", **kwargs):
564 """Commit result."""
565
566 msg = f"{pkgs_prefix}{pname}: {old_version} -> {new_version}"
567
568 if changelog := _get_attr_value(f"{pkgs_prefix}{pname}.meta.changelog"):
569 msg += f"\n\n{changelog}"
570
571 msg += "\n\nThis commit was automatically generated using update-python-libraries."
572
573 try:
574 subprocess.check_call([GIT, "add", path])
575 subprocess.check_call([GIT, "commit", "-m", msg])
576 except subprocess.CalledProcessError as e:
577 subprocess.check_call([GIT, "checkout", path])
578 raise subprocess.CalledProcessError(f"Could not commit {path}") from e
579
580 return True
581
582
583def main():
584 epilog = """
585environment variables:
586 GITHUB_API_TOKEN\tGitHub API token used when updating github packages
587 """
588 parser = argparse.ArgumentParser(
589 formatter_class=argparse.RawDescriptionHelpFormatter, epilog=epilog
590 )
591 parser.add_argument("package", type=str, nargs="+")
592 parser.add_argument("--target", type=str, choices=SEMVER.keys(), default="major")
593 parser.add_argument(
594 "--commit", action="store_true", help="Create a commit for each package update"
595 )
596 parser.add_argument(
597 "--use-pkgs-prefix",
598 action="store_true",
599 help="Use python3Packages.${pname}: instead of python: ${pname}: when making commits",
600 )
601
602 args = parser.parse_args()
603 target = args.target
604
605 packages = list(map(os.path.abspath, args.package))
606
607 if len(packages) > 1:
608 global BULK_UPDATE
609 BULK_UPDATE = True
610
611 logging.info("Updating packages...")
612
613 # Use threads to update packages concurrently
614 with Pool() as p:
615 results = list(filter(bool, p.map(lambda pkg: _update(pkg, target), packages)))
616
617 logging.info("Finished updating packages.")
618
619 commit_options = {}
620 if args.use_pkgs_prefix:
621 logging.info("Using python3Packages. prefix for commits")
622 commit_options["pkgs_prefix"] = "python3Packages."
623
624 # Commits are created sequentially.
625 if args.commit:
626 logging.info("Committing updates...")
627 # list forces evaluation
628 list(map(lambda x: _commit(**x, **commit_options), results))
629 logging.info("Finished committing updates")
630
631 count = len(results)
632 logging.info("{} package(s) updated".format(count))
633
634
635if __name__ == "__main__":
636 main()