GPU-accelerated probabilistic graphical models in Rust
thrml-rs is a pure Rust implementation of GPU-accelerated sampling for probabilistic graphical models (PGMs),
ported from Extropic's THRML library, with a few tweaks.
- GPU Acceleration: Multiple backend support:
- WGPU (default): Metal (macOS), Vulkan (Linux/Windows)
- CUDA: Native NVIDIA GPU support
- Block Gibbs Sampling: Efficient parallel sampling for PGMs
- Energy-Based Models: Ising models, discrete EBMs, Gaussian PGMs
- Mixed Variable Types: Spin, categorical, and continuous nodes
- Deterministic RNG: Reproducible sampling with ChaCha8-based key splitting
- Moment Estimation: Built-in observers for computing statistics
- Training Support: Contrastive divergence, KL gradient estimation
use thrml_core::{Node, NodeType, Block, backend::init_gpu_device};
use thrml_models::ising::{IsingEBM, IsingSamplingProgram, hinton_init};
use thrml_samplers::{RngKey, SamplingSchedule};
use burn::tensor::Tensor;
fn main() {
// Initialize GPU
let device = init_gpu_device();
// Create a 5-node Ising chain
let nodes: Vec<Node> = (0..5).map(|_| Node::new(NodeType::Spin)).collect();
let edges: Vec<_> = nodes.windows(2)
.map(|w| (w[0].clone(), w[1].clone()))
.collect();
// Define biases and coupling weights
let biases = Tensor::from_data([0.1f32, 0.2, 0.0, -0.1, 0.3], &device);
let weights = Tensor::from_data([0.5f32, -0.3, 0.4, 0.2], &device);
let beta = Tensor::from_data([1.0f32], &device);
// Create the Ising model
let model = IsingEBM::new(nodes.clone(), edges, biases, weights, beta);
// Initialize using Hinton's method
let key = RngKey::new(42);
let blocks = vec![Block::new(nodes).unwrap()];
let init_state = hinton_init(key, &model, &blocks, &[], &device);
println!("Model initialized with {} nodes", model.nodes().len());
}| Crate | Description |
|---|---|
thrml-core |
Core types: Node, Block, BlockSpec, GPU backend |
thrml-samplers |
Sampling algorithms: Gibbs, Bernoulli, Softmax, Gaussian |
thrml-models |
Model implementations: Ising, Discrete EBM, Continuous factors |
thrml-observers |
Observation utilities: State, Moments |
thrml-examples |
Example programs and utilities |
Add to your Cargo.toml:
[dependencies]
thrml-core = "0.1"
thrml-samplers = "0.1"
thrml-models = "0.1"
thrml-observers = "0.1"| Feature | Backend | Use Case |
|---|---|---|
gpu (default) |
WGPU | Metal (macOS), Vulkan (Linux), DX12 (Windows) |
cuda |
CUDA + WGPU | NVIDIA GPUs with native CUDA |
cpu |
ndarray + WGPU | Development/testing without GPU, or CPU fallback |
# Default: WGPU backend (Metal on macOS, Vulkan on Linux)
cargo build --release
# Enable CUDA support alongside WGPU
cargo build --release --features cuda
# Enable CPU backend (useful for testing or systems without GPU)
cargo build --release --features cpu- Rust 1.89+ (stable) - required by Burn 0.19
- WGPU backend: GPU with Metal (macOS) or Vulkan (Linux/Windows) support
- CUDA backend: NVIDIA GPU with CUDA toolkit installed
See the examples/ directory:
# Simple Ising chain demonstration
cargo run --release --example ising_chain
# Spin models with performance benchmarking
cargo run --release --example spin_models
# Categorical variable sampling with visualization
cargo run --release --example categorical_sampling
# Full API walkthrough tutorial
cargo run --release --example full_api_walkthrough
# Gaussian PGM sampling (continuous nodes)
cargo run --release --example gaussian_pgm
# Mixed Gaussian-Bernoulli model
cargo run --release --example gaussian_bernoulli_ebm
# Full MNIST training with contrastive divergence
cargo run --release --example train_mnistTHRML-RS leverages the Burn deep learning framework for GPU acceleration:
| Backend | Platform | GPU Support |
|---|---|---|
| WGPU-Metal | macOS | Apple Silicon, AMD, Intel |
| WGPU-Vulkan | Linux/Windows | NVIDIA, AMD, Intel |
| CUDA | Linux/Windows | NVIDIA (native) |
Key optimizations:
- Native Metal acceleration on Apple Silicon
- CUDA for maximum performance on NVIDIA GPUs
- Efficient tensor operations with automatic batching
- Fused GPU kernels for sampling operations
Contributions are welcome! Please see CONTRIBUTING.md for guidelines.
Licensed under either of:
- Apache License, Version 2.0 (LICENSE-APACHE)
- MIT license (LICENSE-MIT)
at your option.
This project is inspired by Extropic's THRML library. THRML-RS is an independent Rust implementation providing the same functionality with native GPU acceleration.