1{ lib
2, stdenv
3, buildPythonPackage
4, fetchPypi
5, pytestCheckHook
6, flaky
7, numpy
8, pandas
9, torch
10, scikit-learn
11, scipy
12, tabulate
13, tqdm
14}:
15
16buildPythonPackage rec {
17 pname = "skorch";
18 version = "0.13.0";
19
20 src = fetchPypi {
21 inherit pname version;
22 hash = "sha256-k9Zs4uqskHLqVHOKK7dIOmBSUmbDpOMuPS9eSdxNjO0=";
23 };
24
25 propagatedBuildInputs = [ numpy torch scikit-learn scipy tabulate tqdm ];
26 nativeCheckInputs = [ flaky pandas pytestCheckHook ];
27
28 # patch out pytest-cov dep/invocation
29 postPatch = ''
30 substituteInPlace setup.cfg \
31 --replace "--cov=skorch" "" \
32 --replace "--cov-report=term-missing" "" \
33 --replace "--cov-config .coveragerc" ""
34 '';
35
36 disabledTests = [
37 # on CPU, these expect artifacts from previous GPU run
38 "test_load_cuda_params_to_cpu"
39 # failing tests
40 "test_pickle_load"
41 ] ++ lib.optionals stdenv.isDarwin [
42 # there is a problem with the compiler selection
43 "test_fit_and_predict_with_compile"
44 ];
45
46 disabledTestPaths = [
47 # tries to import `transformers` and download HuggingFace data
48 "skorch/tests/test_hf.py"
49 ] ++ lib.optionals (stdenv.hostPlatform.system != "x86_64-linux") [
50 # torch.distributed is disabled by default in darwin
51 # aarch64-linux also failed these tests
52 "skorch/tests/test_history.py"
53 ];
54
55 pythonImportsCheck = [ "skorch" ];
56
57 meta = with lib; {
58 description = "Scikit-learn compatible neural net library using Pytorch";
59 homepage = "https://skorch.readthedocs.io";
60 changelog = "https://github.com/skorch-dev/skorch/blob/master/CHANGES.md";
61 license = licenses.bsd3;
62 maintainers = with maintainers; [ bcdarwin ];
63 };
64}