1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6 rustPlatform,
7
8 # optional-dependencies
9 numpy,
10 torch,
11 tensorflow,
12 flax,
13 jax,
14 mlx,
15 paddlepaddle,
16 h5py,
17 huggingface-hub,
18 setuptools-rust,
19 pytest,
20 pytest-benchmark,
21 hypothesis,
22
23 # tests
24 pytestCheckHook,
25}:
26
27buildPythonPackage rec {
28 pname = "safetensors";
29 version = "0.5.2";
30 pyproject = true;
31
32 src = fetchFromGitHub {
33 owner = "huggingface";
34 repo = "safetensors";
35 tag = "v${version}";
36 hash = "sha256-dtHHLiTgrg/a/SQ/Z1w0BsuFDClgrMsGiSTCpbJasUs=";
37 };
38
39 sourceRoot = "${src.name}/bindings/python";
40
41 cargoDeps = rustPlatform.fetchCargoVendor {
42 inherit pname src sourceRoot;
43 hash = "sha256-hjV2cfS/0WFyAnATt+A8X8sQLzQViDzkNI7zN0ltgpU=";
44 };
45
46 nativeBuildInputs = [
47 rustPlatform.cargoSetupHook
48 rustPlatform.maturinBuildHook
49 ];
50
51 optional-dependencies = lib.fix (self: {
52 numpy = [ numpy ];
53 torch = self.numpy ++ [
54 torch
55 ];
56 tensorflow = self.numpy ++ [
57 tensorflow
58 ];
59 pinned-tf = self.numpy ++ [
60 tensorflow
61 ];
62 jax = self.numpy ++ [
63 flax
64 jax
65 ];
66 mlx = [
67 mlx
68 ];
69 paddlepaddle = self.numpy ++ [
70 paddlepaddle
71 ];
72 testing = self.numpy ++ [
73 h5py
74 huggingface-hub
75 setuptools-rust
76 pytest
77 pytest-benchmark
78 hypothesis
79 ];
80 all = self.torch ++ self.numpy ++ self.pinned-tf ++ self.jax ++ self.paddlepaddle ++ self.testing;
81 dev = self.all;
82 });
83
84 nativeCheckInputs = [
85 h5py
86 numpy
87 pytestCheckHook
88 torch
89 ];
90 pytestFlagsArray = [ "tests" ];
91 # don't require PaddlePaddle (not in Nixpkgs), Flax, or Tensorflow (onerous) to run tests:
92 disabledTestPaths =
93 [
94 "tests/test_flax_comparison.py"
95 "tests/test_paddle_comparison.py"
96 "tests/test_tf_comparison.py"
97 ]
98 ++ lib.optionals stdenv.hostPlatform.isDarwin [
99 # don't require mlx (not in Nixpkgs) to run tests
100 "tests/test_mlx_comparison.py"
101 ];
102
103 pythonImportsCheck = [ "safetensors" ];
104
105 meta = {
106 homepage = "https://github.com/huggingface/safetensors";
107 description = "Fast (zero-copy) and safe (unlike pickle) format for storing tensors";
108 changelog = "https://github.com/huggingface/safetensors/releases/tag/v${version}";
109 license = lib.licenses.asl20;
110 maintainers = with lib.maintainers; [ bcdarwin ];
111 };
112}