1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchPypi,
6 fetchpatch,
7 pythonOlder,
8 numpy,
9 scikit-learn,
10 scipy,
11 tabulate,
12 torch,
13 tqdm,
14 flaky,
15 pandas,
16 pytestCheckHook,
17 safetensors,
18 pythonAtLeast,
19}:
20
21buildPythonPackage rec {
22 pname = "skorch";
23 version = "0.15.0";
24 format = "setuptools";
25
26 src = fetchPypi {
27 inherit pname version;
28 hash = "sha256-39XVBlCmbg162z9uL84GZrU+v+M8waXbGdVV72ZYf84=";
29 };
30
31 # Remove at next skorch release:
32 patches = [
33 (fetchpatch {
34 name = "unbreak-tests-with-sklearn-1.4";
35 url = "https://github.com/skorch-dev/skorch/commit/1f7a779d0aa78589e17262c206f5775f2fcd75f8.diff";
36 hash = "sha256-X3SgjgDeq3PlBI13cC56LIL1dV1e+Z3tsBj9sz5pizo=";
37 })
38 ];
39
40 disabled = pythonOlder "3.8";
41
42 propagatedBuildInputs = [
43 numpy
44 scikit-learn
45 scipy
46 tabulate
47 torch
48 tqdm
49 ];
50
51 nativeCheckInputs = [
52 flaky
53 pandas
54 pytestCheckHook
55 safetensors
56 ];
57
58 # patch out pytest-cov dep/invocation
59 postPatch = ''
60 substituteInPlace setup.cfg \
61 --replace "--cov=skorch" "" \
62 --replace "--cov-report=term-missing" "" \
63 --replace "--cov-config .coveragerc" ""
64 '';
65
66 disabledTests =
67 [
68 # on CPU, these expect artifacts from previous GPU run
69 "test_load_cuda_params_to_cpu"
70 # failing tests
71 "test_pickle_load"
72 ]
73 ++ lib.optionals stdenv.isDarwin [
74 # there is a problem with the compiler selection
75 "test_fit_and_predict_with_compile"
76 ]
77 ++ lib.optionals (pythonAtLeast "3.11") [
78 # Python 3.11+ not yet supported for torch.compile
79 # https://github.com/pytorch/pytorch/blob/v2.0.1/torch/_dynamo/eval_frame.py#L376-L377
80 "test_fit_and_predict_with_compile"
81 ];
82
83 disabledTestPaths =
84 [
85 # tries to import `transformers` and download HuggingFace data
86 "skorch/tests/test_hf.py"
87 ]
88 ++ lib.optionals (stdenv.hostPlatform.system != "x86_64-linux") [
89 # torch.distributed is disabled by default in darwin
90 # aarch64-linux also failed these tests
91 "skorch/tests/test_history.py"
92 ];
93
94 pythonImportsCheck = [ "skorch" ];
95
96 meta = with lib; {
97 description = "Scikit-learn compatible neural net library using Pytorch";
98 homepage = "https://skorch.readthedocs.io";
99 changelog = "https://github.com/skorch-dev/skorch/blob/master/CHANGES.md";
100 license = licenses.bsd3;
101 maintainers = with maintainers; [ bcdarwin ];
102 };
103}