1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 pythonOlder,
6 fetchFromGitHub,
7 fetchpatch,
8 writeText,
9 setuptools,
10 wheel,
11 filelock,
12 huggingface-hub,
13 importlib-metadata,
14 numpy,
15 pillow,
16 regex,
17 requests,
18 safetensors,
19 # optional dependencies
20 accelerate,
21 datasets,
22 flax,
23 jax,
24 jaxlib,
25 jinja2,
26 peft,
27 protobuf,
28 tensorboard,
29 torch,
30 # test dependencies
31 parameterized,
32 pytest-timeout,
33 pytest-xdist,
34 pytestCheckHook,
35 requests-mock,
36 scipy,
37 sentencepiece,
38 torchsde,
39 transformers,
40 pythonAtLeast,
41}:
42
43buildPythonPackage rec {
44 pname = "diffusers";
45 version = "0.27.2";
46 pyproject = true;
47
48 disabled = pythonOlder "3.8";
49
50 src = fetchFromGitHub {
51 owner = "huggingface";
52 repo = "diffusers";
53 rev = "refs/tags/v${version}";
54 hash = "sha256-aRnbU3jN40xaCsoMFyRt1XB+hyIYMJP2b/T1yZho90c=";
55 };
56
57 patches = [
58 # fix python3.12 build
59 (fetchpatch {
60 # https://github.com/huggingface/diffusers/pull/7455
61 name = "001-remove-distutils.patch";
62 url = "https://github.com/huggingface/diffusers/compare/363699044e365ef977a7646b500402fa585e1b6b...3c67864c5acb30413911730b1ed4a9ad47c0a15c.patch";
63 hash = "sha256-Qyvyp1GyTVXN+A+lA1r2hf887ubTtaUknbKd4r46NZQ=";
64 })
65 (fetchpatch {
66 # https://github.com/huggingface/diffusers/pull/7461
67 name = "002-fix-removed-distutils.patch";
68 url = "https://github.com/huggingface/diffusers/commit/efbbbc38e436a1abb1df41a6eccfd6f9f0333f97.patch";
69 hash = "sha256-scdtpX1RYFFEDHcaMb+gDZSsPafkvnIO/wQlpzrQhLA=";
70 })
71 ];
72
73 build-system = [
74 setuptools
75 wheel
76 ];
77
78 dependencies = [
79 filelock
80 huggingface-hub
81 importlib-metadata
82 numpy
83 pillow
84 regex
85 requests
86 safetensors
87 ];
88
89 passthru.optional-dependencies = {
90 flax = [
91 flax
92 jax
93 jaxlib
94 ];
95 torch = [
96 accelerate
97 torch
98 ];
99 training = [
100 accelerate
101 datasets
102 jinja2
103 peft
104 protobuf
105 tensorboard
106 ];
107 };
108
109 pythonImportsCheck = [ "diffusers" ];
110
111 # tests crash due to torch segmentation fault
112 doCheck = !(stdenv.isLinux && stdenv.isAarch64);
113
114 nativeCheckInputs = [
115 parameterized
116 pytest-timeout
117 pytest-xdist
118 pytestCheckHook
119 requests-mock
120 scipy
121 sentencepiece
122 torchsde
123 transformers
124 ] ++ passthru.optional-dependencies.torch;
125
126 preCheck =
127 let
128 # This pytest hook mocks and catches attempts at accessing the network
129 # tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed.
130 # cf. python3Packages.shap
131 conftestSkipNetworkErrors = writeText "conftest.py" ''
132 from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport
133 import urllib3
134
135 class NetworkAccessDeniedError(RuntimeError): pass
136 def deny_network_access(*a, **kw):
137 raise NetworkAccessDeniedError
138
139 urllib3.connection.HTTPSConnection._new_conn = deny_network_access
140
141 def pytest_runtest_makereport(item, call):
142 tr = orig_pytest_runtest_makereport(item, call)
143 if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError:
144 tr.outcome = 'skipped'
145 tr.wasxfail = "reason: Requires network access."
146 return tr
147 '';
148 in
149 ''
150 export HOME=$TMPDIR
151 cat ${conftestSkipNetworkErrors} >> tests/conftest.py
152 '';
153
154 pytestFlagsArray = [ "tests/" ];
155
156 disabledTests =
157 [
158 # depends on current working directory
159 "test_deprecate_stacklevel"
160 # fails due to precision of floating point numbers
161 "test_model_cpu_offload_forward_pass"
162 # tries to run ruff which we have intentionally removed from nativeCheckInputs
163 "test_is_copy_consistent"
164 ]
165 ++ lib.optionals (pythonAtLeast "3.12") [
166
167 # RuntimeError: Dynamo is not supported on Python 3.12+
168 "test_from_save_pretrained_dynamo"
169 ];
170
171 meta = with lib; {
172 description = "State-of-the-art diffusion models for image and audio generation in PyTorch";
173 mainProgram = "diffusers-cli";
174 homepage = "https://github.com/huggingface/diffusers";
175 changelog = "https://github.com/huggingface/diffusers/releases/tag/${src.rev}";
176 license = licenses.asl20;
177 maintainers = with maintainers; [ natsukium ];
178 };
179}