1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchPypi,
6 pythonOlder,
7 numpy,
8 scikit-learn,
9 scipy,
10 tabulate,
11 torch,
12 tqdm,
13 flaky,
14 pandas,
15 pytestCheckHook,
16 safetensors,
17 pythonAtLeast,
18}:
19
20buildPythonPackage rec {
21 pname = "skorch";
22 version = "1.0.0";
23 format = "setuptools";
24
25 src = fetchPypi {
26 inherit pname version;
27 hash = "sha256-JcplwaeYlGRAJXRNac1Ya/hgWoHE+NWjZhCU9eaSyRQ=";
28 };
29
30 disabled = pythonOlder "3.8";
31
32 propagatedBuildInputs = [
33 numpy
34 scikit-learn
35 scipy
36 tabulate
37 torch
38 tqdm
39 ];
40
41 nativeCheckInputs = [
42 flaky
43 pandas
44 pytestCheckHook
45 safetensors
46 ];
47
48 # patch out pytest-cov dep/invocation
49 postPatch = ''
50 substituteInPlace setup.cfg \
51 --replace "--cov=skorch" "" \
52 --replace "--cov-report=term-missing" "" \
53 --replace "--cov-config .coveragerc" ""
54 '';
55
56 disabledTests =
57 [
58 # on CPU, these expect artifacts from previous GPU run
59 "test_load_cuda_params_to_cpu"
60 # failing tests
61 "test_pickle_load"
62 ]
63 ++ lib.optionals stdenv.hostPlatform.isDarwin [
64 # there is a problem with the compiler selection
65 "test_fit_and_predict_with_compile"
66 ]
67 ++ lib.optionals (pythonAtLeast "3.11") [
68 # Python 3.11+ not yet supported for torch.compile
69 # https://github.com/pytorch/pytorch/blob/v2.0.1/torch/_dynamo/eval_frame.py#L376-L377
70 "test_fit_and_predict_with_compile"
71 ];
72
73 disabledTestPaths =
74 [
75 # tries to import `transformers` and download HuggingFace data
76 "skorch/tests/test_hf.py"
77 ]
78 ++ lib.optionals (stdenv.hostPlatform.system != "x86_64-linux") [
79 # torch.distributed is disabled by default in darwin
80 # aarch64-linux also failed these tests
81 "skorch/tests/test_history.py"
82 ];
83
84 pythonImportsCheck = [ "skorch" ];
85
86 meta = with lib; {
87 description = "Scikit-learn compatible neural net library using Pytorch";
88 homepage = "https://skorch.readthedocs.io";
89 changelog = "https://github.com/skorch-dev/skorch/blob/master/CHANGES.md";
90 license = licenses.bsd3;
91 maintainers = with maintainers; [ bcdarwin ];
92 };
93}