1{ 2 lib, 3 stdenv, 4 buildPythonPackage, 5 pythonOlder, 6 fetchFromGitHub, 7 fetchpatch, 8 writeText, 9 setuptools, 10 wheel, 11 filelock, 12 huggingface-hub, 13 importlib-metadata, 14 numpy, 15 pillow, 16 regex, 17 requests, 18 safetensors, 19 # optional dependencies 20 accelerate, 21 datasets, 22 flax, 23 jax, 24 jaxlib, 25 jinja2, 26 peft, 27 protobuf, 28 tensorboard, 29 torch, 30 # test dependencies 31 parameterized, 32 pytest-timeout, 33 pytest-xdist, 34 pytestCheckHook, 35 requests-mock, 36 scipy, 37 sentencepiece, 38 torchsde, 39 transformers, 40 pythonAtLeast, 41}: 42 43buildPythonPackage rec { 44 pname = "diffusers"; 45 version = "0.27.2"; 46 pyproject = true; 47 48 disabled = pythonOlder "3.8"; 49 50 src = fetchFromGitHub { 51 owner = "huggingface"; 52 repo = "diffusers"; 53 rev = "refs/tags/v${version}"; 54 hash = "sha256-aRnbU3jN40xaCsoMFyRt1XB+hyIYMJP2b/T1yZho90c="; 55 }; 56 57 patches = [ 58 # fix python3.12 build 59 (fetchpatch { 60 # https://github.com/huggingface/diffusers/pull/7455 61 name = "001-remove-distutils.patch"; 62 url = "https://github.com/huggingface/diffusers/compare/363699044e365ef977a7646b500402fa585e1b6b...3c67864c5acb30413911730b1ed4a9ad47c0a15c.patch"; 63 hash = "sha256-Qyvyp1GyTVXN+A+lA1r2hf887ubTtaUknbKd4r46NZQ="; 64 }) 65 (fetchpatch { 66 # https://github.com/huggingface/diffusers/pull/7461 67 name = "002-fix-removed-distutils.patch"; 68 url = "https://github.com/huggingface/diffusers/commit/efbbbc38e436a1abb1df41a6eccfd6f9f0333f97.patch"; 69 hash = "sha256-scdtpX1RYFFEDHcaMb+gDZSsPafkvnIO/wQlpzrQhLA="; 70 }) 71 ]; 72 73 build-system = [ 74 setuptools 75 wheel 76 ]; 77 78 dependencies = [ 79 filelock 80 huggingface-hub 81 importlib-metadata 82 numpy 83 pillow 84 regex 85 requests 86 safetensors 87 ]; 88 89 passthru.optional-dependencies = { 90 flax = [ 91 flax 92 jax 93 jaxlib 94 ]; 95 torch = [ 96 accelerate 97 torch 98 ]; 99 training = [ 100 accelerate 101 datasets 102 jinja2 103 peft 104 protobuf 105 tensorboard 106 ]; 107 }; 108 109 pythonImportsCheck = [ "diffusers" ]; 110 111 # tests crash due to torch segmentation fault 112 doCheck = !(stdenv.isLinux && stdenv.isAarch64); 113 114 nativeCheckInputs = [ 115 parameterized 116 pytest-timeout 117 pytest-xdist 118 pytestCheckHook 119 requests-mock 120 scipy 121 sentencepiece 122 torchsde 123 transformers 124 ] ++ passthru.optional-dependencies.torch; 125 126 preCheck = 127 let 128 # This pytest hook mocks and catches attempts at accessing the network 129 # tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed. 130 # cf. python3Packages.shap 131 conftestSkipNetworkErrors = writeText "conftest.py" '' 132 from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport 133 import urllib3 134 135 class NetworkAccessDeniedError(RuntimeError): pass 136 def deny_network_access(*a, **kw): 137 raise NetworkAccessDeniedError 138 139 urllib3.connection.HTTPSConnection._new_conn = deny_network_access 140 141 def pytest_runtest_makereport(item, call): 142 tr = orig_pytest_runtest_makereport(item, call) 143 if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError: 144 tr.outcome = 'skipped' 145 tr.wasxfail = "reason: Requires network access." 146 return tr 147 ''; 148 in 149 '' 150 export HOME=$TMPDIR 151 cat ${conftestSkipNetworkErrors} >> tests/conftest.py 152 ''; 153 154 pytestFlagsArray = [ "tests/" ]; 155 156 disabledTests = 157 [ 158 # depends on current working directory 159 "test_deprecate_stacklevel" 160 # fails due to precision of floating point numbers 161 "test_model_cpu_offload_forward_pass" 162 # tries to run ruff which we have intentionally removed from nativeCheckInputs 163 "test_is_copy_consistent" 164 ] 165 ++ lib.optionals (pythonAtLeast "3.12") [ 166 167 # RuntimeError: Dynamo is not supported on Python 3.12+ 168 "test_from_save_pretrained_dynamo" 169 ]; 170 171 meta = with lib; { 172 description = "State-of-the-art diffusion models for image and audio generation in PyTorch"; 173 mainProgram = "diffusers-cli"; 174 homepage = "https://github.com/huggingface/diffusers"; 175 changelog = "https://github.com/huggingface/diffusers/releases/tag/${src.rev}"; 176 license = licenses.asl20; 177 maintainers = with maintainers; [ natsukium ]; 178 }; 179}