1{
2 lib,
3 buildPythonPackage,
4 fetchPypi,
5 fetchpatch,
6
7 # dependencies
8 einops,
9 emoji,
10 flax,
11 ftfy,
12 jax,
13 jaxlib,
14 orbax-checkpoint,
15 pillow,
16 pydantic,
17 transformers,
18 unidecode,
19 wandb,
20}:
21
22buildPythonPackage rec {
23 pname = "dalle-mini";
24 version = "0.1.5";
25 pyproject = true;
26
27 src = fetchPypi {
28 inherit pname version;
29 hash = "sha256-k4XILjNNz0FPcAzwPEeqe5Lj24S2Y139uc9o/1IUS1c=";
30 };
31
32 # Fix incompatibility with the latest JAX versions
33 # See https://github.com/borisdayma/dalle-mini/pull/338
34 patches = [
35 (fetchpatch {
36 url = "https://github.com/borisdayma/dalle-mini/pull/338/commits/22ffccf03f3e207731a481e3e42bdb564ceebb69.patch";
37 hash = "sha256-LIOyfeq/oVYukG+1rfy5PjjsJcjADCjn18x/hVmLkPY=";
38 })
39 ];
40
41 pythonRelaxDeps = [
42 "transformers"
43 "jax"
44 "flax"
45 ];
46
47 pythonRemoveDeps = [
48 "orbax"
49 ];
50
51 dependencies = [
52 einops
53 emoji
54 flax
55 ftfy
56 jax
57 jaxlib
58 orbax-checkpoint
59 pillow
60 pydantic
61 transformers
62 unidecode
63 wandb
64 ];
65
66 doCheck = false; # no upstream tests
67
68 pythonImportsCheck = [ "dalle_mini" ];
69
70 meta = {
71 description = "Generate images from a text prompt";
72 homepage = "https://github.com/borisdayma/dalle-mini";
73 license = lib.licenses.asl20;
74 maintainers = with lib.maintainers; [ r-burns ];
75 };
76}