Hierarchical Configs#
Structures (typically dataclasses) can be nested to build hierarchical configuration objects. This helps with modularity and grouping in larger projects.
1import dataclasses
2import enum
3import pathlib
4
5import tyro
6
7
8class OptimizerType(enum.Enum):
9 ADAM = enum.auto()
10 SGD = enum.auto()
11
12
13@dataclasses.dataclass
14class OptimizerConfig:
15 # Gradient-based optimizer to use.
16 algorithm: OptimizerType = OptimizerType.ADAM
17
18 # Learning rate to use.
19 learning_rate: float = 3e-4
20
21 # Coefficient for L2 regularization.
22 weight_decay: float = 1e-2
23
24
25@dataclasses.dataclass
26class ExperimentConfig:
27 # Various configurable options for our optimizer.
28 optimizer: OptimizerConfig
29
30 # Batch size.
31 batch_size: int = 32
32
33 # Total number of training steps.
34 train_steps: int = 100_000
35
36 # Random seed. This is helpful for making sure that our experiments are all
37 # reproducible!
38 seed: int = 0
39
40
41def train(
42 out_dir: pathlib.Path,
43 config: ExperimentConfig,
44 restore_checkpoint: bool = False,
45 checkpoint_interval: int = 1000,
46) -> None:
47 """Train a model.
48
49 Args:
50 out_dir: Where to save logs and checkpoints.
51 config: Experiment configuration.
52 restore_checkpoint: Set to restore an existing checkpoint.
53 checkpoint_interval: Training steps between each checkpoint save.
54 """
55 print(f"{out_dir=}, {restore_checkpoint=}, {checkpoint_interval=}")
56 print()
57 print(f"{config=}")
58
59
60if __name__ == "__main__":
61 tyro.cli(train)
python 02_nesting/01_nesting.py --help
usage: 01_nesting.py [-h] [OPTIONS] Train a model. ╭─ options ──────────────────────────────────────────────────────────────────╮ │ -h, --help show this help message and exit │ │ --out-dir PATH Where to save logs and checkpoints. (required) │ │ --restore-checkpoint, --no-restore-checkpoint │ │ Set to restore an existing checkpoint. (default: │ │ False) │ │ --checkpoint-interval INT │ │ Training steps between each checkpoint save. │ │ (default: 1000) │ ╰────────────────────────────────────────────────────────────────────────────╯ ╭─ config options ───────────────────────────────────────────────────────────╮ │ Experiment configuration. │ │ ────────────────────────────────────────────────────────────────────────── │ │ --config.batch-size INT │ │ Batch size. (default: 32) │ │ --config.train-steps INT │ │ Total number of training steps. (default: 100000) │ │ --config.seed INT Random seed. This is helpful for making sure that │ │ our experiments are all reproducible! (default: 0) │ ╰────────────────────────────────────────────────────────────────────────────╯ ╭─ config.optimizer options ─────────────────────────────────────────────────╮ │ --config.optimizer.algorithm {ADAM,SGD} │ │ Gradient-based optimizer to use. (default: ADAM) │ │ --config.optimizer.learning-rate FLOAT │ │ Learning rate to use. (default: 0.0003) │ │ --config.optimizer.weight-decay FLOAT │ │ Coefficient for L2 regularization. (default: 0.01) │ ╰────────────────────────────────────────────────────────────────────────────╯
python 02_nesting/01_nesting.py --out-dir . --config.optimizer.algorithm SGD
out_dir=PosixPath('.'), restore_checkpoint=False, checkpoint_interval=1000 config=ExperimentConfig(optimizer=OptimizerConfig(algorithm=<OptimizerType.SGD: 2>, learning_rate=0.0003, weight_decay=0.01), batch_size=32, train_steps=100000, seed=0)
python 02_nesting/01_nesting.py --out-dir . --restore-checkpoint
out_dir=PosixPath('.'), restore_checkpoint=True, checkpoint_interval=1000 config=ExperimentConfig(optimizer=OptimizerConfig(algorithm=<OptimizerType.ADAM: 1>, learning_rate=0.0003, weight_decay=0.01), batch_size=32, train_steps=100000, seed=0)