Sequenced Subcommands#

Multiple unions over nested types are populated using a series of subcommands.

 1from __future__ import annotations
 2
 3import dataclasses
 4from typing import Literal
 5
 6import tyro
 7
 8# Possible dataset configurations.
 9
10
11@dataclasses.dataclass
12class Mnist:
13    binary: bool = False
14    """Set to load binary version of MNIST dataset."""
15
16
17@dataclasses.dataclass
18class ImageNet:
19    subset: Literal[50, 100, 1000]
20    """Choose between ImageNet-50, ImageNet-100, ImageNet-1000, etc."""
21
22
23# Possible optimizer configurations.
24
25
26@dataclasses.dataclass
27class Adam:
28    learning_rate: float = 1e-3
29    betas: tuple[float, float] = (0.9, 0.999)
30
31
32@dataclasses.dataclass
33class Sgd:
34    learning_rate: float = 3e-4
35
36
37# Train script.
38
39
40@tyro.conf.configure(tyro.conf.ConsolidateSubcommandArgs)
41def train(
42    dataset: Mnist | ImageNet = Mnist(),
43    optimizer: Adam | Sgd = Adam(),
44) -> None:
45    """Example training script.
46
47    Args:
48        dataset: Dataset to train on.
49        optimizer: Optimizer to train with.
50
51    Returns:
52        None:
53    """
54    print(dataset)
55    print(optimizer)
56
57
58if __name__ == "__main__":
59    tyro.cli(train)

python 02_nesting/03_multiple_subcommands.py --help
usage: 03_multiple_subcommands.py [-h] {dataset:mnist,dataset:image-net}

Example training script.

╭─ options ─────────────────────────────────────────╮
│ -h, --help        show this help message and exit │
╰───────────────────────────────────────────────────╯
╭─ subcommands ─────────────────────────────────────╮
│ Dataset to train on.                              │
│ ─────────────────────────────────                 │
│ {dataset:mnist,dataset:image-net}                 │
│     dataset:mnist                                 │
│     dataset:image-net                             │
╰───────────────────────────────────────────────────╯

python 02_nesting/03_multiple_subcommands.py dataset:mnist --help
usage: 03_multiple_subcommands.py dataset:mnist [-h]
                                                {optimizer:adam,optimizer:sgd}

╭─ options ─────────────────────────────────────────╮
│ -h, --help        show this help message and exit │
╰───────────────────────────────────────────────────╯
╭─ subcommands ─────────────────────────────────────╮
│ Optimizer to train with.                          │
│ ──────────────────────────────                    │
│ {optimizer:adam,optimizer:sgd}                    │
│     optimizer:adam                                │
│     optimizer:sgd                                 │
╰───────────────────────────────────────────────────╯

python 02_nesting/03_multiple_subcommands.py dataset:mnist optimizer:adam --help
usage: 03_multiple_subcommands.py dataset:mnist optimizer:adam
       [-h] [--optimizer.learning-rate FLOAT] [--optimizer.betas FLOAT FLOAT]
       [--dataset.binary | --dataset.no-binary]

╭─ options ─────────────────────────────────────────────────────────╮
│ -h, --help                                                        │
│     show this help message and exit                               │
╰───────────────────────────────────────────────────────────────────╯
╭─ optimizer options ───────────────────────────────────────────────╮
│ --optimizer.learning-rate FLOAT                                   │
│     (default: 0.001)                                              │
│ --optimizer.betas FLOAT FLOAT                                     │
│     (default: 0.9 0.999)                                          │
╰───────────────────────────────────────────────────────────────────╯
╭─ dataset options ─────────────────────────────────────────────────╮
│ --dataset.binary, --dataset.no-binary                             │
│     Set to load binary version of MNIST dataset. (default: False) │
╰───────────────────────────────────────────────────────────────────╯

python 02_nesting/03_multiple_subcommands.py dataset:mnist optimizer:adam --optimizer.learning-rate 3e-4 --dataset.binary
Mnist(binary=True)
Adam(learning_rate=0.0003, betas=(0.9, 0.999))