jax-js logo

jax-js is an ML library and compiler for the web

High-performance WebGPU and WebAssembly kernels in JavaScript. Run AI training and inference, image algorithms, simulations, and numerical code on arrays, all JIT compiled in your browser.

Add jax-js to your project

Zero dependencies. All major browsers, with and in .

npm install @jax-js/jax

Matrix multiplication

Billions of floating-point operations (GFLOPs) per second

1.72

1071

1343

Like JAX and PyTorch in your browser

jax-js is a end-to-end ML library inspired by JAX, but in pure JavaScript:

  • Runs completely client-side (Chrome, Firefox, iOS, Android).
  • Has close API compatibility with NumPy/JAX.
  • Is written from scratch, with zero external dependencies.

jax-js is likely the most portable GPU ML framework, since it runs anywhere a browser can run. It's also simple but optimized, including a lightweight compiler that translates your high-level operations into WebGPU and WebAssembly kernels.

The goal of jax-js is to make numerical code accessible and deployable to everyone, so compute-intensive apps can run fast and locally on consumer hardware.

Try it out!

This is a live editor, the code is running in your browser.

Run code to see output here.

Learn more