at 23.05-pre 7.6 kB view raw
1#! /usr/bin/env python3 2 3from itertools import chain 4import json 5import logging 6from pathlib import Path 7import os 8import re 9import subprocess 10import sys 11from typing import Dict, List, Optional, Set, TextIO 12from urllib.request import urlopen 13from urllib.error import HTTPError 14import yaml 15 16PKG_SET = "pkgs.python3Packages" 17 18# If some requirements are matched by multiple or no Python packages, the 19# following can be used to choose the correct one 20PKG_PREFERENCES = { 21 "dnspython": "dnspython", 22 "google-api-python-client": "google-api-python-client", 23 "psycopg2-binary": "psycopg2", 24 "requests_toolbelt": "requests-toolbelt", 25} 26 27# Requirements missing from the airflow provider metadata 28EXTRA_REQS = { 29 "sftp": ["pysftp"], 30} 31 32 33def get_version(): 34 with open(os.path.dirname(sys.argv[0]) + "/default.nix") as fh: 35 # A version consists of digits, dots, and possibly a "b" (for beta) 36 m = re.search('version = "([\\d\\.b]+)";', fh.read()) 37 return m.group(1) 38 39 40def get_file_from_github(version: str, path: str): 41 with urlopen( 42 f"https://raw.githubusercontent.com/apache/airflow/{version}/{path}" 43 ) as response: 44 return yaml.safe_load(response) 45 46 47def repository_root() -> Path: 48 return Path(os.path.dirname(sys.argv[0])) / "../../../.." 49 50 51def dump_packages() -> Dict[str, Dict[str, str]]: 52 # Store a JSON dump of Nixpkgs' python3Packages 53 output = subprocess.check_output( 54 [ 55 "nix-env", 56 "-f", 57 repository_root(), 58 "-qa", 59 "-A", 60 PKG_SET, 61 "--arg", 62 "config", 63 "{ allowAliases = false; }", 64 "--json", 65 ] 66 ) 67 return json.loads(output) 68 69 70def remove_version_constraint(req: str) -> str: 71 return re.sub(r"[=><~].*$", "", req) 72 73 74def name_to_attr_path(req: str, packages: Dict[str, Dict[str, str]]) -> Optional[str]: 75 if req in PKG_PREFERENCES: 76 return f"{PKG_SET}.{PKG_PREFERENCES[req]}" 77 attr_paths = [] 78 names = [req] 79 # E.g. python-mpd2 is actually called python3.6-mpd2 80 # instead of python-3.6-python-mpd2 inside Nixpkgs 81 if req.startswith("python-") or req.startswith("python_"): 82 names.append(req[len("python-") :]) 83 for name in names: 84 # treat "-" and "_" equally 85 name = re.sub("[-_]", "[-_]", name) 86 # python(minor).(major)-(pname)-(version or unstable-date) 87 # we need the version qualifier, or we'll have multiple matches 88 # (e.g. pyserial and pyserial-asyncio when looking for pyserial) 89 pattern = re.compile( 90 f"^python\\d+\\.\\d+-{name}-(?:\\d|unstable-.*)", re.I 91 ) 92 for attr_path, package in packages.items(): 93 # logging.debug("Checking match for %s with %s", name, package["name"]) 94 if pattern.match(package["name"]): 95 attr_paths.append(attr_path) 96 # Let's hope there's only one derivation with a matching name 97 assert len(attr_paths) <= 1, f"{req} matches more than one derivation: {attr_paths}" 98 if attr_paths: 99 return attr_paths[0] 100 return None 101 102 103def provider_reqs_to_attr_paths(reqs: List, packages: Dict) -> List: 104 no_version_reqs = map(remove_version_constraint, reqs) 105 filtered_reqs = [ 106 req for req in no_version_reqs if not re.match(r"^apache-airflow", req) 107 ] 108 attr_paths = [] 109 for req in filtered_reqs: 110 attr_path = name_to_attr_path(req, packages) 111 if attr_path is not None: 112 # Add attribute path without "python3Packages." prefix 113 pname = attr_path[len(PKG_SET + ".") :] 114 attr_paths.append(pname) 115 else: 116 # If we can't find it, we just skip and warn the user 117 logging.warning("Could not find package attr for %s", req) 118 return attr_paths 119 120 121def get_cross_provider_reqs( 122 provider: str, provider_reqs: Dict, cross_provider_deps: Dict, seen: List = None 123) -> Set: 124 # Unfortunately there are circular cross-provider dependencies, so keep a 125 # list of ones we've seen already 126 seen = seen or [] 127 reqs = set(provider_reqs[provider]) 128 if len(cross_provider_deps[provider]) > 0: 129 reqs.update( 130 chain.from_iterable( 131 get_cross_provider_reqs( 132 d, provider_reqs, cross_provider_deps, seen + [provider] 133 ) 134 if d not in seen 135 else [] 136 for d in cross_provider_deps[provider] 137 ) 138 ) 139 return reqs 140 141 142def get_provider_reqs(version: str, packages: Dict) -> Dict: 143 provider_dependencies = get_file_from_github( 144 version, "generated/provider_dependencies.json" 145 ) 146 provider_reqs = {} 147 cross_provider_deps = {} 148 for provider, provider_data in provider_dependencies.items(): 149 provider_reqs[provider] = list( 150 provider_reqs_to_attr_paths(provider_data["deps"], packages) 151 ) + EXTRA_REQS.get(provider, []) 152 cross_provider_deps[provider] = [ 153 d for d in provider_data["cross-providers-deps"] if d != "common.sql" 154 ] 155 transitive_provider_reqs = {} 156 # Add transitive cross-provider reqs 157 for provider in provider_reqs: 158 transitive_provider_reqs[provider] = get_cross_provider_reqs( 159 provider, provider_reqs, cross_provider_deps 160 ) 161 return transitive_provider_reqs 162 163 164def get_provider_yaml(version: str, provider: str) -> Dict: 165 provider_dir = provider.replace(".", "/") 166 path = f"airflow/providers/{provider_dir}/provider.yaml" 167 try: 168 return get_file_from_github(version, path) 169 except HTTPError: 170 logging.warning("Couldn't get provider yaml for %s", provider) 171 return {} 172 173 174def get_provider_imports(version: str, providers) -> Dict: 175 provider_imports = {} 176 for provider in providers: 177 provider_yaml = get_provider_yaml(version, provider) 178 imports: List[str] = [] 179 if "hooks" in provider_yaml: 180 imports.extend( 181 chain.from_iterable( 182 hook["python-modules"] for hook in provider_yaml["hooks"] 183 ) 184 ) 185 if "operators" in provider_yaml: 186 imports.extend( 187 chain.from_iterable( 188 operator["python-modules"] 189 for operator in provider_yaml["operators"] 190 ) 191 ) 192 provider_imports[provider] = imports 193 return provider_imports 194 195 196def to_nix_expr(provider_reqs: Dict, provider_imports: Dict, fh: TextIO) -> None: 197 fh.write("# Warning: generated by update-providers.py, do not update manually\n") 198 fh.write("{\n") 199 for provider, reqs in provider_reqs.items(): 200 provider_name = provider.replace(".", "_") 201 fh.write(f" {provider_name} = {{\n") 202 fh.write( 203 " deps = [ " + " ".join(sorted(f'"{req}"' for req in reqs)) + " ];\n" 204 ) 205 fh.write( 206 " imports = [ " 207 + " ".join(sorted(f'"{imp}"' for imp in provider_imports[provider])) 208 + " ];\n" 209 ) 210 fh.write(" };\n") 211 fh.write("}\n") 212 213 214def main() -> None: 215 logging.basicConfig(level=logging.INFO) 216 version = get_version() 217 packages = dump_packages() 218 logging.info("Generating providers.nix for version %s", version) 219 provider_reqs = get_provider_reqs(version, packages) 220 provider_imports = get_provider_imports(version, provider_reqs.keys()) 221 with open("providers.nix", "w") as fh: 222 to_nix_expr(provider_reqs, provider_imports, fh) 223 224 225if __name__ == "__main__": 226 main()