1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 pytestCheckHook,
6 setuptools,
7 numpy,
8 jaxlib,
9 jax,
10 torch,
11 dask,
12 sparse,
13 array-api-strict,
14 config,
15 cudaSupport ? config.cudaSupport,
16 cupy,
17}:
18
19buildPythonPackage rec {
20 pname = "array-api-compat";
21 version = "1.11.2";
22 pyproject = true;
23
24 src = fetchFromGitHub {
25 owner = "data-apis";
26 repo = "array-api-compat";
27 tag = version;
28 hash = "sha256-qGf1XDhRx9hJJP0LcZF7lA8tl+LKYNCw0xTqGjsZYj8=";
29 };
30
31 build-system = [ setuptools ];
32
33 nativeCheckInputs = [
34 pytestCheckHook
35 numpy
36 jaxlib
37 jax
38 torch
39 dask
40 sparse
41 array-api-strict
42 ] ++ lib.optionals cudaSupport [ cupy ];
43
44 pythonImportsCheck = [ "array_api_compat" ];
45
46 # CUDA (used via cupy) is not available in the testing sandbox
47 pytestFlagsArray = [
48 "-k"
49 "'not cupy'"
50 ];
51
52 meta = {
53 homepage = "https://data-apis.org/array-api-compat";
54 changelog = "https://github.com/data-apis/array-api-compat/releases/tag/${src.tag}";
55 description = "Compatibility layer for NumPy to support the Python array API";
56 license = lib.licenses.mit;
57 maintainers = with lib.maintainers; [ berquist ];
58 };
59}