HHLO — Haskell Frontend for StableHLO
HHLO is a Haskell library and runtime for building, compiling, and executing machine learning programs targeting StableHLO, the portable, versioned intermediate representation of the OpenXLA ecosystem.
Instead of replicating JAX's Python-based tracing infrastructure, HHLO generates StableHLO MLIR text directly from Haskell and compiles it to CPU or GPU via the PJRT plugin interface.
Design
HHLO is structured in five layers:
┌─────────────────────────────────────┐
│ Convenience (HHLO.Session) │ One-liners: withCPU, compile, run
├─────────────────────────────────────┤
│ EDSL (HHLO.EDSL.Ops) │ Type-safe frontend: add, matmul, relu, etc.
├─────────────────────────────────────┤
│ IR Builder (HHLO.IR.Builder) │ Stateful monad for constructing MLIR
├─────────────────────────────────────┤
│ Pretty Printer (HHLO.IR.Pretty) │ Emits StableHLO MLIR text
├─────────────────────────────────────┤
│ PJRT Runtime (HHLO.Runtime.*) │ Compile → Execute on CPU or GPU
└─────────────────────────────────────┘
Text Emission + PJRT
The library emits StableHLO MLIR text directly and hands it to PJRT_Client_Compile. This is the same path used by JAX's C++ backend and avoids the heavy dependency of building LLVM/MLIR from source.
Phantom Types
Every tensor carries its shape and dtype as phantom type parameters:
Tensor '[2, 3] 'F32 -- 2×3 matrix of Float32
Matmul, broadcast, and conv shapes are checked at compile time via type families.
ForeignPtr Finalizers
PJRT buffers and executables are managed by ForeignPtr finalizers that automatically call PJRT_Buffer_Destroy and PJRT_LoadedExecutable_Destroy when values are garbage-collected. You can still let references drop out of scope without explicit cleanup.
Dynamic Output Counts
The runtime queries the compiled executable for its actual number of outputs via PJRT_Executable_NumOutputs instead of guessing or hardcoding a maximum.
Async Execution
HHLO.Runtime.Async provides true non-blocking execution: executeAsync returns buffer handles immediately, bufferReady polls for completion, and awaitBuffers blocks until device-side computation finishes.
Device Enumeration & Selection
HHLO.Runtime.Device lets you discover and select specific GPUs at runtime:
addressableDevices api client -- list all devices
deviceKind api dev -- "cpu" or "NVIDIA GeForce RTX 5090"
defaultGPUDevice api client -- first non-CPU device
Multi-GPU Inference Scaling
HHLO.Runtime.Execute provides executeReplicas for running the same compiled model concurrently across multiple GPUs:
compileWithOptions api client mlirText
(defaultCompileOptions { optNumReplicas = numDevs })
-- Launch independent forward passes on all GPUs
executeReplicas api exec
[ (gpu0, [bufA0, bufB0])
, (gpu1, [bufA1, bufB1])
, ...
]
Multi-Result Operations
The AST Operation type supports multiple results, enabling ops like stablehlo.rng_bit_generator and multi-value control flow:
-- Two-result operation
(newState, output) <- rngBitGenerator state
Convenience Layer
HHLO.ModuleBuilder and HHLO.Session provide a high-level API that eliminates PJRT boilerplate for the common case:
import HHLO.ModuleBuilder
import HHLO.Session
-- Build + compile + run in four lines
main = withCPU $ \sess -> do
let modu = buildModule @2 @1 "mul" $ \x y -> multiply x y
compiled <- compile sess modu
result <- run sess compiled (hostFromList @'[2] [2.0, 3.0],
hostFromList @'[2] [4.0, 5.0])
print (hostToList result) -- [8.0, 15.0]
No FuncArg, no natVal, no render, no toDeviceF32, no explicit shape lists. The low-level API remains available for expert users who need full control.
Multi-Value Control Flow
whileLoop2 / conditional2 carry multiple typed tensors through loops and conditionals without manual packing:
-- Loop with two accumulators: counter and running sum
(resultCounter, resultSum) <- whileLoop2 counter0 sum0
(\c s -> compare c limit "LT")
(\c s -> do
cNext <- add c one
sNext <- add s cNext
returnTuple2 cNext sNext)
Random Number Generation
Three RNG primitives are exposed in the EDSL:
uniform <- rngUniform a b -- uniform in [a, b)
normal <- rngNormal -- standard normal (mean 0, std 1)
(newSt, bits) <- rngBitGenerator state -- Threefry bit generator
Extended Math Primitives
Element-wise ops covering the full HBayesian requirements:
y <- sqrt x -- square root
y <- rsqrt x -- reciprocal sqrt
y <- sin x -- sine
y <- cos x -- cosine
y <- tan x -- tangent
y <- pow x e -- element-wise power
y <- log1p x -- log(1+x)
y <- floor x -- floor
y <- ceil x -- ceiling
y <- sigmoid x -- 1 / (1 + exp(-x))
Shape-Preserving Comparisons
compare and its wrappers return Tensor s 'Bool (same shape as inputs), matching StableHLO semantics:
mask <- equal x y -- element-wise equality
mask <- greaterThan x y -- element-wise >
mask <- lessThanOrEqual x y -- element-wise <=
Convenience ops for scalar manipulation:
s <- sumAll x -- reduce all dimensions to scalar
v <- slice1 vec i -- extract scalar from 1-D tensor
packed <- pack2 a b -- pack two scalars into [2]
Installation
System Requirements
- GHC 9.6+ and Cabal 3.10+
- Linux x86_64 (other platforms supported by PJRT artifacts may work)
curl, tar, and standard C toolchain (gcc or clang)
libstdc++ and libdl (usually present on Linux)
From Hackage
HHLO is published on Hackage. You can add it directly to your .cabal file:
build-depends: hhlo >= 0.4
Or with cabal:
cabal install hhlo
Download PJRT Plugins
Run the provided script to download prebuilt PJRT plugins:
./pjrt_script.sh
This downloads libpjrt_cpu.so from the zml/pjrt-artifacts nightly builds into deps/pjrt/. If you have an NVIDIA GPU with nvidia-smi available, the CUDA plugin is also fetched automatically.
Build the Project
cabal build all
This compiles the library, the demo, the examples, and the test suite.
Usage
CPU (works out of the box)
cabal run example-add --flag=examples
cabal test
Note: All example-* executables are guarded by the examples flag in hhlo.cabal (defaults to False). Append --flag=examples to every cabal run example-* command.
GPU (requires runtime libraries)
The PJRT CUDA plugin depends on NVIDIA runtime libraries: cuDNN, NCCL, and NVSHMEM. These are commonly available via conda, pip, or system packages.
If you already have them (e.g. via PyTorch or JAX installations), simply run:
./setup_gpu_env.sh
source ~/.bashrc
This idempotent script auto-discovers the libraries and appends them to ~/.bashrc. After that, GPU examples work directly:
cabal run example-gpu-add --flag=examples
cabal run example-gpu-matmul-bench --flag=examples
cabal run example-multi-gpu-inference --flag=examples
EDSL Quick Start
Convenience layer (recommended)
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeApplications #-}
import HHLO.ModuleBuilder
import HHLO.Session
-- Build a program: c = a + b
program = buildModule @2 @1 "add" $ \a b -> add a b
main = withCPU $ \sess -> do
compiled <- compile sess program
result <- run sess compiled
( hostFromList @'[2,2] @'F32 [1, 2, 3, 4]
, hostFromList @'[2,2] @'F32 [5, 6, 7, 8]
)
print (hostToList result) -- [6.0, 8.0, 10.0, 12.0]
Low-level API (full control)
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE OverloadedStrings #-}
import HHLO.Core.Types
import HHLO.EDSL.Ops
import HHLO.IR.AST (FuncArg(..), TensorType(..))
import HHLO.IR.Builder
import HHLO.IR.Pretty
import qualified Data.Text as T
-- Build a program: c = a + b
program :: Module
program = moduleFromBuilder @'[2,2] @'F32 "main"
[ FuncArg "a" (TensorType [2, 2] F32)
, FuncArg "b" (TensorType [2, 2] F32)
]
$ do
a <- arg
b <- arg
c <- add a b
return c
main :: IO ()
main = T.putStrLn (render program)
Output:
module {
func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = stablehlo.add %arg0, %arg1 : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
}
Running the Demo
cabal run hhlo-demo
The demo builds a stablehlo.add program via the EDSL, compiles it with PJRT CPU, creates F32 input buffers, executes, and reads back the result:
=== HHLO End-to-End Demo ===
Loading PJRT CPU plugin...
Plugin loaded.
...
Result: [6.0,8.0,10.0,12.0]
SUCCESS: Results match expected values!
Running Examples
Standalone examples are provided in examples/:
| # |
Command |
Description |
| 1 |
cabal run example-add --flag=examples |
Element-wise c = a + b |
| 2 |
cabal run example-matmul --flag=examples |
2×3 @ 3×2 matrix multiply |
| 3 |
cabal run example-chain-ops --flag=examples |
(a + b) * (a - b) |
| 4 |
cabal run example-async --flag=examples |
Async executeAsync + relu |
| 5 |
cabal run example-mlp --flag=examples |
2-layer MLP |
| 6 |
cabal run example-mlp-batched --flag=examples |
Batched MLP |
| 7 |
cabal run example-tuple --flag=examples |
Multi-result func.func |
| 8 |
cabal run example-reduce --flag=examples |
reduceSum over all dimensions |
| 9 |
cabal run example-softmax --flag=examples |
1-D and batched 2-D softmax |
| 10 |
cabal run example-conv2d --flag=examples |
NHWC conv2d |
| 11 |
cabal run example-batch-norm --flag=examples |
Batch norm inference |
| 12 |
cabal run example-while --flag=examples |
whileLoop count-up |
| 13 |
cabal run example-conditional --flag=examples |
conditional if-then-else |
| 14 |
cabal run example-gather --flag=examples |
gather rows from matrix |
| 15 |
cabal run example-scatter --flag=examples |
scatter replace into vector |
| 16 |
cabal run example-slice --flag=examples |
slice sub-array extraction |
| 17 |
cabal run example-pad --flag=examples |
pad with edge/interior padding |
| 18 |
cabal run example-dynamic-slice --flag=examples |
dynamicSlice runtime indices |
| 19 |
cabal run example-sort --flag=examples |
sort 1-D ascending |
| 20 |
cabal run example-select --flag=examples |
Element-wise ternary select |
| 21 |
cabal run example-map --flag=examples |
map with custom computation |
| 22 |
cabal run example-new-ops-smoke-test --flag=examples |
Smoke test for newer ops |
| 23 |
cabal run example-resnet --flag=examples |
ResNet-18 toy (8×8 input) |
| 24 |
cabal run example-alexnet --flag=examples |
AlexNet toy (16×16 input) |
| 25 |
cabal run example-transformer --flag=examples |
Transformer encoder (1×4×16) |
| 26 |
cabal run example-unet --flag=examples |
UNet segmentation toy (16×16) |
| 30 |
cabal run example-rng-uniform --flag=examples |
rngUniform random floats [0,1) |
| 31 |
cabal run example-rng-normal --flag=examples |
rngNormal standard normal distribution |
| 32 |
cabal run example-rng-bit-generator --flag=examples |
rngBitGenerator Threefry PRNG |
| 33 |
cabal run example-multi-value-loop --flag=examples |
whileLoop2 with two loop-carried values |
| 27 |
cabal run example-gpu-add --flag=examples |
GPU smoke test |
| 28 |
cabal run example-gpu-matmul-bench --flag=examples |
GPU 4096×4096 benchmark |
| 29 |
cabal run example-multi-gpu-inference --flag=examples |
Multi-GPU concurrent matmul |
Tests
CPU Tests (default)
cabal test
Runs 155 tests across three tiers:
- Tier 1 — Golden tests — Verify rendered MLIR text for EDSL ops, IR constructs, NN layers, and control flow.
- Tier 2 — End-to-end runtime tests — Load the PJRT CPU plugin, compile StableHLO programs, execute them, and verify numerical results. Covers arithmetic, matmul, reductions, data movement, and NN ops.
- Tier 3 — Runtime integration tests — Buffer metadata queries, async execution, and error handling.
GPU Tests
HHLO_TEST_GPU=1 cabal test
Runs the full 155 CPU tests plus 6 additional GPU integration tests:
EndToEnd.GPU — GPU availability and device enumeration
Runtime.BufferGPU — Buffer round-trip and metadata queries on GPU
Runtime.AsyncGPU — Async execution and bufferReady polling on GPU
Runtime.MultiGPU — Concurrent executeReplicas across all GPUs
Sample output:
HHLO Tests
EDSL.Ops
Binary element-wise
add: OK
...
EndToEnd.Arithmetic
relu: OK (0.02s)
...
Runtime.Buffer
buffer round-trip f32: OK
Runtime.Async
buffer ready after sync execute: OK (0.02s)
EndToEnd.GPU
gpu available: OK
Runtime.BufferGPU
gpu buffer round-trip f32: OK
Runtime.AsyncGPU
gpu executeAsync + await: OK
Runtime.MultiGPU
execute replicas on all GPUs: OK
All 147 tests passed (16.27s)
Project Structure
.
├── app/ # hhlo-demo executable
├── cbits/ # C shim around PJRT C API
│ ├── pjrt_c_api.h # Upstream PJRT header
│ ├── pjrt_shim.c # Thin wrapper exposing flat C functions
│ └── pjrt_shim.h # C header for the shim
├── deps/
│ └── pjrt/ # Downloaded PJRT plugins (.so files)
│ └── lib_symlinks/ # Compatibility symlinks for missing library versions
├── doc/ # Architecture and design documents
├── examples/ # Standalone example programs (01–33)
├── src/HHLO/
│ ├── Core/Types.hs # DType, Shape, HostType type families
│ ├── IR/
│ │ ├── AST.hs # MLIR AST (Operation, Function, Module)
│ │ ├── Builder.hs # Stateful Builder monad + Tensor/Tuple GADTs
│ │ └── Pretty.hs # MLIR text pretty-printer
│ ├── EDSL/Ops.hs # Type-safe frontend ops (50+ ops)
│ └── Runtime/
│ ├── PJRT/
│ │ ├── FFI.hs # C FFI declarations
│ │ ├── Types.hs # Opaque pointer newtypes + buffer type constants
│ │ ├── Error.hs # PJRT error handling
│ │ └── Plugin.hs # Backend-agnostic plugin loading (withPJRT)
│ ├── Device.hs # Device enumeration & selection
│ ├── Compile.hs # MLIR → PJRT executable
│ ├── Compile.hs # MLIR → PJRT executable (with `CompileOptions`)
│ ├── Execute.hs # Synchronous + device-targeted + multi-GPU replica execution
│ ├── Async.hs # Non-blocking execution with PJRT_Event
│ └── Buffer.hs # Host↔device buffer transfers + metadata queries
├── test/
│ ├── Test/
│ │ ├── EDSL/Ops.hs
│ │ ├── IR/
│ │ │ ├── Builder.hs
│ │ │ ├── Pretty.hs
│ │ │ ├── PrettyOps.hs
│ │ │ ├── PrettyNN.hs
│ │ │ └── PrettyControlFlow.hs
│ │ ├── Runtime/
│ │ │ ├── EndToEnd*.hs # CPU E2E test modules
│ │ │ ├── EndToEndGPU.hs # GPU availability test
│ │ │ ├── Buffer.hs
│ │ │ ├── BufferGPU.hs # GPU buffer integration tests
│ │ │ ├── Async.hs
│ │ │ ├── AsyncGPU.hs # GPU async tests
│ │ │ ├── MultiGPU.hs # Multi-GPU inference scaling tests
│ │ │ └── Errors.hs
│ │ └── Utils.hs
│ └── Main.hs
├── hhlo.cabal
├── pjrt_script.sh # Downloads PJRT plugins
├── setup_gpu_env.sh # Auto-configures LD_LIBRARY_PATH for GPU
└── README.md
License
MIT License — see LICENSE.