jax-js is a machine learning framework for the browser. It aims to bring JAX-style, high-performance CPU and GPU kernels to JavaScript, so you can run numerical applications on the web.
npm i @jax-js/jax
Under the hood, it translates array operations into a compiler representation, then synthesizes kernels in WebAssembly and WebGPU.
The library is written from scratch, with zero external dependencies. It maintains close API compatibility with NumPy/JAX. Since everything runs client-side, jax-js is likely the most portable GPU ML framework, since it runs anywhere a browser can run.
You can use jax-js as an array API, just like NumPy.
import { numpy as np } from "@jax-js/jax";
// Array operations, compatible with NumPy.
const x = np.array([1, 2, 3]);
const y = x.mul(4); // [4, 8, 12]
It also lets you take derivatives with grad like in JAX (as well as vmap, jit).
import { grad, numpy as np } from "@jax-js/jax";
// Calculate derivatives with reverse-mode AD.
const norm = (a) => a.ref.mul(a).sum();
const x = np.array([1, 2, 3]);
const xnorm = norm(x.ref); // 1^2 + 2^2 + 3^2 = 14
const xgrad = grad(norm)(x); // [2, 4, 6]
The default backend runs on CPU, but on supported browsers including Chrome and iOS Safari, you can switch to GPU for better performance.
import { defaultDevice, init, numpy as np } from "@jax-js/jax";
// Initialize the GPU backend.
await init("webgpu");
// Change the default backend to GPU.
defaultDevice("webgpu");
const x = np.ones([4096, 4096]);
const y = np.dot(x.ref, x); // JIT-compiled into a matrix multiplication kernel
Most common JAX APIs are supported. See the compatibility table for a full breakdown of what features are available.
If you want to use jax-js in vanilla JavaScript (without a bundler), just import from a module
script tag. This is the easiest way to get started on a blank HTML page.
<script type="module">
import { numpy as np } from "https://esm.sh/@jax-js/jax";
</script>
We haven't spent a ton of time optimizing yet, but performance is generally pretty good. jit is
very helpful for fusing operations together, and it's a feature only available on the web in jax-js.
The default kernel-tuning heuristics get about 3000 GFLOP/s for matrix multiplication on an M4 Pro
chip (try it).
For that example, it's around the same GFLOP/s as TensorFlow.js and ONNX Runtime Web, which both use handwritten libraries of custom kernels (versus jax-js, which generates kernels with an ML compiler).
If you make something cool with jax-js, don't be a stranger! We can feature it here.
The following technical details are for contributing to jax-js and modifying its internals.
This repository is managed by pnpm. You can compile and build all packages in
watch mode with:
pnpm install
pnpm run build:watch
Then you can run tests in a headless browser using Vitest.
pnpm exec playwright install
pnpm test
We are currently on an older version of Playwright that supports using WebGPU in headless mode; newer versions skip the WebGPU tests.
To start a Vite dev server running the website, demos and REPL:
pnpm -C website dev
Contributions are welcomed in the following areas:
class Array {} wrapperstuner.ts)
jit() support via Jaxprs and kernel fusiondispose() / refcount / linear types stuff
dispose() for saved "const" tracers in Jaxprs