.. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. .. _example-category-pytorch_jax: PyTorch / JAX ============= In these examples, we show some patterns for using :func:`tyro.cli` with PyTorch and JAX. .. _example-01_pytorch_parallelism: PyTorch Parallelism ------------------- The :code:`console_outputs=` argument can be set to :code:`False` to suppress helptext and error message printing. This is useful in PyTorch for distributed training scripts, where you only want to print the helptext from the main process: .. code-block:: python # HuggingFace Accelerate. args = tyro.cli(Args, console_outputs=accelerator.is_main_process) # PyTorch DDP. args = tyro.cli(Args, console_outputs=(rank == 0)) # PyTorch Lightning. args = tyro.cli(Args, console_outputs=trainer.is_global_zero) .. code-block:: python :linenos: # 01_pytorch_parallelism.py import dataclasses import tyro @dataclasses.dataclass class Args: """Description. This should show up in the helptext!""" field1: int """A field.""" field2: int = 3 """A numeric field, with a default value.""" if __name__ == "__main__": args = tyro.cli(Args, console_outputs=False) print(args) .. raw:: html
    $ python ./01_pytorch_parallelism.py --help
    
.. _example-02_flax: JAX/Flax Integration -------------------- If you use `flax.linen `_, modules can be instantiated directly from :func:`tyro.cli()`. .. code-block:: python :linenos: # 02_flax.py from flax import linen as nn from jax import numpy as jnp import tyro class Classifier(nn.Module): layers: int """Layers in our network.""" units: int = 32 """Hidden unit count.""" output_dim: int = 10 """Number of classes.""" @nn.compact def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # type: ignore for i in range(self.layers - 1): x = nn.Dense( self.units, kernel_init=nn.initializers.kaiming_normal(), )(x) x = nn.relu(x) x = nn.Dense( self.output_dim, kernel_init=nn.initializers.xavier_normal(), )(x) x = nn.sigmoid(x) return x def train(model: Classifier, num_iterations: int = 1000) -> None: """Train a model. Args: model: Model to train. num_iterations: Number of training iterations. """ print(f"{model=}") print(f"{num_iterations=}") if __name__ == "__main__": tyro.cli(train) .. raw:: html
    $ python ./02_flax.py --help
    usage: 02_flax.py [-h] [OPTIONS]
    
    Train a model.
    
    ╭─ options ──────────────────────────────────────────────────────────────╮
     -h, --help              show this help message and exit                
     --num-iterations INT    Number of training iterations. (default: 1000) 
    ╰────────────────────────────────────────────────────────────────────────╯
    ╭─ model options ────────────────────────────────────────────────────────╮
     Model to train.                                                        
     ─────────────────────────────────────────────────────────              
     --model.layers INT      Layers in our network. (required)              
     --model.units INT       Hidden unit count. (default: 32)               
     --model.output-dim INT  Number of classes. (default: 10)               
    ╰────────────────────────────────────────────────────────────────────────╯
    
.. raw:: html
    $ python ./02_flax.py --model.layers 4
    model=Classifier(
        # attributes
        layers = 4
        units = 32
        output_dim = 10
    )
    num_iterations=1000