Eric J Ma's Website

How to manage CUDA libraries within Conda environments

written by Eric J. Ma on 2024-06-01 | tags: cuda jax conda environment variables cudnn python gpu dynamic libraries nvidia software installation


In this blog post, I share how to resolve CUDA backend initialization issues when installing JAX with CUDA, specifically addressing outdated cuDNN versions. I detail a method using Conda environments to manage CUDA installations and set environment variables correctly, offering two solutions: configuring LD_LIBRARY_PATH through Conda's activate.d and deactivate.d scripts, or directly within a Python session using a .env file. Both approaches aim to ensure that JAX utilizes the correct CUDA libraries, but each has its tradeoffs regarding portability. Curious about which method might work best for your setup?

UPDATE 9 June 2024: With thanks to the kind efforts of Vyas Ramasubramani, there is actually no need to set LD_LIBRARY_PATH at all, as long as one's conda environment is set up correctly! I detail what works right at the end of the blog post.

If you're like me, you've tried to install JAX pre-compiled against CUDA, and have probably also wrestled with issues that look like this:

CUDA backend failed to initialize: Unable to use CUDA because of the following issues with CUDA components:
Outdated cuDNN installation found.
Version JAX was built against: 8907
Minimum supported: 8900
Installed version: 8600
The local installation version must be no lower than 8900..(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

And then you may have sat there scratching your head wondering, "but wait, I swear that I have the latest CUDA drivers, CUDA libraries, and cuDNN! Why is CUDA so hard?!"

Well, now, there's a way out! As it turns out, independent of your system-level CUDA installation, if you're using Conda environments, you actually can manage your own environment-level CUDA installation! And as it turns out, this is mostly a matter of setting environment variables correctly.

Step 1: Ensure that the CUDA metapackage is available in your environment.yml file

Your environment.yml file should look like this:

name: my-env
channels:
    - nvidia # order matters
    - conda-forge
dependencies:
    - ...
    - cuda-libraries # this is the cuda metapackage
    - cudatoolkit # explicitly needed so conda/mamba installs the cuda-compiled jaxlib
    - cudnn # this is specifically for cudnn
    - cuda-nvcc # ensures that a compatible nvidia C compiler is available!
    - jaxlib # ensure that this is installed by conda/mamba, not by pip!
    - jax
    - ...

Now, run mamba env update -f environment.yml to install the packages.

What will happen is the following:

  • The CUDA libraries will be installed into your $CONDA_PREFIX/lib directory.
  • jax and jaxlib will be installed as Python packages as well, within the environment specified by the $CONDA_PREFIX environment variable.

You may be tempted to run your JAX code at this point, but you'll still run into the aforementioned error.

Step 2: Ensure that your LD_LIBRARY_PATHincludes $CONDA_PREFIX/lib

Now, we need to set environment variables, specifically the LD_LIBRARY_PATH environment variable. The LD_LIBRARY_PATH is used to specify a list of directories to look for "dynamic libraries" before searching the standard UNIX library paths. At this point, options multiply: we may need to make judgment calls or tradeoffs. As I see it, there are two sane places where we can configure LD_LIBRARY_PATH to be correct.

Option 1: Use conda's activate.d and deactivate.d

activate.d and deactivate.d are folders housing shell scripts that are automatically run whenever we do conda activate and conda deactivate, respectively. Jaerock Kwon wrote a constructive blog post on ensuring that these environment variables are set correctly. Essentially, it is the shell script below, which should be run after activating an environment:

mkdir -p $CONDA_PREFIX/etc/conda/activate.d
echo 'export OLD_LD_LIBRARY_PATH=${LD_LIBRARY_PATH}' > \
    $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
echo 'export LD_LIBRARY_PATH=$CONDA_PREFIX/lib/:$OLD_LD_LIBRARY_PATH' >> \
    $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh

mkdir -p $CONDA_PREFIX/etc/conda/deactivate.d
echo 'export LD_LIBRARY_PATH=${OLD_LD_LIBRARY_PATH}' > \
    $CONDA_PREFIX/etc/conda/deactivate.d/env_vars.sh
echo 'unset OLD_LD_LIBRARY_PATH' >> \
    $CONDA_PREFIX/etc/conda/deactivate.d/env_vars.sh

The tradeoff here is that this is not portable from machine to machine. We'd have to run that chunk of code every single time we create a new environment that uses CUDA libraries or try to recreate the environment in a new machine or a Docker container. Ideally, we would specify this information within environment.yml (link found via StackOverflow).

Option 2: Explicitly define this in .env and load it into your Python session

As it turns out, there is another way! Before importing JAX and executing JAX code, we can directly set the LD_LIBRARY_PATH environment variable within our Python session through environment variables that we load at runtime. To start, we need to create a .env file in our repository:

# These are the contents of /path/to/repo/.env
XLA_FLAGS="--xla_gpu_cuda_data_dir=${CONDA_PREFIX}/lib"
LD_LIBRARY_PATH="${CONDA_PREFIX}/lib:${LD_LIBRARY_PATH}"

Then, within a Jupyter notebook (or a Python session more generally), we use python-dotenv to inject those environment variables into the currently running Python session:

from dotenv import load_dotenv

load_dotenv()

We should be able to verify that LD_LIBRARY_PATH is set correctly by running:

import os

print(os.getenv("LD_LIBRARY_PATH"))
print(os.getenv("XLA_FLAGS"))

What gets printed should follow the pattern set in the .env file.

Now, you can create NumPy arrays and shouldn't observe any issues with outdated cuDNN:

import jax.numpy as np

a = np.linspace(0, 3, 1000)

To verify that you're also using GPU, at the terminal, run the following:

nvidia-smi

And you should see something like the following:

 nvidia-smi
Fri May 31 12:31:08 2024
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A10G         On   | 00000000:00:1E.0 Off |                    0 |
|  0%   28C    P0    59W / 300W |  17434MiB / 23028MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A   1947397      C   ...vironment-name/bin/python    17432MiB |
+-----------------------------------------------------------------------------+

Which way should I use?

Both are good, and both come with a common tradeoff that we need to remember: the environment configuration isn't portable from machine to machine.

In the case of the env_var.sh scripts housed under activate.d and deactivate.d, they need to be recreated any time the conda environment is recreated. This means that if one deletes the environment and recreates it, or simply moves to a different machine and recreates it, one will need to re-run the shell commands listed above to recreate the environment variables.

In the case of .env configuration, this is much more lightweight, but as a matter of common idioms, .env files are not committed to source control, which makes distributing them a hassle.

My preference is to use .env files, as they are more easily accessible to us as we work within a project repository. My solution to the problem of portability is to ensure that the specific configuration of individual .env files is at least initialized as part of a standard project structure (akin to what I do with pyds-cli or what cookie-cutter-data-science does), so that newly created repositories come with these environment variables. And in the absolute worst case, one can simply copy/paste those two lines between two repos' .env files.

UPDATE (9 June 2024): Thanks to Vyas Ramasubramani from NVIDIA, who did a deep dive after reading my blog post, it turns out it is unnecessary to set environment variables as long as the correct CUDA packages are installed. cudatoolkit was necessary with CUDA11, but is no longer necessary with CUDA12. An example environment.yml file that I verified as working on my system are:

name: my-env
channels:
    - conda-forge
dependencies:
    - python=3.10
    - cuda-libraries # this is the cuda metapackage
    - cudnn # this is specifically for cudnn
    - cuda-nvcc # ensures that a compatible nvidia C compiler is available!
    - jaxlib
    - cuda-version # necessary for pulling in packages that are CUDA12-compatible!
    - jax

As I learned through the discussion thread that Vyas created, cudatoolkit is needed for CUDA11, while cuda-version is usable for any version of CUDA running back to version 9.2. This turns out to be the result of hard work by open source community members who also work at NVIDIA (see this thread for an example). Back in graduate school, NVIDIA's stack used to be confusing to me, but I'm thankful that many people are hard at work making it easier to use! (This is also why I love the internet: If I have a misconception and write about it, there will be wonderful people like Vyas to correct me!)


Cite this blog post:
@article{
    ericmjl-2024-how-environments,
    author = {Eric J. Ma},
    title = {How to manage CUDA libraries within Conda environments},
    year = {2024},
    month = {06},
    day = {01},
    howpublished = {\url{https://ericmjl.github.io}},
    journal = {Eric J. Ma's Blog},
    url = {https://ericmjl.github.io/blog/2024/6/1/how-to-manage-cuda-libraries-within-conda-environments},
}
  

I send out a newsletter with tips and tools for data scientists. Come check it out at Substack.

If you would like to sponsor the coffee that goes into making my posts, please consider GitHub Sponsors!

Finally, I do free 30-minute GenAI strategy calls for teams that are looking to leverage GenAI for maximum impact. Consider booking a call on Calendly if you're interested!