PyTorch / JAX

In these examples, we show some patterns for using tyro.cli() with PyTorch and JAX.

PyTorch Parallelism

The console_outputs= argument can be set to False to suppress helptext and error message printing.

This is useful in PyTorch for distributed training scripts, where you only want to print the helptext from the main process:

# HuggingFace Accelerate.
args = tyro.cli(Args, console_outputs=accelerator.is_main_process)

# PyTorch DDP.
args = tyro.cli(Args, console_outputs=(rank == 0))

# PyTorch Lightning.
args = tyro.cli(Args, console_outputs=trainer.is_global_zero)
 1# 01_pytorch_parallelism.py
 2import dataclasses
 3
 4import tyro
 5
 6@dataclasses.dataclass
 7class Args:
 8    """Description.
 9    This should show up in the helptext!"""
10
11    field1: int
12    """A field."""
13
14    field2: int = 3
15    """A numeric field, with a default value."""
16
17if __name__ == "__main__":
18    args = tyro.cli(Args, console_outputs=False)
19    print(args)
$ python ./01_pytorch_parallelism.py --help

JAX/Flax Integration

If you use flax.linen, modules can be instantiated directly from tyro.cli().

 1# 02_flax.py
 2from flax import linen as nn
 3from jax import numpy as jnp
 4
 5import tyro
 6
 7class Classifier(nn.Module):
 8    layers: int
 9    """Layers in our network."""
10    units: int = 32
11    """Hidden unit count."""
12    output_dim: int = 10
13    """Number of classes."""
14
15    @nn.compact
16    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:  # type: ignore
17        for i in range(self.layers - 1):
18            x = nn.Dense(
19                self.units,
20                kernel_init=nn.initializers.kaiming_normal(),
21            )(x)
22            x = nn.relu(x)
23
24        x = nn.Dense(
25            self.output_dim,
26            kernel_init=nn.initializers.xavier_normal(),
27        )(x)
28        x = nn.sigmoid(x)
29        return x
30
31def train(model: Classifier, num_iterations: int = 1000) -> None:
32    """Train a model.
33
34    Args:
35        model: Model to train.
36        num_iterations: Number of training iterations.
37    """
38    print(f"{model=}")
39    print(f"{num_iterations=}")
40
41if __name__ == "__main__":
42    tyro.cli(train)
$ python ./02_flax.py --help
usage: 02_flax.py [-h] [OPTIONS]

Train a model.

╭─ options ──────────────────────────────────────────────────────────────╮
 -h, --help              show this help message and exit                
 --num-iterations INT    Number of training iterations. (default: 1000) 
╰────────────────────────────────────────────────────────────────────────╯
╭─ model options ────────────────────────────────────────────────────────╮
 Model to train.                                                        
 ─────────────────────────────────────────────────────────              
 --model.layers INT      Layers in our network. (required)              
 --model.units INT       Hidden unit count. (default: 32)               
 --model.output-dim INT  Number of classes. (default: 10)               
╰────────────────────────────────────────────────────────────────────────╯
$ python ./02_flax.py --model.layers 4
model=Classifier(
    # attributes
    layers = 4
    units = 32
    output_dim = 10
)
num_iterations=1000