Custom constructors¶
tyro.cli()
aims for comprehensive support of standard Python type
constructs. It can still, however, be useful to extend the set of suported
types.
We provide two complementary approaches for doing so:
tyro.conf
provides a simple API for specifying custom constructor functions.tyro.constructors
provides a more flexible API for defining behavior for different types. There are two categories of types: primitive types are instantiated from a single commandline argument, while struct types are broken down into multiple arguments.
Warning
Custom constructors are useful, but can be verbose and require care. We recommend using them sparingly.
Simple Constructors¶
For simple custom constructors, we can pass a constructor function into
tyro.conf.arg()
or tyro.conf.subcommand()
. Arguments for will be
generated by parsing the signature of the constructor function.
In this example, we use this pattern to define custom behavior for instantiating a NumPy array.
1# 01_simple_constructors.py
2from typing import Literal
3
4import numpy as np
5from typing_extensions import Annotated
6
7import tyro
8
9def construct_array(
10 values: tuple[float, ...], dtype: Literal["float32", "float64"] = "float64"
11) -> np.ndarray:
12 """A custom constructor for 1D NumPy arrays."""
13 return np.array(
14 values,
15 dtype={"float32": np.float32, "float64": np.float64}[dtype],
16 )
17
18def main(
19 # We can specify a custom constructor for an argument in `tyro.conf.arg()`.
20 array: Annotated[np.ndarray, tyro.conf.arg(constructor=construct_array)],
21) -> None:
22 print(f"{array=}")
23
24if __name__ == "__main__":
25 tyro.cli(main)
$ python ./01_simple_constructors.py --help usage: 01_simple_constructors.py [-h] --array.values [FLOAT [FLOAT ...]] [--array.dtype {float32,float64}] ╭─ options ─────────────────────────────────────────╮ │ -h, --help show this help message and exit │ ╰───────────────────────────────────────────────────╯ ╭─ array options ───────────────────────────────────╮ │ A custom constructor for 1D NumPy arrays. │ │ ───────────────────────────────────────── │ │ --array.values [FLOAT [FLOAT ...]] │ │ (required) │ │ --array.dtype {float32,float64} │ │ (default: float64) │ ╰───────────────────────────────────────────────────╯
$ python ./01_simple_constructors.py --array.values 1 2 3
array=array([1., 2., 3.])
$ python ./01_simple_constructors.py --array.values 1 2 3 4 5 --array.dtype float32
array=array([1., 2., 3., 4., 5.], dtype=float32)
Custom Primitive¶
In this example, we use tyro.constructors
to attach a primitive
constructor via a runtime annotation.
1# 02_primitive_annotation.py
2import json
3
4from typing_extensions import Annotated
5
6import tyro
7
8# A dictionary type, but `tyro` will expect a JSON string from the CLI.
9JsonDict = Annotated[
10 dict,
11 tyro.constructors.PrimitiveConstructorSpec(
12 # Number of arguments to consume.
13 nargs=1,
14 # Argument name in usage messages.
15 metavar="JSON",
16 # Convert a list of strings to an instance. The length of the list
17 # should match `nargs`.
18 instance_from_str=lambda args: json.loads(args[0]),
19 # Check if an instance is of the expected type. This is only used for
20 # helptext formatting in the presence of union types.
21 is_instance=lambda instance: isinstance(instance, dict),
22 # Convert an instance to a list of strings. This is used for handling
23 # default values that are set in Python. The length of the list should
24 # match `nargs`.
25 str_from_instance=lambda instance: [json.dumps(instance)],
26 ),
27]
28
29def main(
30 dict1: JsonDict,
31 dict2: JsonDict = {"default": None},
32) -> None:
33 print(f"{dict1=}")
34 print(f"{dict2=}")
35
36if __name__ == "__main__":
37 tyro.cli(main)
$ python ./02_primitive_annotation.py --help usage: 02_primitive_annotation.py [-h] --dict1 JSON [--dict2 JSON] ╭─ options ───────────────────────────────────────────╮ │ -h, --help show this help message and exit │ │ --dict1 JSON (required) │ │ --dict2 JSON (default: '{"default": null}') │ ╰─────────────────────────────────────────────────────╯
$ python ./02_primitive_annotation.py --dict1 '{"hello": "world"}'
dict1={'hello': 'world'}
dict2={'default': None}
$ python ./02_primitive_annotation.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}'
dict1={'hello': 'world'}
dict2={'hello': 'world'}
Custom Primitive (Registry)¶
In this example, we use a tyro.constructors.ConstructorRegistry
to
define a rule that applies to all types that match dict[str, Any]
.
1# 03_primitive_registry.py
2import json
3from typing import Any
4
5import tyro
6
7# Create a custom registry, which stores constructor rules.
8custom_registry = tyro.constructors.ConstructorRegistry()
9
10# Define a rule that applies to all types that match `dict[str, Any]`.
11@custom_registry.primitive_rule
12def _(
13 type_info: tyro.constructors.PrimitiveTypeInfo,
14) -> tyro.constructors.PrimitiveConstructorSpec | None:
15 # We return `None` if the rule does not apply.
16 if type_info.type != dict[str, Any]:
17 return None
18
19 # If the rule applies, we return the constructor spec.
20 return tyro.constructors.PrimitiveConstructorSpec(
21 nargs=1,
22 metavar="JSON",
23 instance_from_str=lambda args: json.loads(args[0]),
24 is_instance=lambda instance: isinstance(instance, dict),
25 str_from_instance=lambda instance: [json.dumps(instance)],
26 )
27
28def main(
29 dict1: dict[str, Any],
30 dict2: dict[str, Any] = {"default": None},
31) -> None:
32 """A function with two arguments, which can be populated from the CLI via JSON."""
33 print(f"{dict1=}")
34 print(f"{dict2=}")
35
36if __name__ == "__main__":
37 # To activate a custom registry, we should use it as a context manager.
38 with custom_registry:
39 tyro.cli(main)
$ python ./03_primitive_registry.py --help usage: 03_primitive_registry.py [-h] --dict1 JSON [--dict2 JSON] A function with two arguments, which can be populated from the CLI via JSON. ╭─ options ───────────────────────────────────────────╮ │ -h, --help show this help message and exit │ │ --dict1 JSON (required) │ │ --dict2 JSON (default: '{"default": null}') │ ╰─────────────────────────────────────────────────────╯
$ python ./03_primitive_registry.py --dict1 '{"hello": "world"}'
dict1={'hello': 'world'}
dict2={'default': None}
$ python ./03_primitive_registry.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}'
dict1={'hello': 'world'}
dict2={'hello': 'world'}
Custom Structs (Registry)¶
In this example, we use a tyro.constructors.ConstructorRegistry
to
add support for a custom type.
Warning
This will be complicated!
1# 04_struct_registry.py
2import tyro
3
4# A custom type that we'll add support for to tyro.
5class Bounds:
6 def __init__(self, lower: int, upper: int):
7 self.bounds = (lower, upper)
8
9 def __repr__(self) -> str:
10 return f"(lower={self.bounds[0]}, upper={self.bounds[1]})"
11
12# Create a custom registry, which stores constructor rules.
13custom_registry = tyro.constructors.ConstructorRegistry()
14
15# Define a rule that applies to all types that match `Bounds`.
16@custom_registry.struct_rule
17def _(
18 type_info: tyro.constructors.StructTypeInfo,
19) -> tyro.constructors.StructConstructorSpec | None:
20 # We return `None` if the rule does not apply.
21 if type_info.type != Bounds:
22 return None
23
24 # We can extract the default value of the field from `type_info`.
25 if isinstance(type_info.default, Bounds):
26 # If the default value is a `Bounds` instance, we don't need to generate a constructor.
27 default = (type_info.default.bounds[0], type_info.default.bounds[1])
28 is_default_overridden = True
29 else:
30 # Otherwise, the default value is missing. We'll mark the child defaults as missing as well.
31 assert type_info.default in (
32 tyro.constructors.MISSING,
33 tyro.constructors.MISSING_NONPROP,
34 )
35 default = (tyro.MISSING, tyro.MISSING)
36 is_default_overridden = False
37
38 # If the rule applies, we return the constructor spec.
39 return tyro.constructors.StructConstructorSpec(
40 # The instantiate function will be called with the fields as keyword arguments.
41 instantiate=Bounds,
42 fields=(
43 tyro.constructors.StructFieldSpec(
44 name="lower",
45 type=int,
46 default=default[0],
47 is_default_overridden=is_default_overridden,
48 helptext="Lower bound." "",
49 ),
50 tyro.constructors.StructFieldSpec(
51 name="upper",
52 type=int,
53 default=default[1],
54 is_default_overridden=is_default_overridden,
55 helptext="Upper bound." "",
56 ),
57 ),
58 )
59
60def main(
61 bounds: Bounds,
62 bounds_with_default: Bounds = Bounds(0, 100),
63) -> None:
64 """A function with two `Bounds` instances as input."""
65 print(f"{bounds=}")
66 print(f"{bounds_with_default=}")
67
68if __name__ == "__main__":
69 # To activate a custom registry, we should use it as a context manager.
70 with custom_registry:
71 tyro.cli(main)
$ python ./04_struct_registry.py --help usage: 04_struct_registry.py [-h] [OPTIONS] A function with two `Bounds` instances as input. ╭─ options ───────────────────────────────────────────────╮ │ -h, --help show this help message and exit │ ╰─────────────────────────────────────────────────────────╯ ╭─ bounds options ────────────────────────────────────────╮ │ --bounds.lower INT Lower bound. (required) │ │ --bounds.upper INT Upper bound. (required) │ ╰─────────────────────────────────────────────────────────╯ ╭─ bounds-with-default options ───────────────────────────╮ │ --bounds-with-default.lower INT │ │ Lower bound. (default: 0) │ │ --bounds-with-default.upper INT │ │ Upper bound. (default: 100) │ ╰─────────────────────────────────────────────────────────╯
$ python ./04_struct_registry.py --bounds.lower 5 --bounds.upper 10
bounds=(lower=5, upper=10)
bounds_with_default=(lower=0, upper=100)