1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 setuptools,
8 setuptools-scm,
9
10 # dependencies
11 torch,
12 triton,
13
14 # optional-dependencies
15 accelerate,
16 datasets,
17 fire,
18 huggingface-hub,
19 pandas,
20 pytestCheckHook,
21 tqdm,
22 transformers,
23}:
24
25buildPythonPackage {
26 pname = "cut-cross-entropy";
27 version = "25.5.1";
28 pyproject = true;
29
30 # The `ml-cross-entropy` Pypi comes from a third-party.
31 # Apple recommends installing from the repo's main branch directly
32 src = fetchFromGitHub {
33 owner = "apple";
34 repo = "ml-cross-entropy";
35 rev = "b616b222976b235647790a16d0388338b9e18941"; # no tags
36 hash = "sha256-BVPon+T7chkpozX/IZU3KZMw1zRzlYVvF/22JWKjT2Y=";
37 };
38
39 build-system = [
40 setuptools
41 setuptools-scm
42 ];
43
44 dependencies = [
45 torch
46 triton
47 ];
48
49 optional-dependencies = {
50 transformers = [ transformers ];
51 all = [
52 accelerate
53 datasets
54 fire
55 huggingface-hub
56 pandas
57 tqdm
58 transformers
59 ];
60 # `deepspeed` is not yet packaged in nixpkgs
61 # ++ lib.optionals (!stdenv.hostPlatform.isDarwin) [
62 # deepspeed
63 # ];
64 };
65
66 nativeCheckInputs = [ pytestCheckHook ];
67
68 pythonImportsCheck = [
69 "cut_cross_entropy"
70 ];
71
72 meta = {
73 description = "Memory-efficient cross-entropy loss implementation using Cut Cross-Entropy (CCE)";
74 homepage = "https://github.com/apple/ml-cross-entropy";
75 license = lib.licenses.aml;
76 maintainers = with lib.maintainers; [ hoh ];
77 };
78}