Like many PyTorch users, you may have heard great things about JAX — its high performance, the elegance of its functional programming approach, and its powerful, built-in support for parallel computation. However, you may have also struggled to find what you need to get started: a straightforward, easy-to-follow tutorial to help you understand the basics of JAX by connecting its new concepts to the PyTorch building blocks that you’re already familiar with. So, we created one!
In this tutorial, we explore the basics of the JAX ecosystem from the lens of a PyTorch user, focusing on training a simple neural network in both frameworks for the classic machine learning (ML) task of predicting which passengers survived the Titanic disaster. Along the way, we introduce JAX by demonstrating how many things — from model definitions and instantiation to training — map to their PyTorch equivalents.
You can follow along with full code examples in the accompanying notebook: https://www.kaggle.com/code/anfalatgoogle/pytorch-developer-s-guide-to-jax-fundamentals
Modularity with JAX
As a PyTorch user, you might initially find Jax’s highly modularized ecosystem to be quite different than what you are used to. JAX focuses on being a high-performance numerical computation library with support for automatic differentiation. Unlike with PyTorch, it does not try to have explicit built-in support for defining neural networks, optimizers, etc. Instead, JAX is designed to be flexible, allowing you to bring in your frameworks of choice to add to its functionality.
In this tutorial, we use the Flax Neural Network library and the Optax optimization library — both very popular, well-supported libraries. We show how to train a neural network in the new Flax NNX API for a very PyTorch-esque experience, and then show how to do the same thing with the older, but still widely-used Linen API.
- aside_block
- <ListValue: [StructValue([('title', '$300 in free credit to try Google Cloud AI and ML'), ('body', <wagtail.rich_text.RichText object at 0x3e2c0400a850>), ('btn_text', 'Start building for free'), ('href', 'http://console.cloud.google.com/freetrial?redirectPath=/vertex-ai/'), ('image', None)])]>
Functional programming
Before we dive into our tutorial, let’s talk about JAX’s rationale for using functional programming, as opposed to the object-oriented programming that PyTorch and other frameworks use. Briefly, functional programming focuses on pure functions that cannot mutate state and cannot have side effects, i.e., they always produce the same output for the same input. In JAX, this manifests through significant usage of composable functions and immutable arrays.
The predictability of pure functions and functional programming unlocks many benefits in JAX, such as Just-In-Time (JIT) compilation, where the XLA compiler can significantly optimize code on GPUs or TPUs, for major speed-ups. Moreover, they also make sharding and parallelizing operations much easier in JAX. You can learn more from the official JAX tutorials.
Do not be deterred if you’re new to functional programming — as you will soon see, Flax NNX hides much of it behind standard Pythonic idioms.
Data loading
Data loading in JAX is very straightforward — just do what you already do in PyTorch. You can use a PyTorch dataset/dataloader with a simple collate_fn
to convert things to the Numpy-like arrays that underlie all JAX computation.
- code_block
- <ListValue: [StructValue([('code', 'import torchrnfrom torch.utils.data import Dataset, DataLoader, default_collaternimport jax.numpy as jnprnfrom jax.tree_util import tree_maprnrn# Dataset Definition rnclass TitanicDataset(Dataset):rn def __init__(self, samples, labels):rn self.df = samplesrn self.labels = labelsrnrn def __len__(self):rn return len(self.df)rnrn def __getitem__(self, idx):rn x = torch.tensor(self.df.iloc[idx].values, dtype=torch.float32)rn y = torch.tensor(self.labels.iloc[idx], dtype=torch.float32)rn return x, yrnrndef numpy_collate(batch):rn return tree_map(jnp.asarray, default_collate(batch))rnrn# Create Train dataset, dataloader from a pandas dataframe rntrain_dataset = TitanicDataset(X_train, y_train)rntrain_dataloader_jax = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=numpy_collate)rntrain_dataloader_torch = DataLoader(train_dataset, batch_size=64, shuffle=True)rnrn# Create Eval dataset, dataloader from a pandas dataframerneval_dataset = TitanicDataset(X_test, y_test)rneval_dataloader_jax = DataLoader(eval_dataset, batch_size=64, shuffle=False, collate_fn=numpy_collate)rneval_dataloader_torch = DataLoader(eval_dataset, batch_size=64, shuffle=False)'), ('language', ''), ('caption', <wagtail.rich_text.RichText object at 0x3e2c0400a670>)])]>
Model definition
With Flax’s NNX API, defining your neural networks is very similar to doing so in PyTorch. Here we define a simple, two-layer multilayer perceptron in both frameworks, starting with PyTorch.
- code_block
- <ListValue: [StructValue([('code', 'import torch.nn as nnrnrnclass TitanicNeuralNet(nn.Module):rn def __init__(self, num_hidden_1, num_hidden_2):rn super().__init__()rn self.linear1 = nn.Linear(8, num_hidden_1)rn self.dropout = nn.Dropout(0.01)rn self.relu = nn.LeakyReLU()rn self.linear2 = nn.Linear(num_hidden_1, num_hidden_2)rn self.linear3 = nn.Linear(num_hidden_2, 1, bias=False)rnrn def forward(self, x):rn x = self.linear1(x)rn x = self.dropout(x)rn x = self.relu(x)rn x = self.linear2(x)rn x = self.dropout(x)rn x = self.relu(x)rn out = self.linear3(x)rn return out'), ('language', ''), ('caption', <wagtail.rich_text.RichText object at 0x3e2c0400ac40>)])]>
NNX model definitions are very similar to the PyTorch code above. Both make use of __init__
to define the layers of the model, while __call__
corresponds to forward
.
- code_block
- <ListValue: [StructValue([('code', 'from flax import nnxrnrnclass TitanicNNX(nnx.Module):rn def __init__(self, num_hidden_1, num_hidden_2, rngs: nnx.Rngs):rn self.linear1 = nnx.Linear(8, num_hidden_1, rngs=rngs)rn self.dropout = nnx.Dropout(0.01, rngs=rngs)rn self.relu = nnx.leaky_relurn self.linear2 = nnx.Linear(num_hidden_1, num_hidden_2, rngs=rngs)rn self.linear3 = nnx.Linear(num_hidden_2, 1, use_bias=False, rngs=rngs)rnrn def __call__(self, x):rn x = self.linear1(x)rn x = self.dropout(x)rn x = self.relu(x)rn x = self.linear2(x)rn x = self.dropout(x)rn x = self.relu(x)rn out = self.linear3(x)rn return out'), ('language', ''), ('caption', <wagtail.rich_text.RichText object at 0x3e2c0400a460>)])]>
Model initialization and usage
Model initialization in NNX is nearly identical to PyTorch. In both frameworks, when you instantiate an instance of the model class, the model parameters are eagerly (vs. lazily) initialized and tied to the instance itself. The only difference in NNX is that you need to pass in a pseudorandom number generator (PRNG) key when instantiating the model. In keeping with Jax’s functional nature, it avoids implicit global random state, requiring you to explicitly pass PRNG keys. This makes PRNG generation easily reproducible, parallelizable, and vectorizable. See the JAX docs for more details.
- code_block
- <ListValue: [StructValue([('code', '# PyTorch rntorch_model = TitanicNeuralNet(num_hidden_1=32, num_hidden_2=16)'), ('language', ''), ('caption', <wagtail.rich_text.RichText object at 0x3e2c0400adc0>)])]>
- code_block
- <ListValue: [StructValue([('code', '# Flax NNXrnflax_model = TitanicNNX(num_hidden_1=32, num_hidden_2=16, rngs=nnx.Rngs(0))'), ('language', ''), ('caption', <wagtail.rich_text.RichText object at 0x3e2c0400acd0>)])]>
Actually using the models to process a batch of data is equivalent between the two frameworks:
- code_block
- <ListValue: [StructValue([('code', '# PyTorch rntorch_model(sample_data)'), ('language', ''), ('caption', <wagtail.rich_text.RichText object at 0x3e2c0400ae20>)])]>
- code_block
- <ListValue: [StructValue([('code', '# Flax NNXrnflax_model(sample_data)'), ('language', ''), ('caption', <wagtail.rich_text.RichText object at 0x3e2c0400aee0>)])]>
Training step and backpropagation
There are some key differences in training loops between PyTorch and Flax NNX. To demonstrate, let’s build up to the full NNX training loop step by step.
Setup
- code_block
- <ListValue: [StructValue([('code', '# PyTorchrnoptimizer = optim.Adam(model.parameters(), lr=0.01)'), ('language', ''), ('caption', <wagtail.rich_text.RichText object at 0x3e2c0400aac0>)])]>
- code_block
- <ListValue: [StructValue([('code', '# Flax NNXrnoptimizer = nnx.Optimizer(model, optax.adam(learning_rate=0.01))'), ('language', ''), ('caption', <wagtail.rich_text.RichText object at 0x3e2c0400a9d0>)])]>
In both frameworks, we create Optimizers and have the flexibility to specify our optimization algorithm. While PyTorch requires passing in model parameters, Flax NNX allows you to just pass in the model directly and handles all interactions with the underlying Optax optimizer.
Forward + backward pass
- code_block
- <ListValue: [StructValue([('code', '# PyTorchrnlogits = model(batch)rnloss = torch.nn.BCEWithLogitsLoss()(logits.squeeze(), labels)rnloss.backward()'), ('language', ''), ('caption', <wagtail.rich_text.RichText object at 0x3e2c0400a4c0>)])]>
- code_block
- <ListValue: [StructValue([('code', '# Flax NNXrndef loss_fn(model):rn logits = model(batch)rn loss = optax.sigmoid_binary_cross_entropy(logits.squeeze(), labels).mean()rn return lossrngrad_fn = nnx.value_and_grad(loss_fn)rnloss, grads = grad_fn(model)'), ('language', ''), ('caption', <wagtail.rich_text.RichText object at 0x3e2c0400a520>)])]>
Perhaps the biggest difference between PyTorch and JAX is how to do a full forward/backward pass. With PyTorch, you calculate the gradients with loss.backward()
, triggering AutoGrad to follow the computation graph from loss
to compute the gradients.
JAX’s automatic differentiation is instead much closer to the raw math, where you have gradients of functions. Specifically, nnx.value_and_grad
/nnx.grad
take in a function, loss_fn,
and return a function, grad_fn
. Then, grad_fn
itself returns the gradient of the output of loss_fn
with respect to its input.
In our example, loss_fn
is doing exactly what is being done in PyTorch: first, it gets the logits
from the forward pass and then calculates the familiar loss
. From there, grad_fn
calculates the gradient of loss
with respect to the parameters of model
. In mathematical terms, the grads
that are returned are ∂J/∂θ
. This is exactly what is happening in PyTorch under the hood: whereas PyTorch is “storing” the gradients in the tensor’s .grad
attribute when you do loss.backward()
, JAX and Flax NNX follow the functional approach of not mutating state and just return the gradients to you directly.
Optimizer step
- code_block
- <ListValue: [StructValue([('code', '# PyTorchrnoptimizer.step()'), ('language', ''), ('caption', <wagtail.rich_text.RichText object at 0x3e2c0400ab50>)])]>
- code_block
- <ListValue: [StructValue([('code', '# Flax NNXrnoptimizer.update(grads)'), ('language', ''), ('caption', <wagtail.rich_text.RichText object at 0x3e2c0400aa30>)])]>
In PyTorch, optimizer.step()
updates the weights in place using the gradients. NNX also does an in-place update of the weights, but requires the grads
you calculated in the backward pass to be passed in directly. This is the same optimization step that is done in PyTorch, just slightly more explicit — in keeping with Jax’s underlying functional nature.
Full training loop
You now have everything you need to construct a full training loop in JAX/Flax NNX. As a reference, let’s first see the familiar PyTorch loop:
- code_block
- <ListValue: [StructValue([('code', 'def train(model, train_dataloader, eval_dataloader, num_epochs):rn optimizer = optim.Adam(model.parameters(), lr=0.01)rn criterion = torch.nn.BCEWithLogitsLoss()rn for epoch in (pbar := tqdm(range(num_epochs))):rn pbar.set_description(f"Epoch {epoch}")rn model.train()rn for batch, labels in train_dataloader:rn optimizer.zero_grad()rnrn logits = model(batch)rnrn loss = criterion(logits.squeeze(), labels)rnrn loss.backward()rnrn optimizer.step()rnrn pbar.set_postfix(train_accuracy=eval(model, train_dataloader), eval_accuracy=eval(model, eval_dataloader))rnrndef eval(model, eval_dataloader):rn model.eval()rn num_correct = 0rn num_samples = 0rn for batch, labels in eval_dataloader:rn logits = model(batch)rn preds = torch.round(torch.sigmoid(logits))rn num_correct += (preds.squeeze() == labels).sum().item()rn num_samples += labels.shape[0]rn return num_correct / num_samples'), ('language', ''), ('caption', <wagtail.rich_text.RichText object at 0x3e2c0400a550>)])]>
And now the full NNX training loop:
- code_block
- <ListValue: [StructValue([('code', 'import optax rnrndef train(model, train_dataloader, eval_dataloader, num_epochs):rn optimizer = nnx.Optimizer(model, optax.adam(learning_rate=0.01))rnrn for epoch in (pbar := tqdm(range(num_epochs))):rn pbar.set_description(f"Epoch {epoch}")rn model.train()rn for batch in train_dataloader:rn train_step(model, optimizer, batch)rnrn pbar.set_postfix(train_accuracy=eval(model, train_dataloader), eval_accuracy=eval(model, eval_dataloader))[email protected] train_step(model, optimizer, batch):rn def loss_fn(model):rn logits = model(batch[0])rn loss = optax.sigmoid_binary_cross_entropy(logits.squeeze(),batch[1]).mean()rn return lossrn grad_fn = nnx.value_and_grad(loss_fn)rn loss, grads = grad_fn(model)rn optimizer.update(grads)rnrndef eval(model, eval_dataloader):rn model.eval()rn total = 0rn num_correct = 0rn for batch in eval_dataloader:rn res = eval_step(model, batch)rn total += res.shape[0]rn num_correct += jnp.sum(res)rn return num_correct / [email protected] eval_step(model, batch):rn logits = model(batch[0])rn logits = logits.squeeze()rn preds = jnp.round(nnx.sigmoid(logits))rn return preds == batch[1]'), ('language', ''), ('caption', <wagtail.rich_text.RichText object at 0x3e2c0400af40>)])]>
The key takeaway is that the training loops are very similar between PyTorch and JAX/Flax NNX, with most of the differences boiling down to object-oriented versus functional programming. Although there’s a slight learning curve to functional programming and thinking about gradients of functions, it enables many of the aforementioned benefits in JAX, e.g., JIT compilation and automatic parallelization. For example, just adding the @nnx.jit
annotations to the above functions speeds up training the model for 500 epochs from 6.25 minutes to just 1.8 minutes with a P100 GPU on Kaggle! You’ll see similar speedups with the same code across CPUs, TPUs, and even non-NVIDIA GPUs.
Flax Linen reference
As previously mentioned, the JAX ecosystem is very flexible and lets you bring in your framework of choice. Although NNX is the recommended solution for new users, the Flax Linen API is still widely used today, including in powerful frameworks like MaxText and MaxDiffusion. While NNX is far more Pythonic and hides much of the complexity of state management, Linen adheres much more closely to pure functional programming.
Being comfortable with both is greatly beneficial if you want to participate in the JAX ecosystem. To help, let’s replicate much of our NNX code with Linen, and include comments highlighting the main differences.
- code_block
- <ListValue: [StructValue([('code', '# Model definitionrnrn# Input dimensions for relevant layers are inferred during init below rnclass TitanicNeuralNet(nn.Module):rn num_hidden_1: intrn num_hidden_2: intrnrn def setup(self):rn self.linear1 = nn.Dense(features=self.num_hidden_1, kernel_init=initializer)rn self.linear2 = nn.Dense(features=self.num_hidden_2, kernel_init=initializer)rn self.linear3 = nn.Dense(features=1, use_bias=False, kernel_init=initializer)rn self.dropout1 = nn.Dropout(0.01)rn self.dropout2 = nn.Dropout(0.01)rnrn def __call__(self, x, training):rn x = self.linear1(x)rn x = self.dropout1(x, deterministic=not training)rn x = nn.leaky_relu(x)rn x = self.linear2(x)rn x = self.dropout2(x, deterministic=not training)rn x = nn.leaky_relu(x)rn x = self.linear3(x)rn return x'), ('language', ''), ('caption', <wagtail.rich_text.RichText object at 0x3e2c0400a400>)])]>
- code_block
- <ListValue: [StructValue([('code', "# Model Initrnrn# Params are independent of the model definition and init requires sample data for shape inference rnrng = jax.random.PRNGKey(42)rnnew_rng, subkey, subdropout = jax.random.split(rng, num=3)rnflax_model = TitanicNeuralNet(num_hidden_1=32, num_hidden_2=16)rnparams = flax_model.init(subkey, sample_data, True)rnrn# Model is called using apply, and both params and data must be passed in, in very functional programming style. Similarly, you distinguish between train/eval with a boolean and pass in PRNGrnflax_model.apply(params, sample_data, True, rngs={'dropout': subdropout})"), ('language', ''), ('caption', <wagtail.rich_text.RichText object at 0x3e2c0c6917f0>)])]>
- code_block
- <ListValue: [StructValue([('code', '# Full Training Loop rnimport optaxrnfrom flax.training import train_staternrn# TrainState as a convenience wrapper to help with the management of parameters and gradients rnoptimizer = optax.adam(learning_rate=0.01)rnrnstate = train_state.TrainState.create(rn apply_fn=flax_model.apply,rn params=params,rn tx=optimizer,rn)rnrndef train(state, train_dataloader, eval_dataloader, subdropout, num_epochs):rn for epoch in (pbar := tqdm(range(num_epochs))):rn pbar.set_description(f"Epoch {epoch}")rn for batch in train_dataloader:rn state, loss = train_step(state, batch, subdropout)rnrn pbar.set_postfix(train_accuracy=eval(state, train_dataloader), eval_accuracy=eval(state, eval_dataloader))rnrn return staternrndef eval(state, eval_dataloader):rn total = 0rn num_correct = 0rn for batch in eval_dataloader:rn res = eval_step(state, batch)rn total += res.shape[0]rn num_correct += jnp.sum(res)rn return num_correct / totalrnrn@jitrndef train_step(state, batch, subdropout):rn def loss_fn(params):rn logits = state.apply_fn(params, batch[0], True, rngs={'dropout': subdropout})rn loss = optax.sigmoid_binary_cross_entropy(logits.squeeze(), batch[1]).mean()rn return lossrnrn grad_fn = jax.value_and_grad(loss_fn)rn loss, grads = grad_fn(state.params)rn # Pass grads to TrainState to get new TrainState with updated parameters, in functional programming stylern state = state.apply_gradients(grads=grads)rn return state, lossrnrn@jitrndef eval_step(state, batch):rn logits = state.apply_fn(state.params, batch[0], False)rn logits = logits.squeeze()rn preds = jnp.round(nn.sigmoid(logits))rn return preds == batch[1]'), ('language', ''), ('caption', <wagtail.rich_text.RichText object at 0x3e2c0c691520>)])]>
Next steps
With the JAX/Flax knowledge you’ve gained from this blog post, you are now ready to write your own neural network. You can get started right away in Google Colab or Kaggle. Find a challenge on Kaggle and write a brand new model with Flax NNX, or start training a large language model (LLM) with MaxText — the possibilities are endless.
And we have just scratched the surface with JAX and Flax. To learn more about JIT, automatic vectorization, custom gradients, and more, check out the documentation for both JAX and Flax!