PyTorch / JAX¶
In these examples, we show some patterns for using tyro.cli()
with PyTorch and JAX.
PyTorch Parallelism¶
The console_outputs=
argument can be set to False
to suppress helptext and
error message printing.
This is useful in PyTorch for distributed training scripts, where you only want to print helptext from the main process:
# HuggingFace Accelerate.
args = tyro.cli(Args, console_outputs=accelerator.is_main_process)
# PyTorch DDP.
args = tyro.cli(Args, console_outputs=(rank == 0))
# PyTorch Lightning.
args = tyro.cli(Args, console_outputs=trainer.is_global_zero)
1# 01_pytorch_parallelism.py
2import tyro
3
4def train(foo: int, bar: str) -> None:
5 """Description. This should show up in the helptext!"""
6
7if __name__ == "__main__":
8 args = tyro.cli(train, console_outputs=False)
9 print(args)
$ python ./01_pytorch_parallelism.py --help
JAX/Flax Integration¶
If you use flax.linen, modules can be instantiated
directly from tyro.cli()
.
1# 02_flax.py
2from flax import linen as nn
3from jax import numpy as jnp
4
5import tyro
6
7class Classifier(nn.Module):
8 layers: int
9 """Layers in our network."""
10 units: int = 32
11 """Hidden unit count."""
12 output_dim: int = 10
13 """Number of classes."""
14
15 @nn.compact
16 def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # type: ignore
17 for i in range(self.layers - 1):
18 x = nn.Dense(
19 self.units,
20 kernel_init=nn.initializers.kaiming_normal(),
21 )(x)
22 x = nn.relu(x)
23
24 x = nn.Dense(
25 self.output_dim,
26 kernel_init=nn.initializers.xavier_normal(),
27 )(x)
28 x = nn.sigmoid(x)
29 return x
30
31def train(model: Classifier, num_iterations: int = 1000) -> None:
32 """Train a model.
33
34 Args:
35 model: Model to train.
36 num_iterations: Number of training iterations.
37 """
38 print(f"{model=}")
39 print(f"{num_iterations=}")
40
41if __name__ == "__main__":
42 tyro.cli(train)
$ python ./02_flax.py --help usage: 02_flax.py [-h] [OPTIONS] Train a model. ╭─ options ──────────────────────────────────────────────────────────────╮ │ -h, --help show this help message and exit │ │ --num-iterations INT Number of training iterations. (default: 1000) │ ╰────────────────────────────────────────────────────────────────────────╯ ╭─ model options ────────────────────────────────────────────────────────╮ │ Model to train. │ │ ───────────────────────────────────────────────────────── │ │ --model.layers INT Layers in our network. (required) │ │ --model.units INT Hidden unit count. (default: 32) │ │ --model.output-dim INT Number of classes. (default: 10) │ ╰────────────────────────────────────────────────────────────────────────╯
$ python ./02_flax.py --model.layers 4
model=Classifier(
# attributes
layers = 4
units = 32
output_dim = 10
)
num_iterations=1000