mnist + jax-js

Let's train a neural network to classify MNIST digits, in your browser with jax-js.

The model is a 3-layer MLP or 4-layer convolutional neural network trained with Adam. Each epoch has 60 (MLP) or 240 (ConvNet) randomized batches, with 60,000 images in total in the train set.

Note: This demo requires a WebGPU-enabled browser. Works best on Chrome.

Train Loss

Test Loss & Accuracy

Inference Demo

draw a digit here!