eigen-hhlo
Dense linear algebra decompositions on HHLO via XLA custom calls.
eigen-hhlo provides singular value decomposition (SVD), QR decomposition, symmetric eigenvalue decomposition, Cholesky factorization, and LU decomposition with partial pivoting as first-class operations within the HHLO ecosystem. Computations run on XLA through PJRT using GPU via cuSOLVER.
Note on CPU support: A CPU backend via LAPACK is prepared but deferred pending upstream PJRT CPU plugin support. See Appendix: CPU Backend for details.
Background
The HHLO ecosystem (Haskell-frontend for StableHLO/XLA) provides a typed EDSL for building MLIR modules and executing them through PJRT. While it covers neural network primitives, reductions, and control flow, it lacks dense linear algebra factorizations. StableHLO itself does not define decomposition ops; the standard way to add them is via stablehlo.custom_call backed by a runtime library. eigen-hhlo fills this gap by exposing cuSOLVER routines through XLA custom calls.
Design
The Haskell EDSL emits stablehlo.custom_call operations whose call_target_name matches symbols in a CUDA shared library (e.g. eigenhhlo_dpotrf). At session startup, withEigenGPU registers these symbols with the PJRT CUDA plugin via registerGpuCustomCall. When XLA compiles the StableHLO module, it maps each custom call to the registered CUDA function; at runtime the wrapper handles device buffer management, cuSOLVER workspace allocation, and kernel launch.
- HHLO-only: No dependency on
hmatrix or other host linear algebra libraries. Everything lives inside the XLA graph.
- GPU via cuSOLVER: C++ wrappers around cuSOLVER (
cusolverDnDpotrf, cusolverDnDgesvd, cusolverDnDgeqrf/cusolverDnDorgqr, cusolverDnDsyevd, cusolverDnDgetrf) conforming to the XLA GPU custom-call ABI (api_version = 3).
- Type-safe shapes: All ops are statically shaped using GHC
TypeLits (e.g. Tensor '[m, n] 'F64).
Implementation
C GPU Library
cbits/gpu/eigenhhlo_cusolver.cu wraps cuSOLVER routines for the XLA GPU custom-call ABI:
| Symbol |
cuSOLVER |
Operation |
eigenhhlo_dpotrf |
cusolverDnDpotrf |
Cholesky factorization |
eigenhhlo_dgesvd |
cusolverDnDgesvd |
SVD |
eigenhhlo_dgeqrf |
cusolverDnDgeqrf |
QR factorization (reflectors) |
eigenhhlo_dorgqr |
cusolverDnDorgqr |
Generate Q from QR reflectors |
eigenhhlo_dsyevd |
cusolverDnDsyevd |
Symmetric eigenvalue decomposition |
eigenhhlo_dgetrf |
cusolverDnDgetrf |
LU with partial pivoting |
The opaque backend_config string carries dimensions and options (e.g. "n=2,uplo=L").
Haskell Modules
| Module |
Purpose |
EigenHHLO.IR.Cholesky |
cholBuilder — typed customCallRaw wrapper for Cholesky |
EigenHHLO.IR.SVD |
svdBuilder — 3-output custom call (U, S, Vt) |
EigenHHLO.IR.QR |
qrBuilder / qBuilder — QR + explicit Q generation |
EigenHHLO.IR.Eigenvalue |
eigBuilder — eigenvalues + eigenvectors |
EigenHHLO.IR.LU |
luBuilder — LU factors + pivot indices |
EigenHHLO.EDSL.Decomposition |
Re-exports with simpler names (chol, svd, qr, etc.) |
EigenHHLO.Runtime.Session |
withEigenGPU — session setup + library loading |
EigenHHLO.Core.BackendConfig |
Opaque config string builders |
Building
⚠️ Required Manual Step: Compile the CUDA Custom-Call Library
eigen-hhlo does NOT compile its CUDA code automatically through Cabal.
You must run the provided build.sh script manually before the package can execute any custom calls at runtime.
cd cbits/gpu && bash build.sh && cd ../..
# Produces: lib/libeigenhhlo_gpu.so
Why this is necessary:
The custom-call functions (eigenhhlo_dpotrf, eigenhhlo_dgesvd, etc.) live in a separate shared library that is loaded at runtime by withEigenGPU (via registerGpuCustomCall). Cabal's extra-source-files only ships the source code in the tarball; it does not invoke nvcc during the Haskell build. This is a deliberate design choice because CUDA compilation requires nvcc, which cannot be assumed to be present on every machine that installs the Haskell package.
Runtime path resolution:
The session helper resolves the custom-call library path using the same three-tier strategy as HHLO's getPluginPath:
- Environment variable (highest priority):
EIGENHHLO_GPU_LIB
- Default relative path:
lib/libeigenhhlo_gpu.so
- Clear runtime error with build instructions if the file is not found
This means you can run your executable from any working directory by setting an absolute path:
export EIGENHHLO_GPU_LIB=/home/you/projects/myapp/lib/libeigenhhlo_gpu.so
./myapp
Prerequisites
- GHC 9.6+
- Cabal 3.10+
- CUDA toolkit 11.8+ with cuSOLVER (for GPU backend)
Step-by-step build
# 1. Build the GPU custom-call library (REQUIRED)
cd cbits/gpu && bash build.sh && cd ../..
# Produces: lib/libeigenhhlo_gpu.so
# 2. Build the Haskell project
cabal build all
# 3. Run tests
cabal test test:eigen-hhlo-test
Build and run the examples
# 1. Compile GPU custom-call library FIRST (required)
cd cbits/gpu && bash build.sh && cd ../..
# 2. Build all examples
cabal build eigen-hhlo:exes --flags=examples
# 3. Run an individual example
cabal run example-cholesky --flags=examples
Usage
Cholesky Example
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeApplications #-}
import HHLO.Core.Types (DType(..))
import HHLO.IR.AST (FuncArg(..), TensorType(..))
import HHLO.IR.Builder (arg, moduleFromBuilder)
import HHLO.IR.Pretty (render)
import EigenHHLO.EDSL.Decomposition (chol)
import EigenHHLO.Runtime.Session (withEigenGPU)
main :: IO ()
main = withEigenGPU $ \esess -> do
let modu = moduleFromBuilder @'[2,2] @'F64 "main"
[ FuncArg "a" (TensorType [2,2] F64) ]
$ do a <- arg @'[2,2] @'F64
chol esess a
putStrLn $ T.unpack (render modu)
Output:
module {
func.func @main(%arg0: tensor<2x2xf64>) -> tensor<2x2xf64> {
%0 = stablehlo.custom_call @eigenhhlo_dpotrf(%arg0)
{call_target_name = "eigenhhlo_dpotrf", has_side_effect = false,
backend_config = "n=2,uplo=L", api_version = 3 : i32}
: (tensor<2x2xf64>) -> tensor<2x2xf64>
return %0 : tensor<2x2xf64>
}
}
Available Operations
-- Cholesky: A = L · Lᵀ
chol :: KnownNat n => EigenSession -> Tensor '[n,n] 'F64 -> Builder (Tensor '[n,n] 'F64)
-- SVD: A = U · diag(S) · Vt
svd :: (KnownNat m, KnownNat n, KnownNat k) => EigenSession -> Tensor '[m,n] 'F64
-> Builder (Tensor '[m,k] 'F64, Tensor '[k] 'F64, Tensor '[k,n] 'F64)
-- QR factorization: A = Q · R
qr :: (KnownNat m, KnownNat n, KnownNat k) => EigenSession -> Tensor '[m,n] 'F64
-> Builder (Tensor '[m,n] 'F64, Tensor '[k] 'F64)
-- Generate explicit Q matrix from QR reflectors
q :: (KnownNat m, KnownNat n, KnownNat k) => EigenSession -> Tensor '[m,n] 'F64 -> Tensor '[k] 'F64
-> Builder (Tensor '[m,m] 'F64)
-- Symmetric eigenvalue: A = V · Λ · Vᵀ
eig :: KnownNat n => EigenSession -> Tensor '[n,n] 'F64
-> Builder (Tensor '[n] 'F64, Tensor '[n,n] 'F64)
-- LU with partial pivoting
lu :: (KnownNat m, KnownNat n, KnownNat k) => EigenSession -> Tensor '[m,n] 'F64
-> Builder (Tensor '[m,n] 'F64, Tensor '[k] 'I32)
Session Setup
import EigenHHLO.Runtime.Session (withEigenGPU)
-- GPU: registers symbols with PJRT CUDA plugin via cuSOLVER
withEigenGPU $ \esess -> do
-- build, compile, run ...
pure ()
Current Status
- MLIR generation: ✅ Complete for all five decompositions.
- Test suite: ✅ 8/8 tests pass (5 MLIR generation + 3 GPU numerical end-to-end).
- GPU execution: ✅ Working on CUDA hardware via cuSOLVER custom calls.
Project Structure
eigen-hhlo/
├── app/Main.hs -- CLI demo / smoke test (GPU)
├── cbits/
│ └── gpu/
│ ├── eigenhhlo_cusolver.cu -- cuSOLVER wrappers (XLA GPU ABI)
│ └── build.sh -- Builds lib/libeigenhhlo_gpu.so
├── examples/
│ ├── 01-cholesky.hs -- Cholesky factorization (GPU)
│ ├── 02-svd.hs -- Singular value decomposition (GPU)
│ ├── 03-qr.hs -- QR factorization -> explicit Q (GPU)
│ ├── 04-eigenvalue.hs -- Symmetric eigenvalue decomposition (GPU)
│ ├── 05-lu.hs -- LU with partial pivoting (GPU)
│ ├── 06-pipeline.hs -- Composed pipeline (Gram -> Cholesky) (GPU)
│ └── 07-gpu-setup.hs -- GPU custom-call registration demo
├── lib/
│ └── libeigenhhlo_gpu.so -- GPU custom-call shared library
├── src/
│ ├── EigenHHLO/Core/
│ │ ├── BackendConfig.hs -- Opaque config string builders
│ │ └── Types.hs -- BackendType, EigenSession
│ ├── EigenHHLO/EDSL/
│ │ └── Decomposition.hs -- User-facing API
│ ├── EigenHHLO/IR/
│ │ ├── Cholesky.hs
│ │ ├── Eigenvalue.hs
│ │ ├── LU.hs
│ │ ├── QR.hs
│ │ └── SVD.hs -- Typed custom_call builders
│ └── EigenHHLO/Runtime/
│ └── Session.hs -- withEigenGPU
├── test/ -- Tasty/HUnit MLIR generation + GPU numerical tests
├── cabal.project
├── eigen-hhlo.cabal
└── devlog/ -- Design docs & analysis
Appendix: CPU Backend
The CPU backend is deferred — the source code is present but end-to-end execution is not yet possible.
Why it is deferred
The PJRT CPU plugin uses an internal C++ CustomCallTargetRegistry singleton to resolve custom calls. This registry can only be populated by code linked directly into the plugin binary itself. Our LAPACK wrapper library (libeigenhhlo_cpu.so) would be loaded separately via dlopen(RTLD_GLOBAL), and its symbols are globally visible — yet the plugin never looks them up. We verified this experimentally (LD_PRELOAD, dlsym tests, and symbol-table inspection all confirm the plugin ignores externally-loaded symbols).
What this means in practice:
- ✅ MLIR generation and compilation work correctly on the CPU code path.
- ❌ Runtime execution crashes with
PJRTException "No registered implementation for untyped custom call to eigenhhlo_* for Host".
- This is not a bug in eigen-hhlo; it is an architectural limitation of the current PJRT CPU plugin.
Paths forward for CPU:
- Build a custom PJRT CPU plugin from OpenXLA source with the LAPACK wrappers linked in.
- Wait for upstream — OpenXLA is tracking CPU custom-call registration APIs (see openxla/xla#26928).
CPU source code
| File |
Purpose |
cbits/cpu/eigenhhlo_lapack.c |
LAPACK wrappers (dpotrf, dgesvd, dgeqrf/dorgqr, dsyevd, dgetrf) for XLA CPU ABI (api_version = 2) |
cbits/cpu/build.sh |
Build script for lib/libeigenhhlo_cpu.so |
lib/libeigenhhlo_cpu.so |
Output shared library (not usable until upstream support arrives) |
Building the CPU library (optional)
If you want to experiment with the CPU backend:
# Prerequisites: LAPACK / OpenBLAS (liblapack.so.3, libopenblas.so.0)
cd cbits/cpu && bash build.sh && cd ../..
# Produces: lib/libeigenhhlo_cpu.so
The Haskell API also exposes withEigenCPU, which resolves the library path via EIGENHHLO_CPU_LIB (environment variable) or the default lib/libeigenhhlo_cpu.so, following the same three-tier strategy as the GPU backend.
License
MIT