1{
2 lib,
3 buildPythonPackage,
4
5 # build-system
6 pybind11,
7 setuptools,
8
9 # dependencies
10 ctranslate2-cpp,
11 numpy,
12 pyyaml,
13
14 # tests
15 pytestCheckHook,
16 torch,
17 transformers,
18 wurlitzer,
19}:
20
21buildPythonPackage rec {
22 inherit (ctranslate2-cpp) pname version src;
23 format = "setuptools";
24
25 # https://github.com/OpenNMT/CTranslate2/tree/master/python
26 sourceRoot = "${src.name}/python";
27
28 nativeBuildInputs = [
29 pybind11
30 setuptools
31 ];
32
33 buildInputs = [ ctranslate2-cpp ];
34
35 propagatedBuildInputs = [
36 numpy
37 pyyaml
38 ];
39
40 pythonImportsCheck = [
41 # https://opennmt.net/CTranslate2/python/overview.html
42 "ctranslate2"
43 "ctranslate2.converters"
44 "ctranslate2.models"
45 "ctranslate2.specs"
46 ];
47
48 nativeCheckInputs = [
49 pytestCheckHook
50 torch
51 transformers
52 wurlitzer
53 ];
54
55 preCheck = ''
56 # run tests against build result, not sources
57 rm -rf ctranslate2
58
59 export HOME=$TMPDIR
60 '';
61
62 disabledTests = [
63 # AssertionError: assert 'int8' in {'float32'}
64 "test_get_supported_compute_types"
65 # Tensorflow (tf) not available in Python 3.12 yet
66 # To remove when https://github.com/NixOS/nixpkgs/pull/325224 is fixed
67 "test_opennmt_tf_model_conversion"
68 "test_opennmt_tf_model_quantization"
69 "test_opennmt_tf_model_conversion_invalid_vocab"
70 "test_opennmt_tf_model_conversion_invalid_dir"
71 "test_opennmt_tf_shared_embeddings_conversion"
72 "test_opennmt_tf_postnorm_transformer_conversion"
73 "test_opennmt_tf_gpt_conversion"
74 "test_opennmt_tf_multi_features"
75 ];
76
77 disabledTestPaths = [
78 # TODO: ModuleNotFoundError: No module named 'opennmt'
79 "tests/test_opennmt_tf.py"
80 # OSError: We couldn't connect to 'https://huggingface.co' to load this file
81 "tests/test_transformers.py"
82 ];
83
84 meta = with lib; {
85 description = "Fast inference engine for Transformer models";
86 homepage = "https://github.com/OpenNMT/CTranslate2";
87 changelog = "https://github.com/OpenNMT/CTranslate2/blob/${src.rev}/CHANGELOG.md";
88 license = licenses.mit;
89 maintainers = with maintainers; [ hexa ];
90 };
91}