1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6 rustPlatform,
7
8 # optional-dependencies
9 numpy,
10 torch,
11 tensorflow,
12 flax,
13 jax,
14 mlx,
15 paddlepaddle,
16 h5py,
17 huggingface-hub,
18 setuptools-rust,
19 pytest,
20 pytest-benchmark,
21 hypothesis,
22
23 # tests
24 pytestCheckHook,
25}:
26
27buildPythonPackage rec {
28 pname = "safetensors";
29 version = "0.6.0";
30 pyproject = true;
31
32 src = fetchFromGitHub {
33 owner = "huggingface";
34 repo = "safetensors";
35 tag = "v${version}";
36 hash = "sha256-wAr/jvr0w+vOHjjqE7cPcAM/IMz+58YhfoJ2XC4987M=";
37 };
38
39 sourceRoot = "${src.name}/bindings/python";
40
41 cargoDeps = rustPlatform.importCargoLock {
42 lockFile = ./Cargo.lock;
43 };
44
45 postPatch = ''
46 ln -s ${./Cargo.lock} Cargo.lock
47 '';
48
49 nativeBuildInputs = [
50 rustPlatform.cargoSetupHook
51 rustPlatform.maturinBuildHook
52 ];
53
54 optional-dependencies = lib.fix (self: {
55 numpy = [ numpy ];
56 torch = self.numpy ++ [
57 torch
58 ];
59 tensorflow = self.numpy ++ [
60 tensorflow
61 ];
62 pinned-tf = self.numpy ++ [
63 tensorflow
64 ];
65 jax = self.numpy ++ [
66 flax
67 jax
68 ];
69 mlx = [
70 mlx
71 ];
72 paddlepaddle = self.numpy ++ [
73 paddlepaddle
74 ];
75 testing = self.numpy ++ [
76 h5py
77 huggingface-hub
78 setuptools-rust
79 pytest
80 pytest-benchmark
81 hypothesis
82 ];
83 all = self.torch ++ self.numpy ++ self.pinned-tf ++ self.jax ++ self.paddlepaddle ++ self.testing;
84 dev = self.all;
85 });
86
87 nativeCheckInputs = [
88 h5py
89 numpy
90 pytestCheckHook
91 torch
92 ];
93
94 enabledTestPaths = [ "tests" ];
95
96 disabledTests = [
97 # AttributeError: module 'torch' has no attribute 'float4_e2m1fn_x2'
98 "test_odd_dtype_fp4"
99
100 # AssertionError: 'No such file or directory: notafile' != 'No such file or directory: "notafile"'
101 "test_file_not_found"
102
103 # AssertionError:
104 # 'Erro[41 chars] 5]: index 20 out of bounds for tensor dimension #1 of size 5'
105 # != 'Erro[41 chars] 5]: SliceOutOfRange { dim_index: 1, asked: 20, dim_size: 5 }'
106 "test_numpy_slice"
107 ];
108
109 # don't require PaddlePaddle (not in Nixpkgs), Flax, or Tensorflow (onerous) to run tests:
110 disabledTestPaths = [
111 "tests/test_flax_comparison.py"
112 "tests/test_paddle_comparison.py"
113 "tests/test_tf_comparison.py"
114 ]
115 ++ lib.optionals stdenv.hostPlatform.isDarwin [
116 # don't require mlx (not in Nixpkgs) to run tests
117 "tests/test_mlx_comparison.py"
118 ];
119
120 pythonImportsCheck = [ "safetensors" ];
121
122 meta = {
123 homepage = "https://github.com/huggingface/safetensors";
124 description = "Fast (zero-copy) and safe (unlike pickle) format for storing tensors";
125 changelog = "https://github.com/huggingface/safetensors/releases/tag/v${version}";
126 license = lib.licenses.asl20;
127 maintainers = with lib.maintainers; [ bcdarwin ];
128 };
129}