1{ lib
2, buildPythonPackage
3, fetchFromGitHub
4, flax
5, jax
6, jaxlib
7, transformers
8}:
9
10buildPythonPackage rec {
11 pname = "vqgan-jax";
12 version = "unstable-2022-04-20";
13
14 src = fetchFromGitHub {
15 owner = "patil-suraj";
16 repo = "vqgan-jax";
17 rev = "1be20eee476e5d35c30e4ec3ed12222018af8ce4";
18 hash = "sha256-OZihAXpE0UsgauQ38XDmAF+lrIgz05uK0ro8SCdVsPc=";
19 };
20
21 format = "setuptools";
22
23 buildInputs = [
24 jaxlib
25 ];
26
27 propagatedBuildInputs = [
28 flax
29 jax
30 transformers
31 ];
32
33 doCheck = false;
34
35 pythonImportsCheck = [
36 "vqgan_jax"
37 ];
38
39 meta = with lib; {
40 description = "JAX implementation of VQGAN";
41 homepage = "https://github.com/patil-suraj/vqgan-jax";
42 # license unknown: https://github.com/patil-suraj/vqgan-jax/issues/9
43 license = lib.licenses.unfree;
44 maintainers = with maintainers; [ r-burns ];
45 };
46}