1#! /usr/bin/env python3
2
3from itertools import chain
4import json
5import logging
6from pathlib import Path
7import os
8import re
9import subprocess
10import sys
11from typing import Dict, List, Optional, Set, TextIO
12from urllib.request import urlopen
13from urllib.error import HTTPError
14import yaml
15
16PKG_SET = "pkgs.python3Packages"
17
18# If some requirements are matched by multiple or no Python packages, the
19# following can be used to choose the correct one
20PKG_PREFERENCES = {
21 "dnspython": "dnspython",
22 "google-api-python-client": "google-api-python-client",
23 "psycopg2-binary": "psycopg2",
24 "requests_toolbelt": "requests-toolbelt",
25}
26
27# Requirements missing from the airflow provider metadata
28EXTRA_REQS = {
29 "sftp": ["pysftp"],
30}
31
32
33def get_version():
34 with open(os.path.dirname(sys.argv[0]) + "/default.nix") as fh:
35 # A version consists of digits, dots, and possibly a "b" (for beta)
36 m = re.search('version = "([\\d\\.b]+)";', fh.read())
37 return m.group(1)
38
39
40def get_file_from_github(version: str, path: str):
41 with urlopen(
42 f"https://raw.githubusercontent.com/apache/airflow/{version}/{path}"
43 ) as response:
44 return yaml.safe_load(response)
45
46
47def repository_root() -> Path:
48 return Path(os.path.dirname(sys.argv[0])) / "../../../.."
49
50
51def dump_packages() -> Dict[str, Dict[str, str]]:
52 # Store a JSON dump of Nixpkgs' python3Packages
53 output = subprocess.check_output(
54 [
55 "nix-env",
56 "-f",
57 repository_root(),
58 "-qa",
59 "-A",
60 PKG_SET,
61 "--arg",
62 "config",
63 "{ allowAliases = false; }",
64 "--json",
65 ]
66 )
67 return json.loads(output)
68
69
70def remove_version_constraint(req: str) -> str:
71 return re.sub(r"[=><~].*$", "", req)
72
73
74def name_to_attr_path(req: str, packages: Dict[str, Dict[str, str]]) -> Optional[str]:
75 if req in PKG_PREFERENCES:
76 return f"{PKG_SET}.{PKG_PREFERENCES[req]}"
77 attr_paths = []
78 names = [req]
79 # E.g. python-mpd2 is actually called python3.6-mpd2
80 # instead of python-3.6-python-mpd2 inside Nixpkgs
81 if req.startswith("python-") or req.startswith("python_"):
82 names.append(req[len("python-") :])
83 for name in names:
84 # treat "-" and "_" equally
85 name = re.sub("[-_]", "[-_]", name)
86 # python(minor).(major)-(pname)-(version or unstable-date)
87 # we need the version qualifier, or we'll have multiple matches
88 # (e.g. pyserial and pyserial-asyncio when looking for pyserial)
89 pattern = re.compile(
90 f"^python\\d+\\.\\d+-{name}-(?:\\d|unstable-.*)", re.I
91 )
92 for attr_path, package in packages.items():
93 # logging.debug("Checking match for %s with %s", name, package["name"])
94 if pattern.match(package["name"]):
95 attr_paths.append(attr_path)
96 # Let's hope there's only one derivation with a matching name
97 assert len(attr_paths) <= 1, f"{req} matches more than one derivation: {attr_paths}"
98 if attr_paths:
99 return attr_paths[0]
100 return None
101
102
103def provider_reqs_to_attr_paths(reqs: List, packages: Dict) -> List:
104 no_version_reqs = map(remove_version_constraint, reqs)
105 filtered_reqs = [
106 req for req in no_version_reqs if not re.match(r"^apache-airflow", req)
107 ]
108 attr_paths = []
109 for req in filtered_reqs:
110 attr_path = name_to_attr_path(req, packages)
111 if attr_path is not None:
112 # Add attribute path without "python3Packages." prefix
113 pname = attr_path[len(PKG_SET + ".") :]
114 attr_paths.append(pname)
115 else:
116 # If we can't find it, we just skip and warn the user
117 logging.warning("Could not find package attr for %s", req)
118 return attr_paths
119
120
121def get_cross_provider_reqs(
122 provider: str, provider_reqs: Dict, cross_provider_deps: Dict, seen: List = None
123) -> Set:
124 # Unfortunately there are circular cross-provider dependencies, so keep a
125 # list of ones we've seen already
126 seen = seen or []
127 reqs = set(provider_reqs[provider])
128 if len(cross_provider_deps[provider]) > 0:
129 reqs.update(
130 chain.from_iterable(
131 get_cross_provider_reqs(
132 d, provider_reqs, cross_provider_deps, seen + [provider]
133 )
134 if d not in seen
135 else []
136 for d in cross_provider_deps[provider]
137 )
138 )
139 return reqs
140
141
142def get_provider_reqs(version: str, packages: Dict) -> Dict:
143 provider_dependencies = get_file_from_github(
144 version, "generated/provider_dependencies.json"
145 )
146 provider_reqs = {}
147 cross_provider_deps = {}
148 for provider, provider_data in provider_dependencies.items():
149 provider_reqs[provider] = list(
150 provider_reqs_to_attr_paths(provider_data["deps"], packages)
151 ) + EXTRA_REQS.get(provider, [])
152 cross_provider_deps[provider] = [
153 d for d in provider_data["cross-providers-deps"] if d != "common.sql"
154 ]
155 transitive_provider_reqs = {}
156 # Add transitive cross-provider reqs
157 for provider in provider_reqs:
158 transitive_provider_reqs[provider] = get_cross_provider_reqs(
159 provider, provider_reqs, cross_provider_deps
160 )
161 return transitive_provider_reqs
162
163
164def get_provider_yaml(version: str, provider: str) -> Dict:
165 provider_dir = provider.replace(".", "/")
166 path = f"airflow/providers/{provider_dir}/provider.yaml"
167 try:
168 return get_file_from_github(version, path)
169 except HTTPError:
170 logging.warning("Couldn't get provider yaml for %s", provider)
171 return {}
172
173
174def get_provider_imports(version: str, providers) -> Dict:
175 provider_imports = {}
176 for provider in providers:
177 provider_yaml = get_provider_yaml(version, provider)
178 imports: List[str] = []
179 if "hooks" in provider_yaml:
180 imports.extend(
181 chain.from_iterable(
182 hook["python-modules"] for hook in provider_yaml["hooks"]
183 )
184 )
185 if "operators" in provider_yaml:
186 imports.extend(
187 chain.from_iterable(
188 operator["python-modules"]
189 for operator in provider_yaml["operators"]
190 )
191 )
192 provider_imports[provider] = imports
193 return provider_imports
194
195
196def to_nix_expr(provider_reqs: Dict, provider_imports: Dict, fh: TextIO) -> None:
197 fh.write("# Warning: generated by update-providers.py, do not update manually\n")
198 fh.write("{\n")
199 for provider, reqs in provider_reqs.items():
200 provider_name = provider.replace(".", "_")
201 fh.write(f" {provider_name} = {{\n")
202 fh.write(
203 " deps = [ " + " ".join(sorted(f'"{req}"' for req in reqs)) + " ];\n"
204 )
205 fh.write(
206 " imports = [ "
207 + " ".join(sorted(f'"{imp}"' for imp in provider_imports[provider]))
208 + " ];\n"
209 )
210 fh.write(" };\n")
211 fh.write("}\n")
212
213
214def main() -> None:
215 logging.basicConfig(level=logging.INFO)
216 version = get_version()
217 packages = dump_packages()
218 logging.info("Generating providers.nix for version %s", version)
219 provider_reqs = get_provider_reqs(version, packages)
220 provider_imports = get_provider_imports(version, provider_reqs.keys())
221 with open("providers.nix", "w") as fh:
222 to_nix_expr(provider_reqs, provider_imports, fh)
223
224
225if __name__ == "__main__":
226 main()