···2727 # It will rebuild itself using the version of this package (NSS) and if
2828 # an update is required do the required changes to the expression.
2929 # Example: nix-shell ./maintainers/scripts/update.nix --argstr package cacert
3030- version = "3.68";
3030+ version = "3.70";
31313232in
3333stdenv.mkDerivation rec {
···36363737 src = fetchurl {
3838 url = "mirror://mozilla/security/nss/releases/NSS_${lib.replaceStrings [ "." ] [ "_" ] version}_RTM/src/${pname}-${version}.tar.gz";
3939- sha256 = "0nvj7h2brcw21p1z99nrsxka056d0r1yy9nqqg0lw0w3mhnb60n4";
3939+ sha256 = "sha256-K4mruGAe5AW+isW1cD1x8fs4pRw6ZKPYNDh/eLMlURs=";
4040 };
41414242 depsBuildBuild = [ buildPackages.stdenv.cc ];
···11+{ buildPythonPackage, fetchFromGitHub, lib
22+# propagatedBuildInputs
33+, absl-py, numpy, opt-einsum
44+# checkInputs
55+, jaxlib, pytestCheckHook
66+}:
77+88+buildPythonPackage rec {
99+ pname = "jax";
1010+ version = "0.2.19";
1111+1212+ # Fetching from pypi doesn't allow us to run the test suite. See https://discourse.nixos.org/t/pythonremovetestsdir-hook-being-run-before-checkphase/14612/3.
1313+ src = fetchFromGitHub {
1414+ owner = "google";
1515+ repo = pname;
1616+ rev = "jax-v${version}";
1717+ sha256 = "sha256-pVn62G7pydR7ybkf7gSbu0FlEq2c0US6H2GTBAljup4=";
1818+ };
1919+2020+ # jaxlib is _not_ included in propagatedBuildInputs because there are
2121+ # different versions of jaxlib depending on the desired target hardware. The
2222+ # JAX project ships separate wheels for CPU, GPU, and TPU. Currently only the
2323+ # CPU wheel is packaged.
2424+ propagatedBuildInputs = [ absl-py numpy opt-einsum ];
2525+2626+ checkInputs = [ jaxlib pytestCheckHook ];
2727+ # NOTE: Don't run the tests in the expiremental directory as they require flax
2828+ # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2.
2929+ # Not a big deal, this is how the JAX docs suggest running the test suite
3030+ # anyhow.
3131+ pytestFlagsArray = [ "-W ignore::DeprecationWarning" "tests/" ];
3232+3333+ meta = with lib; {
3434+ description = "Differentiate, compile, and transform Numpy code";
3535+ homepage = "https://github.com/google/jax";
3636+ license = licenses.asl20;
3737+ maintainers = with maintainers; [ samuela ];
3838+ };
3939+}
···11+# For the moment we only support the CPU and GPU backends of jaxlib. The TPU
22+# backend will require some additional work. Those wheels are located here:
33+# https://storage.googleapis.com/jax-releases/libtpu_releases.html.
44+55+# For future reference, the easiest way to test the GPU backend is to run
66+# NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib.override { cudaSupport = true; }"
77+# export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1
88+# python -c "from jax.lib import xla_bridge; assert xla_bridge.get_backend().platform == 'gpu'"
99+# python -c "from jax import random; random.PRNGKey(0)"
1010+# python -c "from jax import random; x = random.normal(random.PRNGKey(0), (100, 100)); x @ x"
1111+# There's no convenient way to test the GPU backend in the derivation since the
1212+# nix build environment blocks access to the GPU. See also:
1313+# * https://github.com/google/jax/issues/971#issuecomment-508216439
1414+# * https://github.com/google/jax/issues/5723#issuecomment-913038780
1515+1616+{ addOpenGLRunpath, autoPatchelfHook, buildPythonPackage, config, fetchPypi
1717+, fetchurl, isPy39, lib, stdenv
1818+# propagatedBuildInputs
1919+, absl-py, flatbuffers, scipy, cudatoolkit_11
2020+# Options:
2121+, cudaSupport ? config.cudaSupport or false
2222+}:
2323+2424+assert cudaSupport -> lib.versionAtLeast cudatoolkit_11.version "11.1";
2525+2626+let
2727+ device = if cudaSupport then "gpu" else "cpu";
2828+in
2929+buildPythonPackage rec {
3030+ pname = "jaxlib";
3131+ version = "0.1.71";
3232+ format = "wheel";
3333+3434+ # At the time of writing (8/19/21), there are releases for 3.7-3.9. Supporting
3535+ # all of them is a pain, so we focus on 3.9, the current nixpkgs python3
3636+ # version.
3737+ disabled = !isPy39;
3838+3939+ src = {
4040+ cpu = fetchurl {
4141+ url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl";
4242+ sha256 = "sha256:0rqhs6qabydizlv5d3rb20dbv6612rr7dqfniy9r6h4kazdinsn6";
4343+ };
4444+ gpu = fetchurl {
4545+ url = "https://storage.googleapis.com/jax-releases/cuda111/jaxlib-${version}+cuda111-cp39-none-manylinux2010_x86_64.whl";
4646+ sha256 = "sha256:065kyzjsk9m84d138p99iymdiiicm1qz8a3iwxz8rspl43rwrw89";
4747+ };
4848+ }.${device};
4949+5050+ # Prebuilt wheels are dynamically linked against things that nix can't find.
5151+ # Run `autoPatchelfHook` to automagically fix them.
5252+ nativeBuildInputs = [ autoPatchelfHook ] ++ lib.optional cudaSupport addOpenGLRunpath;
5353+ # Dynamic link dependencies
5454+ buildInputs = [ stdenv.cc.cc ];
5555+5656+ # jaxlib contains shared libraries that open other shared libraries via dlopen
5757+ # and these implicit dependencies are not recognized by ldd or
5858+ # autoPatchelfHook. That means we need to sneak them into rpath. This step
5959+ # must be done after autoPatchelfHook and the automatic stripping of
6060+ # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the
6161+ # patchPhase. Dependencies:
6262+ # * libcudart.so.11.0 -> cudatoolkit_11.lib
6363+ # * libcublas.so.11 -> cudatoolkit_11
6464+ # * libcuda.so.1 -> opengl driver in /run/opengl-driver/lib
6565+ preInstallCheck = lib.optional cudaSupport ''
6666+ shopt -s globstar
6767+6868+ addOpenGLRunpath $out/**/*.so
6969+7070+ for file in $out/**/*.so; do
7171+ rpath=$(patchelf --print-rpath $file)
7272+ # For some reason `makeLibraryPath` on `cudatoolkit_11` maps to
7373+ # <cudatoolkit_11.lib>/lib which is different from <cudatoolkit_11>/lib.
7474+ patchelf --set-rpath "$rpath:${cudatoolkit_11}/lib:${lib.makeLibraryPath [ cudatoolkit_11.lib ]}" $file
7575+ done
7676+ '';
7777+7878+ # pip dependencies and optionally cudatoolkit.
7979+ propagatedBuildInputs = [ absl-py flatbuffers scipy ] ++ lib.optional cudaSupport cudatoolkit_11;
8080+8181+ pythonImportsCheck = [ "jaxlib" ];
8282+8383+ meta = with lib; {
8484+ description = "XLA library for JAX";
8585+ homepage = "https://github.com/google/jax";
8686+ license = licenses.asl20;
8787+ maintainers = with maintainers; [ samuela ];
8888+ platforms = [ "x86_64-linux" ];
8989+ };
9090+}