HHLO — Haskell Frontend for StableHLO
HHLO is a Haskell library and runtime for building, compiling, and executing machine learning programs targeting StableHLO, the portable, versioned intermediate representation of the OpenXLA ecosystem.
Instead of replicating JAX's Python-based tracing infrastructure, HHLO generates StableHLO MLIR text directly from Haskell and compiles it to CPU or GPU via the PJRT plugin interface.
Design
HHLO is structured in four layers:
┌─────────────────────────────────────┐
│ EDSL (HHLO.EDSL.Ops) │ Type-safe frontend: add, matmul, relu, etc.
├─────────────────────────────────────┤
│ IR Builder (HHLO.IR.Builder) │ Stateful monad for constructing MLIR
├─────────────────────────────────────┤
│ Pretty Printer (HHLO.IR.Pretty) │ Emits StableHLO MLIR text
├─────────────────────────────────────┤
│ PJRT Runtime (HHLO.Runtime.*) │ Compile → Execute on CPU or GPU
└─────────────────────────────────────┘
Text Emission + PJRT
The library emits StableHLO MLIR text directly and hands it to PJRT_Client_Compile. This is the same path used by JAX's C++ backend and avoids the heavy dependency of building LLVM/MLIR from source.
Phantom Types
Every tensor carries its shape and dtype as phantom type parameters:
Tensor '[2, 3] 'F32 -- 2×3 matrix of Float32
Matmul, broadcast, and conv shapes are checked at compile time via type families.
ForeignPtr Finalizers
PJRT buffers and executables are managed by ForeignPtr finalizers that automatically call PJRT_Buffer_Destroy and PJRT_LoadedExecutable_Destroy when values are garbage-collected. You can still let references drop out of scope without explicit cleanup.
Dynamic Output Counts
The runtime queries the compiled executable for its actual number of outputs via PJRT_Executable_NumOutputs instead of guessing or hardcoding a maximum.
Async Execution
HHLO.Runtime.Async provides true non-blocking execution: executeAsync returns buffer handles immediately, bufferReady polls for completion, and awaitBuffers blocks until device-side computation finishes.
Device Enumeration & Selection
HHLO.Runtime.Device lets you discover and select specific GPUs at runtime:
addressableDevices api client -- list all devices
deviceKind api dev -- "cpu" or "NVIDIA GeForce RTX 5090"
defaultGPUDevice api client -- first non-CPU device
Multi-GPU Inference Scaling
HHLO.Runtime.Execute provides executeReplicas for running the same compiled model concurrently across multiple GPUs:
compileWithOptions api client mlirText
(defaultCompileOptions { optNumReplicas = numDevs })
-- Launch independent forward passes on all GPUs
executeReplicas api exec
[ (gpu0, [bufA0, bufB0])
, (gpu1, [bufA1, bufB1])
, ...
]
Installation
System Requirements
- GHC 9.6+ and Cabal 3.10+
- Linux x86_64 (other platforms supported by PJRT artifacts may work)
curl, tar, and standard C toolchain (gcc or clang)
libstdc++ and libdl (usually present on Linux)
Download PJRT Plugins
Run the provided script to download prebuilt PJRT plugins:
./pjrt_script.sh
This downloads libpjrt_cpu.so from the zml/pjrt-artifacts nightly builds into deps/pjrt/. If you have an NVIDIA GPU with nvidia-smi available, the CUDA plugin is also fetched automatically.
Build the Project
cabal build all
This compiles the library, the demo, the examples, and the test suite.
Usage
CPU (works out of the box)
cabal run example-add
cabal test
GPU (requires runtime libraries)
The PJRT CUDA plugin depends on NVIDIA runtime libraries: cuDNN, NCCL, and NVSHMEM. These are commonly available via conda, pip, or system packages.
If you already have them (e.g. via PyTorch or JAX installations), simply run:
./setup_gpu_env.sh
source ~/.bashrc
This idempotent script auto-discovers the libraries and appends them to ~/.bashrc. After that, GPU examples work directly:
cabal run example-gpu-add
cabal run example-gpu-matmul-bench
cabal run example-multi-gpu-inference
EDSL Quick Start
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE OverloadedStrings #-}
import HHLO.Core.Types
import HHLO.EDSL.Ops
import HHLO.IR.AST (FuncArg(..), TensorType(..))
import HHLO.IR.Builder
import HHLO.IR.Pretty
import qualified Data.Text as T
-- Build a program: c = a + b
program :: Module
program = moduleFromBuilder @'[2,2] @'F32 "main"
[ FuncArg "a" (TensorType [2, 2] F32)
, FuncArg "b" (TensorType [2, 2] F32)
]
$ do
a <- arg
b <- arg
c <- add a b
return c
main :: IO ()
main = T.putStrLn (render program)
Output:
module {
func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = stablehlo.add %arg0, %arg1 : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
}
Running the Demo
cabal run hhlo-demo
The demo builds a stablehlo.add program via the EDSL, compiles it with PJRT CPU, creates F32 input buffers, executes, and reads back the result:
=== HHLO End-to-End Demo ===
Loading PJRT CPU plugin...
Plugin loaded.
...
Result: [6.0,8.0,10.0,12.0]
SUCCESS: Results match expected values!
Running Examples
Standalone examples are provided in examples/:
| # |
Command |
Description |
| 1 |
cabal run example-add |
Element-wise c = a + b |
| 2 |
cabal run example-matmul |
2×3 @ 3×2 matrix multiply |
| 3 |
cabal run example-chain-ops |
(a + b) * (a - b) |
| 4 |
cabal run example-async |
Async executeAsync + relu |
| 5 |
cabal run example-mlp |
2-layer MLP |
| 6 |
cabal run example-mlp-batched |
Batched MLP |
| 7 |
cabal run example-tuple |
Multi-result func.func (MLIR print-only) |
| 8 |
cabal run example-reduce |
reduceSum over all dimensions |
| 9 |
cabal run example-softmax |
1-D and batched 2-D softmax |
| 10 |
cabal run example-conv2d |
NHWC conv2d |
| 11 |
cabal run example-batch-norm |
Batch norm inference |
| 12 |
cabal run example-while |
whileLoop count-up |
| 13 |
cabal run example-conditional |
conditional if-then-else |
| 14 |
cabal run example-gather |
gather rows from matrix |
| 15 |
cabal run example-scatter |
scatter replace into vector |
| 16 |
cabal run example-slice |
slice sub-array extraction |
| 17 |
cabal run example-pad |
pad with edge/interior padding |
| 18 |
cabal run example-dynamic-slice |
dynamicSlice runtime indices |
| 19 |
cabal run example-sort |
sort 1-D ascending |
| 20 |
cabal run example-select |
Element-wise ternary select |
| 21 |
cabal run example-map |
map with custom computation |
| 22 |
cabal run example-new-ops-smoke-test |
Smoke test for newer ops |
| 23 |
cabal run example-resnet |
ResNet-18 toy (8×8 input) |
| 24 |
cabal run example-alexnet |
AlexNet toy (16×16 input) |
| 25 |
cabal run example-transformer |
Transformer encoder (1×4×16) |
| 26 |
cabal run example-unet |
UNet segmentation toy (16×16) |
| 27 |
cabal run example-gpu-add |
GPU smoke test |
| 28 |
cabal run example-gpu-matmul-bench |
GPU 4096×4096 benchmark |
| 29 |
cabal run example-multi-gpu-inference |
Multi-GPU concurrent matmul |
Tests
CPU Tests (default)
cabal test
Runs 115 tests across three tiers:
- Tier 1 — Golden tests — Verify rendered MLIR text for EDSL ops, IR constructs, NN layers, and control flow.
- Tier 2 — End-to-end runtime tests — Load the PJRT CPU plugin, compile StableHLO programs, execute them, and verify numerical results. Covers arithmetic, matmul, reductions, data movement, and NN ops.
- Tier 3 — Runtime integration tests — Buffer metadata queries, async execution, and error handling.
GPU Tests
HHLO_TEST_GPU=1 cabal test
Runs the full 115 CPU tests plus 6 additional GPU integration tests:
EndToEnd.GPU — GPU availability and device enumeration
Runtime.BufferGPU — Buffer round-trip and metadata queries on GPU
Runtime.AsyncGPU — Async execution and bufferReady polling on GPU
Runtime.MultiGPU — Concurrent executeReplicas across all GPUs
Sample output:
HHLO Tests
EDSL.Ops
Binary element-wise
add: OK
...
EndToEnd.Arithmetic
relu: OK (0.02s)
...
Runtime.Buffer
buffer round-trip f32: OK
Runtime.Async
buffer ready after sync execute: OK (0.02s)
EndToEnd.GPU
gpu available: OK
Runtime.BufferGPU
gpu buffer round-trip f32: OK
Runtime.AsyncGPU
gpu executeAsync + await: OK
Runtime.MultiGPU
execute replicas on all GPUs: OK
All 121 tests passed (16.27s)
Project Structure
.
├── app/ # hhlo-demo executable
├── cbits/ # C shim around PJRT C API
│ ├── pjrt_c_api.h # Upstream PJRT header
│ ├── pjrt_shim.c # Thin wrapper exposing flat C functions
│ └── pjrt_shim.h # C header for the shim
├── deps/
│ └── pjrt/ # Downloaded PJRT plugins (.so files)
│ └── lib_symlinks/ # Compatibility symlinks for missing library versions
├── doc/ # Architecture and design documents
├── examples/ # Standalone example programs (01–29)
├── src/HHLO/
│ ├── Core/Types.hs # DType, Shape, HostType type families
│ ├── IR/
│ │ ├── AST.hs # MLIR AST (Operation, Function, Module)
│ │ ├── Builder.hs # Stateful Builder monad + Tensor/Tuple GADTs
│ │ └── Pretty.hs # MLIR text pretty-printer
│ ├── EDSL/Ops.hs # Type-safe frontend ops (50+ ops)
│ └── Runtime/
│ ├── PJRT/
│ │ ├── FFI.hs # C FFI declarations
│ │ ├── Types.hs # Opaque pointer newtypes + buffer type constants
│ │ ├── Error.hs # PJRT error handling
│ │ └── Plugin.hs # Backend-agnostic plugin loading (withPJRT)
│ ├── Device.hs # Device enumeration & selection
│ ├── Compile.hs # MLIR → PJRT executable
│ ├── Compile.hs # MLIR → PJRT executable (with `CompileOptions`)
│ ├── Execute.hs # Synchronous + device-targeted + multi-GPU replica execution
│ ├── Async.hs # Non-blocking execution with PJRT_Event
│ └── Buffer.hs # Host↔device buffer transfers + metadata queries
├── test/
│ ├── Test/
│ │ ├── EDSL/Ops.hs
│ │ ├── IR/
│ │ │ ├── Builder.hs
│ │ │ ├── Pretty.hs
│ │ │ ├── PrettyOps.hs
│ │ │ ├── PrettyNN.hs
│ │ │ └── PrettyControlFlow.hs
│ │ ├── Runtime/
│ │ │ ├── EndToEnd*.hs # CPU E2E test modules
│ │ │ ├── EndToEndGPU.hs # GPU availability test
│ │ │ ├── Buffer.hs
│ │ │ ├── BufferGPU.hs # GPU buffer integration tests
│ │ │ ├── Async.hs
│ │ │ ├── AsyncGPU.hs # GPU async tests
│ │ │ ├── MultiGPU.hs # Multi-GPU inference scaling tests
│ │ │ └── Errors.hs
│ │ └── Utils.hs
│ └── Main.hs
├── hhlo.cabal
├── pjrt_script.sh # Downloads PJRT plugins
├── setup_gpu_env.sh # Auto-configures LD_LIBRARY_PATH for GPU
└── README.md
Architecture Docs
The doc/ directory contains detailed design documents:
| Document |
Contents |
implementation-design.md |
Four-layer architecture and design decisions |
progress-and-remaining-work.md |
Current status, completed features, and backlog |
test-suite-documentation.md |
Test catalog and tier descriptions |
License
MIT License — see LICENSE.