JAX/Flax Integration#

If you use flax.linen, modules can be instantiated directly from tyro.cli.

 1from flax import linen as nn
 2from jax import numpy as jnp
 3
 4import tyro
 5
 6
 7class Classifier(nn.Module):
 8    layers: int
 9    """Layers in our network."""
10    units: int = 32
11    """Hidden unit count."""
12    output_dim: int = 10
13    """Number of classes."""
14
15    @nn.compact
16    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:  # type: ignore
17        for i in range(self.layers - 1):
18            x = nn.Dense(
19                self.units,
20                kernel_init=nn.initializers.kaiming_normal(),
21            )(x)
22            x = nn.relu(x)
23
24        x = nn.Dense(
25            self.output_dim,
26            kernel_init=nn.initializers.xavier_normal(),
27        )(x)
28        x = nn.sigmoid(x)
29        return x
30
31
32def train(model: Classifier, num_iterations: int = 1000) -> None:
33    """Train a model.
34
35    Args:
36        model: Model to train.
37        num_iterations: Number of training iterations.
38    """
39    print(f"{model=}")
40    print(f"{num_iterations=}")
41
42
43if __name__ == "__main__":
44    tyro.cli(train)

python 04_additional/10_flax.py --help
usage: 10_flax.py [-h] [OPTIONS]

Train a model.

╭─ options ──────────────────────────────────────────────────────────────╮
│ -h, --help              show this help message and exit                │
│ --num-iterations INT    Number of training iterations. (default: 1000) │
╰────────────────────────────────────────────────────────────────────────╯
╭─ model options ────────────────────────────────────────────────────────╮
│ Model to train.                                                        │
│ ─────────────────────────────────────────────────────────              │
│ --model.layers INT      Layers in our network. (required)              │
│ --model.units INT       Hidden unit count. (default: 32)               │
│ --model.output-dim INT  Number of classes. (default: 10)               │
╰────────────────────────────────────────────────────────────────────────╯

python 04_additional/10_flax.py --model.layers 4
model=Classifier(
    # attributes
    layers = 4
    units = 32
    output_dim = 10
)
num_iterations=1000