Yes, you can customize the percentage of GPU memory preallocated by JAX. By default, JAX preallocates 75% of the total GPU memory when the first JAX operation is run. To change this, you can use the environment variable `XLA_PYTHON_CLIENT_MEM_FRACTION`. This variable allows you to specify a different percentage of GPU memory to preallocate instead of the default 75%.
For example, to preallocate 50% of the GPU memory, you would set the environment variable as follows:
python
import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.5'
This should be done before importing JAX in your script. If you want to disable preallocation altogether, you can set `XLA_PYTHON_CLIENT_PREALLOCATE` to `false`, which will make JAX allocate memory as needed, potentially reducing overall memory usage but increasing the risk of memory fragmentation[1][3][7].
However, note that setting these environment variables after importing JAX will have no effect, so they must be set before importing JAX[7].
Citations:[1] https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
[2] https://github.com/google/jax/issues/4310
[3] https://kolonist26-jax-kr.readthedocs.io/en/latest/gpu_memory_allocation.html
[4] https://dl.acm.org/doi/fullHtml/10.1145/3673038.3673122
[5] https://stackoverflow.com/questions/73322760/jax-gpu-memory-usage-even-with-cpu-allocation
[6] https://dl.acm.org/doi/pdf/10.1145/3673038.3673122
[7] https://github.com/jax-ml/jax/discussions/21962
[8] https://docs.nvidia.com/deeplearning/dali/user-guide/docs/operations/nvidia.dali.fn.decoders.image.html
[9] https://forum.pyro.ai/t/gpu-memory-preallocated-and-not-released-between-batches/3774
[10] https://github.com/jax-ml/jax/discussions/19014