Overriding Configs¶
In these examples, we show how tyro.cli()
can be used to override values
in pre-instantiated configuration objects.
Dataclasses + Defaults¶
The default=
argument can be used to override default values in dataclass
types.
Note
When default=
is used, we advise against mutation of configuration
objects from a dataclass’s __post_init__
method [1]. In the
example below, __post_init__
would be called twice: once for the
Args()
object provided as a default value and another time for the
Args()
objected instantiated by tyro.cli()
. This can cause
confusing behavior! Instead, we show below one example of how derived
fields can be defined immutably.
1# 01_dataclasses_defaults.py
2import dataclasses
3
4import tyro
5
6@dataclasses.dataclass
7class Args:
8 """Description.
9 This should show up in the helptext!"""
10
11 string: str
12 """A string field."""
13
14 reps: int = 3
15 """A numeric field, with a default value."""
16
17 @property
18 def derived_field(self) -> str:
19 return ", ".join([self.string] * self.reps)
20
21if __name__ == "__main__":
22 args = tyro.cli(
23 Args,
24 default=Args(
25 string="default string",
26 reps=tyro.MISSING,
27 ),
28 )
29 print(args.derived_field)
$ python ./01_dataclasses_defaults.py --help usage: 01_dataclasses_defaults.py [-h] [--string STR] --reps INT Description. This should show up in the helptext! ╭─ options ─────────────────────────────────────────────────────────────╮ │ -h, --help show this help message and exit │ │ --string STR A string field. (default: 'default string') │ │ --reps INT A numeric field, with a default value. (required) │ ╰───────────────────────────────────────────────────────────────────────╯
$ python ./01_dataclasses_defaults.py --reps 3
default string, default string, default string
$ python ./01_dataclasses_defaults.py --string hello --reps 5
hello, hello, hello, hello, hello
Overriding YAML Configs¶
tyro
understands a wide range of data structures, including standard
dictionaries and lists.
If you have a library of existing YAML files that you want to use,
tyro.cli()
can help override values within them.
Note
We recommend dataclass configs for new projects.
1# 02_overriding_yaml.py
2import yaml
3
4import tyro
5
6# YAML configuration. This could also be loaded from a file! Environment
7# variables are an easy way to select between different YAML files.
8default_yaml = r"""
9exp_name: test
10optimizer:
11 learning_rate: 0.0001
12 type: adam
13training:
14 batch_size: 32
15 num_steps: 10000
16 checkpoint_steps:
17 - 500
18 - 1000
19 - 1500
20""".strip()
21
22if __name__ == "__main__":
23 # Convert our YAML config into a nested dictionary.
24 default_config = yaml.safe_load(default_yaml)
25
26 # Override fields in the dictionary.
27 overridden_config = tyro.cli(dict, default=default_config)
28
29 # Print the overridden config.
30 overridden_yaml = yaml.safe_dump(overridden_config)
31 print(overridden_yaml)
$ python ./02_overriding_yaml.py --help usage: 02_overriding_yaml.py [-h] [OPTIONS] ╭─ options ───────────────────────────────────────────────╮ │ -h, --help show this help message and exit │ │ --exp-name STR (default: test) │ ╰─────────────────────────────────────────────────────────╯ ╭─ optimizer options ─────────────────────────────────────╮ │ --optimizer.learning-rate FLOAT │ │ (default: 0.0001) │ │ --optimizer.type STR (default: adam) │ ╰─────────────────────────────────────────────────────────╯ ╭─ training options ──────────────────────────────────────╮ │ --training.batch-size INT │ │ (default: 32) │ │ --training.num-steps INT │ │ (default: 10000) │ │ --training.checkpoint-steps [INT [INT ...]] │ │ (default: 500 1000 1500) │ ╰─────────────────────────────────────────────────────────╯
$ python ./02_overriding_yaml.py --training.checkpoint-steps 300 1000 9000
exp_name: test
optimizer:
learning_rate: 0.0001
type: adam
training:
batch_size: 32
checkpoint_steps:
- 300
- 1000
- 9000
num_steps: 10000
Choosing Base Configs¶
One common pattern is to have a set of “base” configurations, which can be selected from and then overridden.
This is often implemented with a set of configuration files (e.g., YAML files).
With tyro
, we can instead define each base configuration as a separate
Python object.
After creating the base configurations, we can use the CLI to select one of them and then override (existing) or fill in (missing) values.
The helper function used here, tyro.extras.overridable_config_cli()
, is
a lightweight wrapper over tyro.cli()
and its Union-based subcommand
syntax.
1# 03_choosing_base_configs.py
2from dataclasses import dataclass
3from typing import Callable, Literal
4
5from torch import nn
6
7import tyro
8
9@dataclass
10class ExperimentConfig:
11 # Dataset to run experiment on.
12 dataset: Literal["mnist", "imagenet-50"]
13
14 # Model size.
15 num_layers: int
16 units: int
17
18 # Batch size.
19 batch_size: int
20
21 # Total number of training steps.
22 train_steps: int
23
24 # Random seed.
25 seed: int
26
27 # Not specifiable via the commandline.
28 activation: Callable[[], nn.Module]
29
30# We could also define this library using separate YAML files (similar to
31# `config_path`/`config_name` in Hydra), but staying in Python enables seamless
32# type checking + IDE support.
33default_configs = {
34 "small": (
35 "Small experiment.",
36 ExperimentConfig(
37 dataset="mnist",
38 batch_size=2048,
39 num_layers=4,
40 units=64,
41 train_steps=30_000,
42 seed=0,
43 activation=nn.ReLU,
44 ),
45 ),
46 "big": (
47 "Big experiment.",
48 ExperimentConfig(
49 dataset="imagenet-50",
50 batch_size=32,
51 num_layers=8,
52 units=256,
53 train_steps=100_000,
54 seed=0,
55 activation=nn.GELU,
56 ),
57 ),
58}
59if __name__ == "__main__":
60 config = tyro.extras.overridable_config_cli(default_configs)
61 print(config)
Overall helptext:
$ python ./03_choosing_base_configs.py --help usage: 03_choosing_base_configs.py [-h] {small,big} ╭─ options ──────────────────────────────────────────╮ │ -h, --help show this help message and exit │ ╰────────────────────────────────────────────────────╯ ╭─ subcommands ──────────────────────────────────────╮ │ {small,big} │ │ small Small experiment. │ │ big Big experiment. │ ╰────────────────────────────────────────────────────╯
The “small” subcommand:
$ python ./03_choosing_base_configs.py small --help usage: 03_choosing_base_configs.py small [-h] [SMALL OPTIONS] Small experiment. ╭─ options ──────────────────────────────────────────────────────────────────╮ │ -h, --help show this help message and exit │ │ --dataset {mnist,imagenet-50} │ │ Dataset to run experiment on. (default: mnist) │ │ --num-layers INT Model size. (default: 4) │ │ --units INT Model size. (default: 64) │ │ --batch-size INT Batch size. (default: 2048) │ │ --train-steps INT Total number of training steps. (default: 30000) │ │ --seed INT Random seed. (default: 0) │ │ --activation {fixed} Not specifiable via the commandline. (fixed to: │ │ <class 'torch.nn.modules.activation.ReLU'>) │ ╰────────────────────────────────────────────────────────────────────────────╯
$ python ./03_choosing_base_configs.py small --seed 94720
ExperimentConfig(dataset='mnist', num_layers=4, units=64, batch_size=2048, train_steps=30000, seed=94720, activation=<class 'torch.nn.modules.activation.ReLU'>)
The “big” subcommand:
$ python ./03_choosing_base_configs.py big --help usage: 03_choosing_base_configs.py big [-h] [BIG OPTIONS] Big experiment. ╭─ options ──────────────────────────────────────────────────────────────────╮ │ -h, --help show this help message and exit │ │ --dataset {mnist,imagenet-50} │ │ Dataset to run experiment on. (default: │ │ imagenet-50) │ │ --num-layers INT Model size. (default: 8) │ │ --units INT Model size. (default: 256) │ │ --batch-size INT Batch size. (default: 32) │ │ --train-steps INT Total number of training steps. (default: 100000) │ │ --seed INT Random seed. (default: 0) │ │ --activation {fixed} Not specifiable via the commandline. (fixed to: │ │ <class 'torch.nn.modules.activation.GELU'>) │ ╰────────────────────────────────────────────────────────────────────────────╯
$ python ./03_choosing_base_configs.py big --seed 94720
ExperimentConfig(dataset='imagenet-50', num_layers=8, units=256, batch_size=32, train_steps=100000, seed=94720, activation=<class 'torch.nn.modules.activation.GELU'>)