Custom constructors

tyro.cli() aims for comprehensive support of standard Python type constructs. It can still, however, be useful to extend the set of suported types.

We provide two complementary approaches for doing so:

  • tyro.conf provides a simple API for specifying custom constructor functions.

  • tyro.constructors provides a more flexible API for defining behavior for different types. There are two categories of types: primitive types are instantiated from a single commandline argument, while struct types are broken down into multiple arguments.

Warning

Custom constructors are useful, but can be verbose and require care. We recommend using them sparingly.

Simple Constructors

For simple custom constructors, we can pass a constructor function into tyro.conf.arg() or tyro.conf.subcommand(). Arguments will be generated by parsing the signature of the constructor function.

In this example, we define custom behavior for instantiating a NumPy array.

 1# 01_simple_constructors.py
 2from typing import Literal
 3
 4import numpy as np
 5from typing_extensions import Annotated
 6
 7import tyro
 8
 9def construct_array(
10    values: tuple[float, ...], dtype: Literal["float32", "float64"] = "float64"
11) -> np.ndarray:
12    """A custom constructor for 1D NumPy arrays."""
13    return np.array(
14        values,
15        dtype={"float32": np.float32, "float64": np.float64}[dtype],
16    )
17
18def main(
19    # We can specify a custom constructor for an argument in `tyro.conf.arg()`.
20    array: Annotated[np.ndarray, tyro.conf.arg(constructor=construct_array)],
21) -> None:
22    print(f"{array=}")
23
24if __name__ == "__main__":
25    tyro.cli(main)
$ python ./01_simple_constructors.py --help
usage: 01_simple_constructors.py [-h] --array.values [FLOAT
                                 [FLOAT ...]] [--array.dtype
{float32,float64}]

╭─ options ─────────────────────────────────────────╮
 -h, --help        show this help message and exit 
╰───────────────────────────────────────────────────╯
╭─ array options ───────────────────────────────────╮
 A custom constructor for 1D NumPy arrays.         
 ─────────────────────────────────────────         
 --array.values [FLOAT [FLOAT ...]]                
                   (required)                      
 --array.dtype {float32,float64}                   
                   (default: float64)              
╰───────────────────────────────────────────────────╯
$ python ./01_simple_constructors.py --array.values 1 2 3
array=array([1., 2., 3.])
$ python ./01_simple_constructors.py --array.values 1 2 3 4 5 --array.dtype float32
array=array([1., 2., 3., 4., 5.], dtype=float32)

Custom Primitive

In this example, we use tyro.constructors to attach a primitive constructor via a runtime annotation.

 1# 02_primitive_annotation.py
 2import json
 3
 4from typing_extensions import Annotated
 5
 6import tyro
 7
 8# A dictionary type, but `tyro` will expect a JSON string from the CLI.
 9JsonDict = Annotated[
10    dict,
11    tyro.constructors.PrimitiveConstructorSpec(
12        # Number of arguments to consume.
13        nargs=1,
14        # Argument name in usage messages.
15        metavar="JSON",
16        # Convert a list of strings to an instance. The length of the list
17        # should match `nargs`.
18        instance_from_str=lambda args: json.loads(args[0]),
19        # Check if an instance is of the expected type. This is only used for
20        # helptext formatting in the presence of union types.
21        is_instance=lambda instance: isinstance(instance, dict),
22        # Convert an instance to a list of strings. This is used for handling
23        # default values that are set in Python. The length of the list should
24        # match `nargs`.
25        str_from_instance=lambda instance: [json.dumps(instance)],
26    ),
27]
28
29def main(
30    dict1: JsonDict,
31    dict2: JsonDict = {"default": None},
32) -> None:
33    print(f"{dict1=}")
34    print(f"{dict2=}")
35
36if __name__ == "__main__":
37    tyro.cli(main)
$ python ./02_primitive_annotation.py --help
usage: 02_primitive_annotation.py [-h] --dict1 JSON [--dict2 JSON]

╭─ options ───────────────────────────────────────────╮
 -h, --help          show this help message and exit 
 --dict1 JSON        (required)                      
 --dict2 JSON        (default: '{"default": null}')  
╰─────────────────────────────────────────────────────╯
$ python ./02_primitive_annotation.py --dict1 '{"hello": "world"}'
dict1={'hello': 'world'}
dict2={'default': None}
$ python ./02_primitive_annotation.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}'
dict1={'hello': 'world'}
dict2={'hello': 'world'}

Custom Primitive (Registry)

In this example, we use a tyro.constructors.ConstructorRegistry to define a rule that applies to all types that match dict[str, Any].

 1# 03_primitive_registry.py
 2import json
 3from typing import Any
 4
 5import tyro
 6
 7# Create a custom registry, which stores constructor rules.
 8custom_registry = tyro.constructors.ConstructorRegistry()
 9
