1{
2 lib,
3 buildPythonPackage,
4 pythonOlder,
5 fetchFromGitHub,
6 writeText,
7 setuptools,
8 filelock,
9 huggingface-hub,
10 importlib-metadata,
11 numpy,
12 pillow,
13 regex,
14 requests,
15 safetensors,
16 # optional dependencies
17 accelerate,
18 datasets,
19 flax,
20 jax,
21 jaxlib,
22 jinja2,
23 peft,
24 protobuf,
25 tensorboard,
26 torch,
27 # test dependencies
28 parameterized,
29 pytest-timeout,
30 pytest-xdist,
31 pytestCheckHook,
32 requests-mock,
33 scipy,
34 sentencepiece,
35 torchsde,
36 transformers,
37 pythonAtLeast,
38 diffusers,
39}:
40
41buildPythonPackage rec {
42 pname = "diffusers";
43 version = "0.30.3";
44 pyproject = true;
45
46 disabled = pythonOlder "3.8";
47
48 src = fetchFromGitHub {
49 owner = "huggingface";
50 repo = "diffusers";
51 rev = "refs/tags/v${version}";
52 hash = "sha256-/3lHJdsNblKb6xX03OluSCApMK3EXJbRLboBk8CjobE=";
53 };
54
55 build-system = [ setuptools ];
56
57 dependencies = [
58 filelock
59 huggingface-hub
60 importlib-metadata
61 numpy
62 pillow
63 regex
64 requests
65 safetensors
66 ];
67
68 optional-dependencies = {
69 flax = [
70 flax
71 jax
72 jaxlib
73 ];
74 torch = [
75 accelerate
76 torch
77 ];
78 training = [
79 accelerate
80 datasets
81 jinja2
82 peft
83 protobuf
84 tensorboard
85 ];
86 };
87
88 pythonImportsCheck = [ "diffusers" ];
89
90 # it takes a few hours
91 doCheck = false;
92
93 passthru.tests.pytest = diffusers.overridePythonAttrs { doCheck = true; };
94
95 nativeCheckInputs = [
96 parameterized
97 pytest-timeout
98 pytest-xdist
99 pytestCheckHook
100 requests-mock
101 scipy
102 sentencepiece
103 torchsde
104 transformers
105 ] ++ optional-dependencies.torch;
106
107 preCheck =
108 let
109 # This pytest hook mocks and catches attempts at accessing the network
110 # tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed.
111 # cf. python3Packages.shap
112 conftestSkipNetworkErrors = writeText "conftest.py" ''
113 from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport
114 import urllib3
115
116 class NetworkAccessDeniedError(RuntimeError): pass
117 def deny_network_access(*a, **kw):
118 raise NetworkAccessDeniedError
119
120 urllib3.connection.HTTPSConnection._new_conn = deny_network_access
121
122 def pytest_runtest_makereport(item, call):
123 tr = orig_pytest_runtest_makereport(item, call)
124 if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError:
125 tr.outcome = 'skipped'
126 tr.wasxfail = "reason: Requires network access."
127 return tr
128 '';
129 in
130 ''
131 export HOME=$TMPDIR
132 cat ${conftestSkipNetworkErrors} >> tests/conftest.py
133 '';
134
135 pytestFlagsArray = [ "tests/" ];
136
137 disabledTests =
138 [
139 # depends on current working directory
140 "test_deprecate_stacklevel"
141 # fails due to precision of floating point numbers
142 "test_model_cpu_offload_forward_pass"
143 # tries to run ruff which we have intentionally removed from nativeCheckInputs
144 "test_is_copy_consistent"
145 ]
146 ++ lib.optionals (pythonAtLeast "3.12") [
147
148 # RuntimeError: Dynamo is not supported on Python 3.12+
149 "test_from_save_pretrained_dynamo"
150 ];
151
152 meta = with lib; {
153 description = "State-of-the-art diffusion models for image and audio generation in PyTorch";
154 mainProgram = "diffusers-cli";
155 homepage = "https://github.com/huggingface/diffusers";
156 changelog = "https://github.com/huggingface/diffusers/releases/tag/${lib.removePrefix "refs/tags/" src.rev}";
157 license = licenses.asl20;
158 maintainers = with maintainers; [ natsukium ];
159 };
160}