1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 6 # build-system 7 setuptools, 8 9 # dependencies 10 filelock, 11 huggingface-hub, 12 importlib-metadata, 13 numpy, 14 pillow, 15 regex, 16 requests, 17 safetensors, 18 19 # optional dependencies 20 accelerate, 21 datasets, 22 flax, 23 jax, 24 jaxlib, 25 jinja2, 26 peft, 27 protobuf, 28 tensorboard, 29 torch, 30 31 # tests 32 writeText, 33 parameterized, 34 pytest-timeout, 35 pytest-xdist, 36 pytestCheckHook, 37 requests-mock, 38 scipy, 39 sentencepiece, 40 torchsde, 41 transformers, 42 pythonAtLeast, 43 diffusers, 44}: 45 46buildPythonPackage rec { 47 pname = "diffusers"; 48 version = "0.32.2"; 49 pyproject = true; 50 51 src = fetchFromGitHub { 52 owner = "huggingface"; 53 repo = "diffusers"; 54 tag = "v${version}"; 55 hash = "sha256-TwmII38EA0Vux+Jh39pTAA6r+FRNuKHQWOOqsEe2Z+E="; 56 }; 57 58 build-system = [ setuptools ]; 59 60 dependencies = [ 61 filelock 62 huggingface-hub 63 importlib-metadata 64 numpy 65 pillow 66 regex 67 requests 68 safetensors 69 ]; 70 71 optional-dependencies = { 72 flax = [ 73 flax 74 jax 75 jaxlib 76 ]; 77 torch = [ 78 accelerate 79 torch 80 ]; 81 training = [ 82 accelerate 83 datasets 84 jinja2 85 peft 86 protobuf 87 tensorboard 88 ]; 89 }; 90 91 pythonImportsCheck = [ "diffusers" ]; 92 93 # it takes a few hours 94 doCheck = false; 95 96 nativeCheckInputs = [ 97 parameterized 98 pytest-timeout 99 pytest-xdist 100 pytestCheckHook 101 requests-mock 102 scipy 103 sentencepiece 104 torchsde 105 transformers 106 ] ++ lib.flatten (builtins.attrValues optional-dependencies); 107 108 preCheck = 109 let 110 # This pytest hook mocks and catches attempts at accessing the network 111 # tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed. 112 # cf. python3Packages.shap 113 conftestSkipNetworkErrors = writeText "conftest.py" '' 114 from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport 115 import urllib3 116 117 class NetworkAccessDeniedError(RuntimeError): pass 118 def deny_network_access(*a, **kw): 119 raise NetworkAccessDeniedError 120 121 urllib3.connection.HTTPSConnection._new_conn = deny_network_access 122 123 def pytest_runtest_makereport(item, call): 124 tr = orig_pytest_runtest_makereport(item, call) 125 if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError: 126 tr.outcome = 'skipped' 127 tr.wasxfail = "reason: Requires network access." 128 return tr 129 ''; 130 in 131 '' 132 export HOME=$(mktemp -d) 133 cat ${conftestSkipNetworkErrors} >> tests/conftest.py 134 ''; 135 136 pytestFlagsArray = [ "tests/" ]; 137 138 disabledTests = 139 [ 140 # depends on current working directory 141 "test_deprecate_stacklevel" 142 # fails due to precision of floating point numbers 143 "test_full_loop_no_noise" 144 "test_model_cpu_offload_forward_pass" 145 # tries to run ruff which we have intentionally removed from nativeCheckInputs 146 "test_is_copy_consistent" 147 148 # Require unpackaged torchao: 149 # importlib.metadata.PackageNotFoundError: No package metadata was found for torchao 150 "test_load_attn_procs_raise_warning" 151 "test_save_attn_procs_raise_warning" 152 "test_save_load_lora_adapter_0" 153 "test_save_load_lora_adapter_1" 154 "test_wrong_adapter_name_raises_error" 155 ] 156 ++ lib.optionals (pythonAtLeast "3.13") [ 157 # RuntimeError: Dynamo is not supported on Python 3.12+ 158 "test_from_save_pretrained_dynamo" 159 ]; 160 161 passthru.tests.pytest = diffusers.overridePythonAttrs { doCheck = true; }; 162 163 meta = { 164 description = "State-of-the-art diffusion models for image and audio generation in PyTorch"; 165 mainProgram = "diffusers-cli"; 166 homepage = "https://github.com/huggingface/diffusers"; 167 changelog = "https://github.com/huggingface/diffusers/releases/tag/${src.tag}"; 168 license = lib.licenses.asl20; 169 maintainers = with lib.maintainers; [ natsukium ]; 170 }; 171}