written by Eric J. Ma on 2024-06-01 | tags: cuda jax conda environment variables cudnn python gpu dynamic libraries nvidia software installation
In this blog post, I share how to resolve CUDA backend initialization issues when installing JAX with CUDA, specifically addressing outdated cuDNN versions. I detail a method using Conda environments to manage CUDA installations and set environment variables correctly, offering two solutions: configuring LD_LIBRARY_PATH through Conda's activate.d and deactivate.d scripts, or directly within a Python session using a .env file. Both approaches aim to ensure that JAX utilizes the correct CUDA libraries, but each has its tradeoffs regarding portability. Curious about which method might work best for your setup?
UPDATE 9 June 2024: With thanks to the kind efforts of Vyas Ramasubramani, there is actually no need to set
LD_LIBRARY_PATH
at all, as long as one'sconda
environment is set up correctly! I detail what works right at the end of the blog post.
If you're like me, you've tried to install JAX pre-compiled against CUDA, and have probably also wrestled with issues that look like this:
CUDA backend failed to initialize: Unable to use CUDA because of the following issues with CUDA components: Outdated cuDNN installation found. Version JAX was built against: 8907 Minimum supported: 8900 Installed version: 8600 The local installation version must be no lower than 8900..(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
And then you may have sat there scratching your head wondering, "but wait, I swear that I have the latest CUDA drivers, CUDA libraries, and cuDNN! Why is CUDA so hard?!"
Well, now, there's a way out! As it turns out, independent of your system-level CUDA installation, if you're using Conda environments, you actually can manage your own environment-level CUDA installation! And as it turns out, this is mostly a matter of setting environment variables correctly.
environment.yml
fileYour environment.yml
file should look like this:
name: my-env channels: - nvidia # order matters - conda-forge dependencies: - ... - cuda-libraries # this is the cuda metapackage - cudatoolkit # explicitly needed so conda/mamba installs the cuda-compiled jaxlib - cudnn # this is specifically for cudnn - cuda-nvcc # ensures that a compatible nvidia C compiler is available! - jaxlib # ensure that this is installed by conda/mamba, not by pip! - jax - ...
Now, run mamba env update -f environment.yml
to install the packages.
What will happen is the following:
$CONDA_PREFIX/lib
directory.jax
and jaxlib
will be installed as Python packages as well, within the environment specified by the $CONDA_PREFIX
environment variable.You may be tempted to run your JAX code at this point, but you'll still run into the aforementioned error.
LD_LIBRARY_PATH
includes $CONDA_PREFIX/lib
Now, we need to set environment variables, specifically the LD_LIBRARY_PATH
environment variable. The LD_LIBRARY_PATH
is used to specify a list of directories to look for "dynamic libraries" before searching the standard UNIX library paths. At this point, options multiply: we may need to make judgment calls or tradeoffs. As I see it, there are two sane places where we can configure LD_LIBRARY_PATH
to be correct.
activate.d
and deactivate.d
activate.d
and deactivate.d
are folders housing shell scripts that are automatically run whenever we do conda activate
and conda deactivate
, respectively. Jaerock Kwon wrote a constructive blog post on ensuring that these environment variables are set correctly. Essentially, it is the shell script below, which should be run after activating an environment:
mkdir -p $CONDA_PREFIX/etc/conda/activate.d echo 'export OLD_LD_LIBRARY_PATH=${LD_LIBRARY_PATH}' > \ $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh echo 'export LD_LIBRARY_PATH=$CONDA_PREFIX/lib/:$OLD_LD_LIBRARY_PATH' >> \ $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh mkdir -p $CONDA_PREFIX/etc/conda/deactivate.d echo 'export LD_LIBRARY_PATH=${OLD_LD_LIBRARY_PATH}' > \ $CONDA_PREFIX/etc/conda/deactivate.d/env_vars.sh echo 'unset OLD_LD_LIBRARY_PATH' >> \ $CONDA_PREFIX/etc/conda/deactivate.d/env_vars.sh
The tradeoff here is that this is not portable from machine to machine. We'd have to run that chunk of code every single time we create a new environment that uses CUDA libraries or try to recreate the environment in a new machine or a Docker container. Ideally, we would specify this information within environment.yml (link found via StackOverflow).
.env
and load it into your Python sessionAs it turns out, there is another way! Before importing JAX and executing JAX code, we can directly set the LD_LIBRARY_PATH
environment variable within our Python session through environment variables that we load at runtime. To start, we need to create a .env
file in our repository:
# These are the contents of /path/to/repo/.env XLA_FLAGS="--xla_gpu_cuda_data_dir=${CONDA_PREFIX}/lib" LD_LIBRARY_PATH="${CONDA_PREFIX}/lib:${LD_LIBRARY_PATH}"
Then, within a Jupyter notebook (or a Python session more generally), we use python-dotenv
to inject those environment variables into the currently running Python session:
from dotenv import load_dotenv load_dotenv()
We should be able to verify that LD_LIBRARY_PATH
is set correctly by running:
import os print(os.getenv("LD_LIBRARY_PATH")) print(os.getenv("XLA_FLAGS"))
What gets printed should follow the pattern set in the .env
file.
Now, you can create NumPy arrays and shouldn't observe any issues with outdated cuDNN:
import jax.numpy as np a = np.linspace(0, 3, 1000)
To verify that you're also using GPU, at the terminal, run the following:
nvidia-smi
And you should see something like the following:
❯ nvidia-smi Fri May 31 12:31:08 2024 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 NVIDIA A10G On | 00000000:00:1E.0 Off | 0 | | 0% 28C P0 59W / 300W | 17434MiB / 23028MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | 0 N/A N/A 1947397 C ...vironment-name/bin/python 17432MiB | +-----------------------------------------------------------------------------+
Both are good, and both come with a common tradeoff that we need to remember: the environment configuration isn't portable from machine to machine.
In the case of the env_var.sh
scripts housed under activate.d
and deactivate.d
,
they need to be recreated any time the conda environment is recreated.
This means that if one deletes the environment and recreates it,
or simply moves to a different machine and recreates it,
one will need to re-run the shell commands listed above
to recreate the environment variables.
In the case of .env
configuration, this is much more lightweight,
but as a matter of common idioms, .env
files are not committed to source control,
which makes distributing them a hassle.
My preference is to use .env
files,
as they are more easily accessible to us as we work within a project repository.
My solution to the problem of portability
is to ensure that the specific configuration of individual .env
files
is at least initialized as part of a standard project structure
(akin to what I do with pyds-cli
or what cookie-cutter-data-science
does),
so that newly created repositories come with these environment variables.
And in the absolute worst case,
one can simply copy/paste those two lines between two repos' .env
files.
UPDATE (9 June 2024): Thanks to Vyas Ramasubramani from NVIDIA,
who did a deep dive after reading my blog post,
it turns out it is unnecessary to set environment variables
as long as the correct CUDA packages are installed.
cudatoolkit
was necessary with CUDA11, but is no longer necessary with CUDA12.
An example environment.yml
file that I verified as working on my system are:
name: my-env channels: - conda-forge dependencies: - python=3.10 - cuda-libraries # this is the cuda metapackage - cudnn # this is specifically for cudnn - cuda-nvcc # ensures that a compatible nvidia C compiler is available! - jaxlib - cuda-version # necessary for pulling in packages that are CUDA12-compatible! - jax
As I learned through the discussion thread that Vyas created,
cudatoolkit
is needed for CUDA11, while cuda-version
is usable for any version of CUDA running back to version 9.2.
This turns out to be the result of
hard work by open source community members who also work at NVIDIA
(see this thread for an example).
Back in graduate school, NVIDIA's stack used to be confusing to me,
but I'm thankful that many people are hard at work making it easier to use!
(This is also why I love the internet:
If I have a misconception and write about it,
there will be wonderful people like Vyas to correct me!)
@article{
ericmjl-2024-how-environments,
author = {Eric J. Ma},
title = {How to manage CUDA libraries within Conda environments},
year = {2024},
month = {06},
day = {01},
howpublished = {\url{https://ericmjl.github.io}},
journal = {Eric J. Ma's Blog},
url = {https://ericmjl.github.io/blog/2024/6/1/how-to-manage-cuda-libraries-within-conda-environments},
}
I send out a newsletter with tips and tools for data scientists. Come check it out at Substack.
If you would like to sponsor the coffee that goes into making my posts, please consider GitHub Sponsors!
Finally, I do free 30-minute GenAI strategy calls for teams that are looking to leverage GenAI for maximum impact. Consider booking a call on Calendly if you're interested!