This is a machine learning and deep learning research codebase built on JAX/Flax and Bazel. The repository contains multiple independent but infrastructure-sharing research projects, currently focusing on two main modules: Reinforcement Learning (RL) and Generative Models (Generative).
src/projects/: Contains the source code for all specific research projects.rl/: Reinforcement Learning module (contains implementations of DQN and related components).generative/: Generative Models module (contains U-Net based generative models like DDPM diffusion models and Flow Matching).
src/core/: Core components and infrastructure (e.g., base model classes inmodel.py, distributed training wrappers indistributed.py, and training state management intrain_state.py).src/utilities/: General utility libraries (including logging, visualization, and training helper functions).src/data/: Data processing modules (e.g., data pipelines using HuggingFacedatasets).MODULE.bazel: Bazel dependency and environment configuration file, managing Python versions and related package dependencies.
Before running any code, ensure that system-level dependencies are correctly loaded and the appropriate driver versions are enabled:
-
Load CUDA and cuDNN modules (typically required on clusters/servers):
module load cuda/12.6 module load cudnn
-
Bazel Build System: This project uses Bazelisk (as a wrapper for Bazel) for unified build version management.
- All Python dependencies are declared via
rules_pythoninMODULE.bazeland are automatically fetched during the build. - No need to manually run
pip install; Bazel will isolate the execution environment properly.
- All Python dependencies are declared via
The core way to execute code is through the bazelisk run command. Below is a breakdown of the "long command" you frequently use for better understanding:
CUDA_VISIBLE_DEVICES=0 \
NCCL_P2P_LEVEL=NVL \
NCCL_SHM_DISABLE=0 \
XLA_PYTHON_CLIENT_MEM_FRACTION=.9 \
bazelisk run --config=cuda //src/projects/rl:main -- --work_dir logs/- GPU and JAX Environment Variables:
CUDA_VISIBLE_DEVICES=0: Restricts the process to only use the GPU with index0.NCCL_P2P_LEVEL=NVL&NCCL_SHM_DISABLE=0: Optimizes NCCL P2P shared memory and NVLink strategies for multi-GPU or single-GPU communication (essential communication optimizations for JAX).XLA_PYTHON_CLIENT_MEM_FRACTION=.9: Instructs JAX/XLA to pre-allocate 90% of the GPU memory to prevent OOM (Out Of Memory) or memory fragmentation.
- Bazelisk Command:
bazelisk run --config=cuda: Compiles and runs using the Bazel configuration with CUDA support (usually defined in the.bazelrcat the root directory).//src/projects/rl:main: Specifies the Bazel Target to run.//represents the root directory, corresponding to theml_py_binarywithname="main"in thesrc/projects/rl/BUILDfile.
- User Arguments (Passed to the Python script after
--):--work_dir logs/: Specific arguments passed tomain.pyin Python, such as paths for saving logs and models.
-
Run Reinforcement Learning (RL - DQN)
module load cuda/12.6 module load cudnn CUDA_VISIBLE_DEVICES=0 NCCL_P2P_LEVEL=NVL NCCL_SHM_DISABLE=0 XLA_PYTHON_CLIENT_MEM_FRACTION=.9 \ bazelisk run --config=cuda //src/projects/rl:main -- \ --work_dir logs/rl_run \ --num_episodes 5000 \ --batch_size 512 -
Run Generative Models (Generative - DDPM / Mean Flow)
module load cuda/12.6 module load cudnn CUDA_VISIBLE_DEVICES=0 NCCL_P2P_LEVEL=NVL NCCL_SHM_DISABLE=0 XLA_PYTHON_CLIENT_MEM_FRACTION=.9 \ bazelisk run --config=cuda //src/projects/generative:main -- \ --work_dir logs/generative_run \ --distributed False
(Note: If you have specific fiddle configuration files when running generative, you might also need to pass them via arguments like --experiment=xxx, see main.py for details)
Because this project uses Bazel to manage all file dependencies, whenever you add a new .py file, you MUST register it in the BUILD file within the same directory; otherwise, Bazel won't find your new code during execution.
Open src/projects/rl/BUILD, and you will see two main types of Bazel Rule Macros:
ml_py_library: Used to define a library file or module (to be imported by other code).ml_py_binary: Used to define an executable entry script (likemain.py, which can be run directly viabazelisk run).
Suppose you create a new algorithm named ppo.py under src/projects/rl/:
-
Register Python Library: Add a
ml_py_libraryblock insrc/projects/rl/BUILD:ml_py_library( name = "ppo", srcs = ["ppo.py"], deps = [ "flax", "jax", "optax", "//src/core:model", # Depend on other local Bazel modules ], )
Third-party libraries (like
flax,jax) indepsare pre-defined inMODULE.bazel, so you can just write their names. Local libraries require the full path (like//src/core:model). -
Import into the Main Program: If you want to use your new algorithm in
main.py, you must add:ppoto the dependencies list ofmainin theBUILDfile.ml_py_binary( name = "main", srcs = ["main.py"], deps = [ # ... other existing dependencies ":ppo", # <-- Introduce the newly defined library as a dependency ], )
-
Run Again: After modifying the
BUILDfile and importing it in your code, just use the previousbazelisk runcommand. Bazel will automatically detect file changes and rebuild your execution environment.