mnist + jax-js
Let's train a neural network to classify MNIST digits, in your browser
with jax-js.
The model is a 3-layer MLP or 4-layer convolutional neural network trained with Adam. Each epoch has 60 (MLP) or 240 (ConvNet) randomized batches, with 60,000 images in total in the train set.
Note: This demo requires a WebGPU-enabled browser. Works best on Chrome.
Train Loss
Test Loss & Accuracy
Inference Demo
draw a digit here!