PyTorch / JAX¶
In these examples, we show some patterns for using tyro.cli()
with PyTorch and JAX.
PyTorch Parallelism¶
The console_outputs=
argument can be set to False
to suppress helptext and
error message printing.
This is useful in PyTorch for distributed training scripts, where you only want to print the helptext from the main process:
# HuggingFace Accelerate.
args = tyro.cli(Args, console_outputs=accelerator.is_main_process)
# PyTorch DDP.
args = tyro.cli(Args, console_outputs=(rank == 0))
# PyTorch Lightning.
args = tyro.cli(Args, console_outputs=trainer.is_global_zero)
1# 01_pytorch_parallelism.py
2import dataclasses
3
4import tyro
5
6@dataclasses.dataclass
7class Args:
8 """Description.
9 This should show up in the helptext!"""
10
11 field1: int
12 """A field."""
13
14 field2: int = 3
15 """A numeric field, with a default value."""
16
17if __name__ == "__main__":
18 args = tyro.cli(Args, console_outputs=False)
19 print(args)
$ python ./01_pytorch_parallelism.py --help
JAX/Flax Integration¶
If you use flax.linen, modules can be instantiated
directly from tyro.cli()
.
1# 02_flax.py
2from flax import linen as nn
3from jax import numpy as jnp
4
5import tyro
6
7class Classifier(nn.Module):
8 layers: int
9 """Layers in our network."""
10 units: int = 32
11 """Hidden unit count."""
12 output_dim: int = 10
13 """Number of classes."""
14
15 @nn.compact
16 def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # type: ignore
17 for i in range(self.layers - 1):
18 x = nn.Dense(
19 self.units,
20 kernel_init=nn.initializers.kaiming_normal(),
21 )(x)
22 x = nn.relu(x)
23
24 x = nn.Dense(
25 self.output_dim,
26 kernel_init=nn.initializers.xavier_normal(),
27 )(x)
28 x = nn.sigmoid(x)
29 return x
30
31def train(model: Classifier, num_iterations: int = 1000) -> None:
32 """Train a model.
33
34 Args:
35 model: Model to train.
36 num_iterations: Number of training iterations.
37 """
38 print(f"{model=}")
39 print(f"{num_iterations=}")
40
41if __name__ == "__main__":
42 tyro.cli(train)
$ python ./02_flax.py --help usage: 02_flax.py [-h] [OPTIONS] Train a model. ╭─ options ──────────────────────────────────────────────────────────────╮ │ -h, --help show this help message and exit │ │ --num-iterations INT Number of training iterations. (default: 1000) │ ╰────────────────────────────────────────────────────────────────────────╯ ╭─ model options ────────────────────────────────────────────────────────╮ │ Model to train. │ │ ───────────────────────────────────────────────────────── │ │ --model.layers INT Layers in our network. (required) │ │ --model.units INT Hidden unit count. (default: 32) │ │ --model.output-dim INT Number of classes. (default: 10) │ ╰────────────────────────────────────────────────────────────────────────╯
$ python ./02_flax.py --model.layers 4
model=Classifier(
# attributes
layers = 4
units = 32
output_dim = 10
)
num_iterations=1000