10# Define a rule that applies to all types that match `dict[str, Any]`.
11@custom_registry.primitive_rule
12def _(
13    type_info: tyro.constructors.PrimitiveTypeInfo,
14) -> tyro.constructors.PrimitiveConstructorSpec | None:
15    # We return `None` if the rule does not apply.
16    if type_info.type != dict[str, Any]:
17        return None
18
19    # If the rule applies, we return the constructor spec.
20    return tyro.constructors.PrimitiveConstructorSpec(
21        nargs=1,
22        metavar="JSON",
23        instance_from_str=lambda args: json.loads(args[0]),
24        is_instance=lambda instance: isinstance(instance, dict),
25        str_from_instance=lambda instance: [json.dumps(instance)],
26    )
27
28def main(
29    dict1: dict[str, Any],
30    dict2: dict[str, Any] = {"default": None},
31) -> None:
32    """A function with two arguments, which can be populated from the CLI via JSON."""
33    print(f"{dict1=}")
34    print(f"{dict2=}")
35
36if __name__ == "__main__":
37    # To activate a custom registry, we should use it as a context manager.
38    with custom_registry:
39        tyro.cli(main)
$ python ./03_primitive_registry.py --help
usage: 03_primitive_registry.py [-h] --dict1 JSON [--dict2 JSON]

A function with two arguments, which can be populated from the CLI via JSON.

╭─ options ───────────────────────────────────────────╮
 -h, --help          show this help message and exit 
 --dict1 JSON        (required)                      
 --dict2 JSON        (default: '{"default": null}')  
╰─────────────────────────────────────────────────────╯
$ python ./03_primitive_registry.py --dict1 '{"hello": "world"}'
dict1={'hello': 'world'}
dict2={'default': None}
$ python ./03_primitive_registry.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}'
dict1={'hello': 'world'}
dict2={'hello': 'world'}

Custom Structs (Registry)

In this example, we use a tyro.constructors.ConstructorRegistry to add support for a custom struct type. Each struct type is broken down into multiple fields, which themselves can be either primitive types or nested structs.

Warning

This will be complicated!

 1# 04_struct_registry.py
 2import tyro
 3
 4# A custom type that we'll add support for to tyro.
 5class Bounds:
 6    def __init__(self, lower: int, upper: int):
 7        self.bounds = (lower, upper)
 8
 9    def __repr__(self) -> str:
10        return f"(lower={self.bounds[0]}, upper={self.bounds[1]})"
11
12# Create a custom registry, which stores constructor rules.
13custom_registry = tyro.constructors.ConstructorRegistry()
14
15# Define a rule that applies to all types that match `Bounds`.
16@custom_registry.struct_rule
17def _(
18    type_info: tyro.constructors.StructTypeInfo,
19) -> tyro.constructors.StructConstructorSpec | None:
20    # We return `None` if the rule does not apply.
21    if type_info.type != Bounds:
22        return None
23
24    # We can extract the default value of the field from `type_info`.
25    if isinstance(type_info.default, Bounds):
26        # If the default value is a `Bounds` instance, we don't need to generate a constructor.
27        default = (type_info.default.bounds[0], type_info.default.bounds[1])
28    else:
29        # Otherwise, the default value is missing. We'll mark the child defaults as missing as well.
30        assert type_info.default in (
31            tyro.constructors.MISSING,
32            tyro.constructors.MISSING_NONPROP,
33        )
34        default = (tyro.MISSING, tyro.MISSING)
35
36    # If the rule applies, we return the constructor spec.
37    return tyro.constructors.StructConstructorSpec(
38        # The instantiate function will be called with the fields as keyword arguments.
39        instantiate=Bounds,
40        fields=(
41            tyro.constructors.StructFieldSpec(
42                name="lower",
43                type=int,
44                default=default[0],
45                helptext="Lower bound.",
46            ),
47            tyro.constructors.StructFieldSpec(
48                name="upper",
49                type=int,
50                default=default[1],
51                helptext="Upper bound.",
52            ),
53        ),
54    )
55
56def main(
57    bounds: Bounds,
58    bounds_with_default: Bounds = Bounds(0, 100),
59) -> None:
60    """A function with two `Bounds` instances as input."""
61    print(f"{bounds=}")
62    print(f"{bounds_with_default=}")
63
64if __name__ == "__main__":
65    # To activate a custom registry, we should use it as a context manager.
66    with custom_registry:
67        tyro.cli(main)
$ python ./04_struct_registry.py --help
usage: 04_struct_registry.py [-h] [OPTIONS]

A function with two `Bounds` instances as input.

╭─ options ───────────────────────────────────────────────╮
 -h, --help              show this help message and exit 
╰─────────────────────────────────────────────────────────╯
╭─ bounds options ────────────────────────────────────────╮
 --bounds.lower INT      Lower bound. (required)         
 --bounds.upper INT      Upper bound. (required)         
╰─────────────────────────────────────────────────────────╯
╭─ bounds-with-default options ───────────────────────────╮
 --bounds-with-default.lower INT                         
                         Lower bound. (default: 0)       
 --bounds-with-default.upper INT                         
                         Upper bound. (default: 100)     
╰─────────────────────────────────────────────────────────╯
$ python ./04_struct_registry.py --bounds.lower 5 --bounds.upper 10
bounds=(lower=5, upper=10)
bounds_with_default=(lower=0, upper=100)