.. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. .. _example-category-pytorch_jax: PyTorch / JAX ============= In these examples, we show some patterns for using :func:`tyro.cli` with PyTorch and JAX. .. _example-01_pytorch_parallelism: PyTorch Parallelism ------------------- The :code:`console_outputs=` argument can be set to :code:`False` to suppress helptext and error message printing. This is useful in PyTorch for distributed training scripts, where you only want to print the helptext from the main process: .. code-block:: python # HuggingFace Accelerate. args = tyro.cli(Args, console_outputs=accelerator.is_main_process) # PyTorch DDP. args = tyro.cli(Args, console_outputs=(rank == 0)) # PyTorch Lightning. args = tyro.cli(Args, console_outputs=trainer.is_global_zero) .. code-block:: python :linenos: # 01_pytorch_parallelism.py import dataclasses import tyro @dataclasses.dataclass class Args: """Description. This should show up in the helptext!""" field1: int """A field.""" field2: int = 3 """A numeric field, with a default value.""" if __name__ == "__main__": args = tyro.cli(Args, console_outputs=False) print(args) .. raw:: html
$ python ./01_pytorch_parallelism.py --help
.. _example-02_flax:
JAX/Flax Integration
--------------------
If you use `flax.linen $ python ./02_flax.py --help usage: 02_flax.py [-h] [OPTIONS] Train a model. ╭─ options ──────────────────────────────────────────────────────────────╮ │ -h, --help show this help message and exit │ │ --num-iterations INT Number of training iterations. (default: 1000) │ ╰────────────────────────────────────────────────────────────────────────╯ ╭─ model options ────────────────────────────────────────────────────────╮ │ Model to train. │ │ ───────────────────────────────────────────────────────── │ │ --model.layers INT Layers in our network. (required) │ │ --model.units INT Hidden unit count. (default: 32) │ │ --model.output-dim INT Number of classes. (default: 10) │ ╰────────────────────────────────────────────────────────────────────────╯.. raw:: html
$ python ./02_flax.py --model.layers 4
model=Classifier(
# attributes
layers = 4
units = 32
output_dim = 10
)
num_iterations=1000