JAX/Flax Integration#
If you use flax.linen, modules can be instantiated
directly from tyro.cli
.
1from flax import linen as nn
2from jax import numpy as jnp
3
4import tyro
5
6
7class Classifier(nn.Module):
8 layers: int
9 """Layers in our network."""
10 units: int = 32
11 """Hidden unit count."""
12 output_dim: int = 10
13 """Number of classes."""
14
15 @nn.compact
16 def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # type: ignore
17 for i in range(self.layers - 1):
18 x = nn.Dense(
19 self.units,
20 kernel_init=nn.initializers.kaiming_normal(),
21 )(x)
22 x = nn.relu(x)
23
24 x = nn.Dense(
25 self.output_dim,
26 kernel_init=nn.initializers.xavier_normal(),
27 )(x)
28 x = nn.sigmoid(x)
29 return x
30
31
32def train(model: Classifier, num_iterations: int = 1000) -> None:
33 """Train a model.
34
35 Args:
36 model: Model to train.
37 num_iterations: Number of training iterations.
38 """
39 print(f"{model=}")
40 print(f"{num_iterations=}")
41
42
43if __name__ == "__main__":
44 tyro.cli(train)
python 04_additional/10_flax.py --help
usage: 10_flax.py [-h] [OPTIONS] Train a model. ╭─ options ──────────────────────────────────────────────────────────────╮ │ -h, --help show this help message and exit │ │ --num-iterations INT Number of training iterations. (default: 1000) │ ╰────────────────────────────────────────────────────────────────────────╯ ╭─ model options ────────────────────────────────────────────────────────╮ │ Model to train. │ │ ───────────────────────────────────────────────────────── │ │ --model.layers INT Layers in our network. (required) │ │ --model.units INT Hidden unit count. (default: 32) │ │ --model.output-dim INT Number of classes. (default: 10) │ ╰────────────────────────────────────────────────────────────────────────╯
python 04_additional/10_flax.py --model.layers 4
model=Classifier( # attributes layers = 4 units = 32 output_dim = 10 ) num_iterations=1000