diff --git a/docs/index.md b/docs/index.md index b13a0826a..a16275c83 100644 --- a/docs/index.md +++ b/docs/index.md @@ -72,6 +72,7 @@ Basics Advanced usage Transformations <./grain.dataset> Performance debugging +JAX training tips Performance autotuning ``` diff --git a/docs/tutorials/jax_training_tutorial.ipynb b/docs/tutorials/jax_training_tutorial.ipynb new file mode 100644 index 000000000..7c783a221 --- /dev/null +++ b/docs/tutorials/jax_training_tutorial.ipynb @@ -0,0 +1,510 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "jx-intro-md" + }, + "source": [ + "# Plugging Grain into JAX training: batching + accelerator transfer\n", + "\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/grain/blob/main/docs/tutorials/jax_training_tutorial.ipynb)\n", + "\n", + "This guide covers the last mile between a Grain pipeline and a JAX training step: how to **batch** records into arrays of the right shape, and how to **move those batches onto your accelerators** efficiently: host-device prefetch, sharding across devices, and distributed-training shards." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jx-install" + }, + "outputs": [], + "source": [ + "# @test {\"output\": \"ignore\"}\n", + "!pip install grain\n", + "# @test {\"output\": \"ignore\"}\n", + "!pip install tensorflow_datasets" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "jx-imports" + }, + "outputs": [], + "source": [ + "import grain\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "import tensorflow_datasets as tfds" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jx-minimal-md" + }, + "source": [ + "## 1. Minimal end-to-end pipeline\n", + "\n", + "The shortest pipeline you'd want for JAX training: source -> shuffle -> preprocess -> **batch** -> iterate -> **`device_put`** -> step." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "jx-minimal-code" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:OpenCV is not installed. We recommend using OpenCV because it is faster according to our benchmarks. Defaulting to PIL to decode images...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'image': ((128, 28, 28, 1), dtype('float32')), 'label': ((128,), dtype('int32'))}\n" + ] + } + ], + "source": [ + "source = tfds.data_source(\"mnist\", split=\"train\")\n", + "\n", + "ds = (\n", + " grain.MapDataset.source(source)\n", + " .seed(42)\n", + " .shuffle()\n", + " .map(lambda r: {\"image\": r[\"image\"].astype(np.float32) / 255.0,\n", + " \"label\": r[\"label\"]})\n", + " .batch(batch_size=128, drop_remainder=True) # new leading dim\n", + " .to_iter_dataset()\n", + ")\n", + "\n", + "for batch in ds:\n", + " batch = jax.device_put(batch) # default device\n", + " print(jax.tree.map(lambda x: (x.shape, x.dtype), batch))\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jx-minimal-notes" + }, + "source": [ + "A few things to notice:\n", + "\n", + "- `batch(...)` lives on `MapDataset`. It stacks PyTree leaves along a **new leading axis** (here `[128, 28, 28, 1]` for images, `[128]` for labels).\n", + "- `drop_remainder=True` guarantees a static batch shape, which lets `jax.jit` cache one compiled version of the step.\n", + "- `to_iter_dataset()` turns the random-access `MapDataset` into an `IterDataset`. Do this **after** any random-access transforms (shuffle, batch, repeat) and **before** any streaming transforms (prefetch, `device_put`)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jx-batching-md" + }, + "source": [ + "## 2. Batching tips that matter for JAX\n", + "\n", + "**Stable shapes.** JAX recompiles whenever input shapes change. Pair `batch(drop_remainder=True)` with `.repeat()` so the loop never produces a short final batch:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "jx-repeat-code" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "length: 72057594037927935\n" + ] + } + ], + "source": [ + "ds = (\n", + " grain.MapDataset.source(source)\n", + " .seed(42)\n", + " .shuffle()\n", + " .repeat() # infinite stream\n", + " .map(lambda r: {\"image\": r[\"image\"].astype(np.float32) / 255.0,\n", + " \"label\": r[\"label\"]})\n", + " .batch(128, drop_remainder=True)\n", + ")\n", + "print(\"length:\", len(ds)) # sys.maxsize" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jx-collate-md" + }, + "source": [ + "**Custom collation.** The default `batch_fn` stacks leaves with `np.stack`. Pass your own when you need padding, ragged handling, or anything non-uniform:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "jx-collate-code" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(4, 5)\n" + ] + } + ], + "source": [ + "def pad_collate(items):\n", + " max_len = max(x[\"tokens\"].shape[0] for x in items)\n", + " tokens = np.stack([\n", + " np.pad(x[\"tokens\"], (0, max_len - x[\"tokens\"].shape[0]))\n", + " for x in items\n", + " ])\n", + " return {\"tokens\": tokens}\n", + "\n", + "# Toy stream of variable-length token sequences.\n", + "ragged = grain.MapDataset.source(\n", + " [{\"tokens\": np.arange(np.random.randint(2, 6))} for _ in range(16)]\n", + ")\n", + "ragged = ragged.batch(4, batch_fn=pad_collate, drop_remainder=True)\n", + "print(ragged[0][\"tokens\"].shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jx-pad-md" + }, + "source": [ + "For variable-length token streams, also look at `grain.experimental.batch_and_pad` — it pads partial final batches to the requested batch size with a sentinel, so you keep one static shape without dropping data." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jx-transfer-md" + }, + "source": [ + "## 3. Moving batches to the accelerator\n", + "\n", + "There are three options. Pick the lowest tier that meets your needs." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jx-option-a-md" + }, + "source": [ + "### Option A: plain `jax.device_put`\n", + "\n", + "Fine for prototyping and small models:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "jx-option-a-code" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)\n", + "1 SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)\n" + ] + } + ], + "source": [ + "ds = (\n", + " grain.MapDataset.source(source)\n", + " .seed(42).shuffle()\n", + " .map(lambda r: {\"image\": r[\"image\"].astype(np.float32) / 255.0,\n", + " \"label\": r[\"label\"]})\n", + " .batch(128, drop_remainder=True)\n", + " .to_iter_dataset()\n", + ")\n", + "\n", + "for step, batch in zip(range(2), ds):\n", + " batch = jax.device_put(batch)\n", + " print(step, batch[\"image\"].sharding)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jx-option-a-caveat" + }, + "source": [ + "The transfer happens on the main thread between every `next(...)`, so the host blocks while the device receives data. On a real training loop this can leave the accelerator idle." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jx-option-b-md" + }, + "source": [ + "### Option B: overlap host work with `ThreadPrefetchIterDataset`\n", + "\n", + "Run the pipeline's CPU work on a background thread so the next batch is ready by the time the device is done with the previous step:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "jx-option-b-code" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(128, 28, 28, 1) SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)\n" + ] + } + ], + "source": [ + "ds = (\n", + " grain.MapDataset.source(source)\n", + " .seed(42).shuffle()\n", + " .map(lambda r: {\"image\": r[\"image\"].astype(np.float32) / 255.0,\n", + " \"label\": r[\"label\"]})\n", + " .batch(128, drop_remainder=True)\n", + " .to_iter_dataset()\n", + ")\n", + "ds = grain.experimental.ThreadPrefetchIterDataset(ds, prefetch_buffer_size=4)\n", + "ds = ds.map(jax.device_put) # transfer still on iter thread\n", + "\n", + "first = next(iter(ds))\n", + "print(first[\"image\"].shape, first[\"image\"].sharding)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jx-option-c-md" + }, + "source": [ + "### Option C: two-stage prefetch with `grain.experimental.device_put`\n", + "\n", + "The recommended pattern for real training. It runs a CPU buffer **and** a device-resident buffer, so a batch is already on the accelerator before the step asks for it:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "jx-option-c-code" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)\n", + "1 SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)\n" + ] + } + ], + "source": [ + "ds = (\n", + " grain.MapDataset.source(source)\n", + " .seed(42).shuffle()\n", + " .map(lambda r: {\"image\": r[\"image\"].astype(np.float32) / 255.0,\n", + " \"label\": r[\"label\"]})\n", + " .batch(128, drop_remainder=True)\n", + " .to_iter_dataset()\n", + ")\n", + "\n", + "ds = grain.experimental.device_put(\n", + " ds=ds,\n", + " device=jax.devices()[0], # or a Sharding (see below)\n", + " cpu_buffer_size=4, # batches buffered on host\n", + " device_buffer_size=2, # batches buffered on device\n", + ")\n", + "\n", + "for step, batch in zip(range(2), ds):\n", + " # `batch` is already a jax.Array on-device.\n", + " print(step, batch[\"image\"].sharding)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jx-option-c-impl" + }, + "source": [ + "Under the hood this is just `ThreadPrefetch -> map(jax.device_put) -> ThreadPrefetch`." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jx-shard-arrays-md" + }, + "source": [ + "## 4. Multi-device: sharding a batch across accelerators\n", + "\n", + "For data-parallel training across all local devices, pass a `Sharding` to `device_put` instead of a single device. Each batch is split along its first axis:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "id": "jx-shard-arrays-code" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NamedSharding(mesh=Mesh('data': 1, axis_types=(Auto,)), spec=PartitionSpec('data',), memory_kind=device)\n" + ] + } + ], + "source": [ + "devices = jax.devices()\n", + "mesh = jax.sharding.Mesh(np.array(devices), axis_names=(\"data\",))\n", + "sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(\"data\"))\n", + "\n", + "ds = (\n", + " grain.MapDataset.source(source)\n", + " .seed(42).shuffle().repeat()\n", + " .map(lambda r: {\"image\": r[\"image\"].astype(np.float32) / 255.0,\n", + " \"label\": r[\"label\"]})\n", + " .batch(128, drop_remainder=True)\n", + " .to_iter_dataset()\n", + ")\n", + "\n", + "ds = grain.experimental.device_put(\n", + " ds=ds,\n", + " device=sharding,\n", + " cpu_buffer_size=4,\n", + " device_buffer_size=2,\n", + ")\n", + "\n", + "batch = next(iter(ds))\n", + "print(batch[\"image\"].sharding)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jx-shard-arrays-notes" + }, + "source": [ + "Make sure `batch_size` is divisible by `len(devices)` — otherwise the sharding split fails. Inside your train step, decorate with `jax.jit` and JAX will compile a single SPMD program that handles the per-device slices automatically." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jx-template-md" + }, + "source": [ + "## 5. Putting it all together\n", + "\n", + "A realistic single-host, multi-device template:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "jx-template-code" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "final params: 0.3906955\n" + ] + } + ], + "source": [ + "BATCH = 256\n", + "devices = jax.devices()\n", + "mesh = jax.sharding.Mesh(np.array(devices), axis_names=(\"data\",))\n", + "sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(\"data\"))\n", + "\n", + "def preprocess(r):\n", + " return {\"image\": r[\"image\"].astype(np.float32) / 255.0,\n", + " \"label\": r[\"label\"]}\n", + "\n", + "ds = (\n", + " grain.MapDataset.source(source)\n", + " .seed(42).shuffle().repeat()\n", + " .map(preprocess)\n", + " .batch(BATCH, drop_remainder=True)\n", + " .to_iter_dataset()\n", + ")\n", + "\n", + "ds = grain.experimental.device_put(\n", + " ds=ds, device=sharding,\n", + " cpu_buffer_size=4, device_buffer_size=2,\n", + ")\n", + "\n", + "@jax.jit\n", + "def train_step(params, batch):\n", + " # Replace with your real loss/update.\n", + " return params + batch[\"image\"].mean()\n", + "\n", + "params = jnp.zeros(())\n", + "for step, batch in zip(range(3), ds):\n", + " params = train_step(params, batch)\n", + "print(\"final params:\", params)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [], + "toc_visible": true + }, + "jupytext": { + "formats": "ipynb,md:myst" + }, + "kernelspec": { + "display_name": "grain-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.14" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/tutorials/jax_training_tutorial.md b/docs/tutorials/jax_training_tutorial.md new file mode 100644 index 000000000..34d801e05 --- /dev/null +++ b/docs/tutorials/jax_training_tutorial.md @@ -0,0 +1,293 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.19.2 +kernelspec: + display_name: grain-dev + language: python + name: python3 +--- + ++++ {"id": "jx-intro-md"} + +# Plugging Grain into JAX training: batching + accelerator transfer + +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/grain/blob/main/docs/tutorials/jax_training_tutorial.ipynb) + +This guide covers the last mile between a Grain pipeline and a JAX training step: how to **batch** records into arrays of the right shape, and how to **move those batches onto your accelerators** efficiently: host-device prefetch, sharding across devices, and distributed-training shards. + +```{code-cell} ipython3 +:id: jx-install + +# @test {"output": "ignore"} +!pip install grain +# @test {"output": "ignore"} +!pip install tensorflow_datasets +``` + +```{code-cell} ipython3 +:id: jx-imports + +import grain +import jax +import jax.numpy as jnp +import numpy as np +import tensorflow_datasets as tfds +``` + ++++ {"id": "jx-minimal-md"} + +## 1. Minimal end-to-end pipeline + +The shortest pipeline you'd want for JAX training: source -> shuffle -> preprocess -> **batch** -> iterate -> **`device_put`** -> step. + +```{code-cell} ipython3 +:id: jx-minimal-code + +source = tfds.data_source("mnist", split="train") + +ds = ( + grain.MapDataset.source(source) + .seed(42) + .shuffle() + .map(lambda r: {"image": r["image"].astype(np.float32) / 255.0, + "label": r["label"]}) + .batch(batch_size=128, drop_remainder=True) # new leading dim + .to_iter_dataset() +) + +for batch in ds: + batch = jax.device_put(batch) # default device + print(jax.tree.map(lambda x: (x.shape, x.dtype), batch)) + break +``` + ++++ {"id": "jx-minimal-notes"} + +A few things to notice: + +- `batch(...)` lives on `MapDataset`. It stacks PyTree leaves along a **new leading axis** (here `[128, 28, 28, 1]` for images, `[128]` for labels). +- `drop_remainder=True` guarantees a static batch shape, which lets `jax.jit` cache one compiled version of the step. +- `to_iter_dataset()` turns the random-access `MapDataset` into an `IterDataset`. Do this **after** any random-access transforms (shuffle, batch, repeat) and **before** any streaming transforms (prefetch, `device_put`). + ++++ {"id": "jx-batching-md"} + +## 2. Batching tips that matter for JAX + +**Stable shapes.** JAX recompiles whenever input shapes change. Pair `batch(drop_remainder=True)` with `.repeat()` so the loop never produces a short final batch: + +```{code-cell} ipython3 +:id: jx-repeat-code + +ds = ( + grain.MapDataset.source(source) + .seed(42) + .shuffle() + .repeat() # infinite stream + .map(lambda r: {"image": r["image"].astype(np.float32) / 255.0, + "label": r["label"]}) + .batch(128, drop_remainder=True) +) +print("length:", len(ds)) # sys.maxsize +``` + ++++ {"id": "jx-collate-md"} + +**Custom collation.** The default `batch_fn` stacks leaves with `np.stack`. Pass your own when you need padding, ragged handling, or anything non-uniform: + +```{code-cell} ipython3 +:id: jx-collate-code + +def pad_collate(items): + max_len = max(x["tokens"].shape[0] for x in items) + tokens = np.stack([ + np.pad(x["tokens"], (0, max_len - x["tokens"].shape[0])) + for x in items + ]) + return {"tokens": tokens} + +# Toy stream of variable-length token sequences. +ragged = grain.MapDataset.source( + [{"tokens": np.arange(np.random.randint(2, 6))} for _ in range(16)] +) +ragged = ragged.batch(4, batch_fn=pad_collate, drop_remainder=True) +print(ragged[0]["tokens"].shape) +``` + ++++ {"id": "jx-pad-md"} + +For variable-length token streams, also look at `grain.experimental.batch_and_pad` — it pads partial final batches to the requested batch size with a sentinel, so you keep one static shape without dropping data. + ++++ {"id": "jx-transfer-md"} + +## 3. Moving batches to the accelerator + +There are three options. Pick the lowest tier that meets your needs. + ++++ {"id": "jx-option-a-md"} + +### Option A: plain `jax.device_put` + +Fine for prototyping and small models: + +```{code-cell} ipython3 +:id: jx-option-a-code + +ds = ( + grain.MapDataset.source(source) + .seed(42).shuffle() + .map(lambda r: {"image": r["image"].astype(np.float32) / 255.0, + "label": r["label"]}) + .batch(128, drop_remainder=True) + .to_iter_dataset() +) + +for step, batch in zip(range(2), ds): + batch = jax.device_put(batch) + print(step, batch["image"].sharding) +``` + ++++ {"id": "jx-option-a-caveat"} + +The transfer happens on the main thread between every `next(...)`, so the host blocks while the device receives data. On a real training loop this can leave the accelerator idle. + ++++ {"id": "jx-option-b-md"} + +### Option B: overlap host work with `ThreadPrefetchIterDataset` + +Run the pipeline's CPU work on a background thread so the next batch is ready by the time the device is done with the previous step: + +```{code-cell} ipython3 +:id: jx-option-b-code + +ds = ( + grain.MapDataset.source(source) + .seed(42).shuffle() + .map(lambda r: {"image": r["image"].astype(np.float32) / 255.0, + "label": r["label"]}) + .batch(128, drop_remainder=True) + .to_iter_dataset() +) +ds = grain.experimental.ThreadPrefetchIterDataset(ds, prefetch_buffer_size=4) +ds = ds.map(jax.device_put) # transfer still on iter thread + +first = next(iter(ds)) +print(first["image"].shape, first["image"].sharding) +``` + ++++ {"id": "jx-option-c-md"} + +### Option C: two-stage prefetch with `grain.experimental.device_put` + +The recommended pattern for real training. It runs a CPU buffer **and** a device-resident buffer, so a batch is already on the accelerator before the step asks for it: + +```{code-cell} ipython3 +:id: jx-option-c-code + +ds = ( + grain.MapDataset.source(source) + .seed(42).shuffle() + .map(lambda r: {"image": r["image"].astype(np.float32) / 255.0, + "label": r["label"]}) + .batch(128, drop_remainder=True) + .to_iter_dataset() +) + +ds = grain.experimental.device_put( + ds=ds, + device=jax.devices()[0], # or a Sharding (see below) + cpu_buffer_size=4, # batches buffered on host + device_buffer_size=2, # batches buffered on device +) + +for step, batch in zip(range(2), ds): + # `batch` is already a jax.Array on-device. + print(step, batch["image"].sharding) +``` + ++++ {"id": "jx-option-c-impl"} + +Under the hood this is just `ThreadPrefetch -> map(jax.device_put) -> ThreadPrefetch`. + ++++ {"id": "jx-shard-arrays-md"} + +## 4. Multi-device: sharding a batch across accelerators + +For data-parallel training across all local devices, pass a `Sharding` to `device_put` instead of a single device. Each batch is split along its first axis: + +```{code-cell} ipython3 +:id: jx-shard-arrays-code + +devices = jax.devices() +mesh = jax.sharding.Mesh(np.array(devices), axis_names=("data",)) +sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("data")) + +ds = ( + grain.MapDataset.source(source) + .seed(42).shuffle().repeat() + .map(lambda r: {"image": r["image"].astype(np.float32) / 255.0, + "label": r["label"]}) + .batch(128, drop_remainder=True) + .to_iter_dataset() +) + +ds = grain.experimental.device_put( + ds=ds, + device=sharding, + cpu_buffer_size=4, + device_buffer_size=2, +) + +batch = next(iter(ds)) +print(batch["image"].sharding) +``` + ++++ {"id": "jx-shard-arrays-notes"} + +Make sure `batch_size` is divisible by `len(devices)` — otherwise the sharding split fails. Inside your train step, decorate with `jax.jit` and JAX will compile a single SPMD program that handles the per-device slices automatically. + ++++ {"id": "jx-template-md"} + +## 5. Putting it all together + +A realistic single-host, multi-device template: + +```{code-cell} ipython3 +:id: jx-template-code + +BATCH = 256 +devices = jax.devices() +mesh = jax.sharding.Mesh(np.array(devices), axis_names=("data",)) +sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("data")) + +def preprocess(r): + return {"image": r["image"].astype(np.float32) / 255.0, + "label": r["label"]} + +ds = ( + grain.MapDataset.source(source) + .seed(42).shuffle().repeat() + .map(preprocess) + .batch(BATCH, drop_remainder=True) + .to_iter_dataset() +) + +ds = grain.experimental.device_put( + ds=ds, device=sharding, + cpu_buffer_size=4, device_buffer_size=2, +) + +@jax.jit +def train_step(params, batch): + # Replace with your real loss/update. + return params + batch["image"].mean() + +params = jnp.zeros(()) +for step, batch in zip(range(3), ds): + params = train_step(params, batch) +print("final params:", params) +```