Hierarchical Configs#

Structures (typically dataclasses) can be nested to build hierarchical configuration objects. This helps with modularity and grouping in larger projects.

 1import dataclasses
 2import enum
 3import pathlib
 4
 5import tyro
 6
 7
 8class OptimizerType(enum.Enum):
 9    ADAM = enum.auto()
10    SGD = enum.auto()
11
12
13@dataclasses.dataclass
14class OptimizerConfig:
15    # Gradient-based optimizer to use.
16    algorithm: OptimizerType = OptimizerType.ADAM
17
18    # Learning rate to use.
19    learning_rate: float = 3e-4
20
21    # Coefficient for L2 regularization.
22    weight_decay: float = 1e-2
23
24
25@dataclasses.dataclass
26class ExperimentConfig:
27    # Various configurable options for our optimizer.
28    optimizer: OptimizerConfig
29
30    # Batch size.
31    batch_size: int = 32
32
33    # Total number of training steps.
34    train_steps: int = 100_000
35
36    # Random seed. This is helpful for making sure that our experiments are all
37    # reproducible!
38    seed: int = 0
39
40
41def train(
42    out_dir: pathlib.Path,
43    config: ExperimentConfig,
44    restore_checkpoint: bool = False,
45    checkpoint_interval: int = 1000,
46) -> None:
47    """Train a model.
48
49    Args:
50        out_dir: Where to save logs and checkpoints.
51        config: Experiment configuration.
52        restore_checkpoint: Set to restore an existing checkpoint.
53        checkpoint_interval: Training steps between each checkpoint save.
54    """
55    print(f"{out_dir=}, {restore_checkpoint=}, {checkpoint_interval=}")
56    print()
57    print(f"{config=}")
58
59
60if __name__ == "__main__":
61    tyro.cli(train)

python 02_nesting/01_nesting.py --help
usage: 01_nesting.py [-h] [OPTIONS]

Train a model.

╭─ options ──────────────────────────────────────────────────────────────────╮
│ -h, --help              show this help message and exit                    │
│ --out-dir PATH          Where to save logs and checkpoints. (required)     │
│ --restore-checkpoint, --no-restore-checkpoint                              │
│                         Set to restore an existing checkpoint. (default:   │
│                         False)                                             │
│ --checkpoint-interval INT                                                  │
│                         Training steps between each checkpoint save.       │
│                         (default: 1000)                                    │
╰────────────────────────────────────────────────────────────────────────────╯
╭─ config options ───────────────────────────────────────────────────────────╮
│ Experiment configuration.                                                  │
│ ────────────────────────────────────────────────────────────────────────── │
│ --config.batch-size INT                                                    │
│                         Batch size. (default: 32)                          │
│ --config.train-steps INT                                                   │
│                         Total number of training steps. (default: 100000)  │
│ --config.seed INT       Random seed. This is helpful for making sure that  │
│                         our experiments are all reproducible! (default: 0) │
╰────────────────────────────────────────────────────────────────────────────╯
╭─ config.optimizer options ─────────────────────────────────────────────────╮
│ --config.optimizer.algorithm {ADAM,SGD}                                    │
│                         Gradient-based optimizer to use. (default: ADAM)   │
│ --config.optimizer.learning-rate FLOAT                                     │
│                         Learning rate to use. (default: 0.0003)            │
│ --config.optimizer.weight-decay FLOAT                                      │
│                         Coefficient for L2 regularization. (default: 0.01) │
╰────────────────────────────────────────────────────────────────────────────╯

python 02_nesting/01_nesting.py --out-dir . --config.optimizer.algorithm SGD
out_dir=PosixPath('.'), restore_checkpoint=False, checkpoint_interval=1000

config=ExperimentConfig(optimizer=OptimizerConfig(algorithm=<OptimizerType.SGD: 2>, learning_rate=0.0003, weight_decay=0.01), batch_size=32, train_steps=100000, seed=0)

python 02_nesting/01_nesting.py --out-dir . --restore-checkpoint
out_dir=PosixPath('.'), restore_checkpoint=True, checkpoint_interval=1000

config=ExperimentConfig(optimizer=OptimizerConfig(algorithm=<OptimizerType.ADAM: 1>, learning_rate=0.0003, weight_decay=0.01), batch_size=32, train_steps=100000, seed=0)