Source code for tyro._parsers
"""Interface for generating `argparse.ArgumentParser()` definitions from callables."""
from __future__ import annotations
import dataclasses
import numbers
import warnings
from typing import Any, Callable, Dict, List, Set, Tuple, Type, TypeVar, Union, cast
from typing_extensions import Annotated, get_args, get_origin
from tyro.constructors._registry import ConstructorRegistry
from tyro.constructors._struct_spec import (
InvalidDefaultInstanceError,
UnsupportedStructTypeMessage,
)
from . import (
_arguments,
_docstrings,
_fields,
_resolver,
_singleton,
_strings,
_subcommand_matching,
)
from . import _fmtlib as fmt
from ._typing import TypeForm
from ._typing_compat import is_typing_union
from .conf import _confstruct, _markers
from .constructors._primitive_spec import (
PrimitiveConstructorSpec,
UnsupportedTypeAnnotationError,
)
T = TypeVar("T")
[docs]
@dataclasses.dataclass()
class LazyParserSpecification:
"""Lazy wrapper that defers full ParserSpecification creation until needed.
Stores lightweight metadata (description) for fast help text generation,
while deferring expensive parser construction until actually needed.
"""
# Lightweight field needed for tyro help formatting.
description: str
# Factory for creating the full parser when needed.
_factory: Callable[[], ParserSpecification]
_cached: ParserSpecification | None = dataclasses.field(default=None, init=False)
[docs]
def evaluate(self) -> ParserSpecification:
"""Get the full ParserSpecification, creating it if needed."""
if self._cached is None:
self._cached = self._factory()
return self._cached
[docs]
@dataclasses.dataclass()
class ArgWithContext:
arg: _arguments.ArgumentDefinition
source_parser: ParserSpecification
"""ParserSpecification that directly contains this argument."""
local_root_parser: ParserSpecification
"""Furthest ancestor of `source_parser` within the same (sub)command."""
[docs]
@dataclasses.dataclass(frozen=True)
class ParserSpecification:
"""Each parser contains a list of arguments and optionally some subparsers."""
f: Callable
markers: Set[_markers._Marker]
description: str
args: List[_arguments.ArgumentDefinition]
field_list: List[_fields.FieldDefinition]
child_from_prefix: Dict[str, ParserSpecification]
helptext_from_intern_prefixed_field_name: Dict[
str, str | Callable[[], str | None] | None
]
# Subparser groups that are direct children of this parser. The actual tree
# structure for argparse is built on-demand in apply().
subparsers_from_intern_prefix: Dict[str, SubparsersSpecification]
intern_prefix: str
extern_prefix: str
has_required_args: bool
subparser_parent: ParserSpecification | None
prog_suffix: str
[docs]
@staticmethod
def from_callable_or_type(
f: Callable[..., T],
markers: Set[_markers._Marker],
description: str | Callable[[], str | None] | None,
parent_classes: Set[Type[Any]],
default_instance: Union[
T, _singleton.PropagatingMissingType, _singleton.NonpropagatingMissingType
],
intern_prefix: str,
extern_prefix: str,
subcommand_prefix: str,
support_single_arg_types: bool,
prog_suffix: str,
) -> ParserSpecification:
"""Create a parser definition from a callable or type."""
# Consolidate subcommand types.
f_unwrapped, new_markers = _resolver.unwrap_annotated(f, _markers._Marker)
markers = markers | set(new_markers)
# Cycle detection.
#
# - `parent` here refers to in the nesting hierarchy, not the superclass.
# - We threshold by `max_nesting_depth` to suppress false positives,
# for example from custom constructors that behave differently
# depending the default value. (example: ml_collections.ConfigDict)
max_nesting_depth = 128
if (
f in parent_classes
and f is not dict
and intern_prefix.count(".") > max_nesting_depth
):
raise UnsupportedTypeAnnotationError(
(
fmt.text(
"Found a cyclic dependency with type ",
fmt.text["cyan"](str(f)),
),
)
)
# TODO: we are abusing the (minor) distinctions between types, classes, and
# callables throughout the code. This is mostly for legacy reasons, could be
# cleaned up.
parent_classes = parent_classes | {cast(Type, f)}
# Wrap our type with a dummy dataclass if it can't be treated as a
# nested type. For example: passing in f=int will result in a dataclass
# with a single field typed as int.
#
# Why don't we always use a dummy dataclass?
# => Docstrings for inner structs are currently lost when we nest struct types.
from . import _calling
# Resolve the type of `f`, generate a field list.
# Try once first to avoid calling field_list_from_type_or_callable twice.
f_for_field_list = f
default_instance_for_field_list = default_instance
# Check if we need DummyWrapper by trying to get fields first.
with _fields.FieldDefinition.marker_context(tuple(markers)):
out = _fields.field_list_from_type_or_callable(
f=f_for_field_list,
default_instance=default_instance_for_field_list,
support_single_arg_types=support_single_arg_types,
in_union_context=False,
)
# If not a struct type and not None, wrap in DummyWrapper and try again.
if isinstance(
out, UnsupportedStructTypeMessage
) and f_unwrapped is not type(None):
try:
f_for_field_list = _calling.DummyWrapper[f] # type: ignore
default_instance_for_field_list = _calling.DummyWrapper(
default_instance
) # type: ignore
out = _fields.field_list_from_type_or_callable(
f=f_for_field_list,
default_instance=default_instance_for_field_list,
support_single_arg_types=support_single_arg_types,
in_union_context=False,
)
except TypeError as e: # pragma: no cover
# In Python 3.8, DummyWrapper[f] raises TypeError if f is not a valid type.
# (e.g., "Parameters to generic types must be types. Got 5.")
raise UnsupportedTypeAnnotationError(
(
fmt.text(
"Expected a type, class, or callable, but got ",
fmt.text["cyan"](repr(f)),
".",
),
)
) from e
assert not isinstance(out, UnsupportedStructTypeMessage), out.message
assert not isinstance(out, InvalidDefaultInstanceError), "\n".join(
repr(fmt.rows(*out.message))
)
f, field_list = out
has_required_args = False
args: list[_arguments.ArgumentDefinition] = []
helptext_from_intern_prefixed_field_name: Dict[
str, str | Callable[[], str | None] | None
] = {}
child_from_prefix: Dict[str, ParserSpecification] = {}
subparsers_from_prefix = {}
for field in field_list:
field_out = handle_field(
field,
parent_classes=parent_classes,
intern_prefix=intern_prefix,
extern_prefix=extern_prefix,
subcommand_prefix=subcommand_prefix,
prog_suffix=prog_suffix,
)
if isinstance(field_out, _arguments.ArgumentDefinition):
# Handle single arguments.
args.append(field_out)
if field_out.lowered.required:
has_required_args = True
elif isinstance(field_out, SubparsersSpecification):
# Handle subparsers.
subparsers_from_prefix[field_out.intern_prefix] = field_out
elif isinstance(field_out, ParserSpecification):
# Handle nested parsers.
nested_parser = field_out
child_from_prefix[field_out.intern_prefix] = nested_parser
# Flatten subparsers from nested parser into current parser.
# This handles the case where a field's type has subcommands that need
# to be accessible at the parent level.
for (
prefix,
subparser_spec,
) in nested_parser.subparsers_from_intern_prefix.items():
subparsers_from_prefix[prefix] = subparser_spec
if nested_parser.has_required_args:
has_required_args = True
# Helptext for this field; used as description for grouping arguments.
class_field_name = _strings.make_field_name(
[intern_prefix, field.intern_name]
)
if field.helptext is not None:
# Keep lazy - don't evaluate yet.
helptext_from_intern_prefixed_field_name[class_field_name] = (
field.helptext
)
else:
helptext_from_intern_prefixed_field_name[class_field_name] = (
_docstrings.get_callable_description(nested_parser.f)
)
# If arguments are in an optional group, it indicates that the default_instance
# will be used if none of the arguments are passed in.
if (
len(nested_parser.args) >= 1
and _markers._OPTIONAL_GROUP in nested_parser.args[0].field.markers
):
current_helptext = helptext_from_intern_prefixed_field_name[
class_field_name
]
# Evaluate lazy helptext before concatenating.
if callable(current_helptext):
current_helptext = current_helptext()
helptext_from_intern_prefixed_field_name[class_field_name] = (
("" if current_helptext is None else current_helptext + "\n\n")
+ "Default: "
+ str(field.default)
)
# Evaluate lazy description if callable.
desc = (
description
if description is not None
else _docstrings.get_callable_description(f)
)
if callable(desc):
desc = desc()
# If still None after evaluation, use empty string.
if desc is None:
desc = ""
parser_spec = ParserSpecification(
f=f,
markers=markers,
description=_strings.remove_single_line_breaks(desc),
args=args,
field_list=field_list,
child_from_prefix=child_from_prefix,
helptext_from_intern_prefixed_field_name=helptext_from_intern_prefixed_field_name,
subparsers_from_intern_prefix=subparsers_from_prefix,
intern_prefix=intern_prefix,
extern_prefix=extern_prefix,
has_required_args=has_required_args,
subparser_parent=None,
prog_suffix=prog_suffix,
)
return parser_spec
[docs]
def get_args_including_children(
self,
local_root: ParserSpecification | None = None,
) -> list[ArgWithContext]:
"""Get all arguments in this parser and its children.
Does not include arguments in subparsers.
"""
if local_root is None:
local_root = self
args = [ArgWithContext(arg, self, local_root) for arg in self.args]
for child in self.child_from_prefix.values():
args.extend(child.get_args_including_children(local_root))
return args
[docs]
def handle_field(
field: _fields.FieldDefinition,
parent_classes: Set[Type[Any]],
intern_prefix: str,
extern_prefix: str,
subcommand_prefix: str,
prog_suffix: str,
) -> Union[
_arguments.ArgumentDefinition,
ParserSpecification,
SubparsersSpecification,
]:
"""Determine what to do with a single field definition."""
# Check that the default value matches the final resolved type.
# There's some similar Union-specific logic for this in narrow_union_type(). We
# may be able to consolidate this.
if (
not _resolver.is_instance(field.type_stripped, field.default)
# If a custom constructor is set, static_type may not be
# matched to the annotated type.
and field.argconf.constructor_factory is None
and field.default not in _singleton.DEFAULT_SENTINEL_SINGLETONS
# The numeric tower in Python is wacky. This logic is non-critical, so
# we'll just skip it (+the complexity) for numbers.
and not isinstance(field.default, numbers.Number)
):
# If the default value doesn't match the resolved type, we expand the
# type. This is inspired by https://github.com/brentyi/tyro/issues/88.
field_name = _strings.make_field_name([extern_prefix, field.extern_name])
message = (
f"The field `{field_name}` is annotated with type `{field.type}`, "
f"but the default value `{field.default}` has type `{type(field.default)}`. "
f"We'll try to handle this gracefully, but it may cause unexpected behavior."
)
warnings.warn(message)
field = field.with_new_type_stripped(
Union[field.type_stripped, type(field.default)] # type: ignore
)
# Force primitive if (1) the field is annotated with a primitive constructor spec, or (2) if
# a custom primitive exists for the type.
force_primitive = (
len(_resolver.unwrap_annotated(field.type, PrimitiveConstructorSpec)[1]) > 0
) or ConstructorRegistry._is_primitive_type(
field.type, field.markers, nondefault_only=True
)
if not force_primitive:
# (1) Handle Unions over callables; these result in subparsers.
if _markers.Suppress not in field.markers:
subparsers_attempt = SubparsersSpecification.from_field(
field,
parent_classes=parent_classes,
intern_prefix=_strings.make_field_name(
[intern_prefix, field.intern_name]
),
extern_prefix=_strings.make_field_name(
[extern_prefix, field.extern_name]
),
prog_suffix=prog_suffix,
)
if subparsers_attempt is not None:
return subparsers_attempt
# (2) Handle nested callables.
if _fields.is_struct_type(field.type, field.default, in_union_context=False):
# Keep description lazy - don't evaluate yet.
return ParserSpecification.from_callable_or_type(
field.type_stripped,
markers=field.markers,
description=field.helptext,
parent_classes=parent_classes,
default_instance=field.default,
intern_prefix=_strings.make_field_name(
[intern_prefix, field.intern_name]
),
extern_prefix=(
_strings.make_field_name([extern_prefix, field.extern_name])
if field.argconf.prefix_name in (True, None)
else field.extern_name
),
subcommand_prefix=subcommand_prefix,
support_single_arg_types=False,
prog_suffix=prog_suffix,
)
# (3) Handle primitive or fixed types. These produce a single argument!
arg = _arguments.ArgumentDefinition(
intern_prefix=intern_prefix,
extern_prefix=extern_prefix,
subcommand_prefix=subcommand_prefix,
field=field,
)
# Validate that Fixed/Suppress fields have defaults.
if (
_markers.Fixed in field.markers or _markers.Suppress in field.markers
) and field.default in _singleton.MISSING_AND_MISSING_NONPROP:
raise UnsupportedTypeAnnotationError(
(
fmt.text(
"Field ",
fmt.text["magenta", "bold"](field.intern_name),
" is marked as Fixed or Suppress but is missing a default value",
),
)
)
return arg
[docs]
@dataclasses.dataclass(frozen=True)
class SubparsersSpecification:
"""Structure for defining subparsers. Each subparser is a parser with a name."""
description: str | Callable[[], str | None] | None
parser_from_name: Dict[str, LazyParserSpecification]
default_name: str | None
default_parser: ParserSpecification | None
intern_prefix: str
extern_prefix: str
required: bool
default_instance: Any
options: Tuple[Union[TypeForm[Any], Callable], ...]
prog_suffix: str
[docs]
@staticmethod
def from_field(
field: _fields.FieldDefinition,
parent_classes: Set[Type[Any]],
intern_prefix: str,
extern_prefix: str,
prog_suffix: str,
) -> SubparsersSpecification | ParserSpecification | None:
"""From a field: return either a subparser specification, a parser
specification for subcommands when `tyro.conf.AvoidSubcommands` is used
and a default is set, or `None` if the field does not create a
subparser."""
# Union of classes should create subparsers.
typ = _resolver.unwrap_annotated(field.type_stripped)
if not is_typing_union(get_origin(typ)):
return None
# We don't use sets here to retain order of subcommands.
options: List[Union[type, Callable]]
options = [typ for typ in get_args(typ)]
# If specified, swap types using tyro.conf.subcommand(constructor=...).
found_subcommand_conf = False
for i, option in enumerate(options):
_, found_subcommand_configs = _resolver.unwrap_annotated(
option, _confstruct._SubcommandConfig
)
if (
len(found_subcommand_configs) > 0
and found_subcommand_configs[0].constructor_factory is not None
):
found_subcommand_conf = True
options[i] = Annotated[ # type: ignore
(
found_subcommand_configs[0].constructor_factory(),
*_resolver.unwrap_annotated(option, "all")[1],
)
]
# Exit if we don't contain any struct types.
def recursive_contains_struct_type(options: list[Any]) -> bool:
for o in options:
if _fields.is_struct_type(
o, _singleton.MISSING_NONPROP, in_union_context=True
):
return True
if is_typing_union(get_origin(_resolver.unwrap_annotated(o))):
if recursive_contains_struct_type(get_args(o)): # type: ignore
return True
return False
if not found_subcommand_conf and not recursive_contains_struct_type(options):
return None
# Get subcommand configurations from `tyro.conf.subcommand()`.
subcommand_config_from_name: Dict[str, _confstruct._SubcommandConfig] = {}
subcommand_type_from_name: Dict[str, type] = {}
subcommand_names: list[str] = []
for option in options:
option_unwrapped, found_subcommand_configs = _resolver.unwrap_annotated(
option, _confstruct._SubcommandConfig
)
subcommand_name = _strings.subparser_name_from_type(
(
""
if _markers.OmitSubcommandPrefixes in field.markers
else extern_prefix
),
cast(type, option),
)
subcommand_names.append(subcommand_name)
if subcommand_name in subcommand_type_from_name:
# Raise a warning that the subcommand already exists
original_type = subcommand_type_from_name[subcommand_name]
original_type_full_name = (
f"{original_type.__module__}.{original_type.__name__}"
)
new_type_full_name = (
f"{option_unwrapped.__module__}.{option_unwrapped.__name__}"
if option_unwrapped is not None
else "none"
)
warnings.warn(
f"Duplicate subcommand name detected: '{subcommand_name}' is already used for "
f"{original_type_full_name} but will be overwritten by {new_type_full_name}. "
f"Only the last type ({new_type_full_name}) will be accessible via this subcommand. "
f"Consider using distinct class names or use tyro.conf.subcommand() to specify "
f"explicit subcommand names."
)
if len(found_subcommand_configs) != 0:
# Explicitly annotated default.
assert len(found_subcommand_configs) == 1, (
f"Expected only one subcommand config, but {subcommand_name} has"
f" {len(found_subcommand_configs)}."
)
subcommand_config_from_name[subcommand_name] = found_subcommand_configs[
0
]
subcommand_type_from_name[subcommand_name] = cast(type, option)
# If a field default is provided, try to find a matching subcommand name.
# Note: EXCLUDE_FROM_CALL (from TypedDict total=False or NotRequired[]) is
# a sentinel that means no default was provided, so we skip matching.
default_name = (
_subcommand_matching.match_subcommand(
field.default,
subcommand_config_from_name,
subcommand_type_from_name,
extern_prefix,
)
if field.default not in _singleton.DEFAULT_SENTINEL_SINGLETONS
else None
)
# Handle `tyro.conf.AvoidSubcommands` with a default value.
if default_name is not None and _markers.AvoidSubcommands in field.markers:
return ParserSpecification.from_callable_or_type(
subcommand_type_from_name[default_name],
markers=field.markers,
description=None,
parent_classes=parent_classes,
default_instance=field.default,
intern_prefix=intern_prefix,
extern_prefix=extern_prefix,
subcommand_prefix=extern_prefix,
support_single_arg_types=True,
prog_suffix=prog_suffix,
)
# Add subcommands for each option.
parser_from_name: Dict[str, LazyParserSpecification] = {}
for option, subcommand_name in zip(options, subcommand_names):
# Get a subcommand config: either pulled from the type annotations or the
# field default.
if subcommand_name in subcommand_config_from_name:
subcommand_config = subcommand_config_from_name[subcommand_name]
else:
subcommand_config = _confstruct._SubcommandConfig(
"unused",
description=None,
default=_singleton.MISSING_NONPROP,
prefix_name=True,
constructor_factory=None,
)
# If names match, borrow subcommand default from field default.
if default_name == subcommand_name and (
field.default not in _singleton.MISSING_AND_MISSING_NONPROP
):
subcommand_config = dataclasses.replace(
subcommand_config, default=field.default
)
# Strip the subcommand config from the option type.
# Relevant: https://github.com/brentyi/tyro/pull/117
option_unwrapped, annotations = _resolver.unwrap_annotated(option, "all")
annotations = tuple(
a
for a in annotations
if not isinstance(a, _confstruct._SubcommandConfig)
)
if _markers.Suppress in annotations:
continue
if len(annotations) == 0:
option = option_unwrapped
else:
option = Annotated[(option_unwrapped,) + annotations] # type: ignore
# Extract description early for fast help text generation.
# If no explicit description, get it from the callable's docstring.
description_for_help = subcommand_config.description
if option_unwrapped is type(None):
description_for_help = ""
elif description_for_help is None:
description_for_help = _docstrings.get_callable_description(
option_unwrapped
)
# Create lazy parser: defer expensive parsing until actually needed.
def parser_factory(
option_captured: Any = option,
markers_captured: Set[_markers._Marker] = field.markers,
subcommand_config_captured: _confstruct._SubcommandConfig = subcommand_config,
parent_classes_captured: Set[Type[Any]] = parent_classes,
intern_prefix_captured: str = intern_prefix,
extern_prefix_captured: str = extern_prefix,
prog_suffix_captured: str = prog_suffix,
subcommand_name_captured: str = subcommand_name,
field_markers_captured: Set[_markers._Marker] = field.markers,
) -> ParserSpecification:
with _fields.FieldDefinition.marker_context(
tuple(field_markers_captured)
):
subparser = ParserSpecification.from_callable_or_type(
option_captured, # type: ignore
markers=markers_captured,
description=subcommand_config_captured.description,
parent_classes=parent_classes_captured,
default_instance=subcommand_config_captured.default,
intern_prefix=intern_prefix_captured,
extern_prefix=extern_prefix_captured,
subcommand_prefix=extern_prefix_captured,
support_single_arg_types=True,
prog_suffix=subcommand_name_captured
if prog_suffix_captured == ""
else prog_suffix_captured + " " + subcommand_name_captured,
)
# Apply prefix to helptext in nested classes in subparsers.
subparser = dataclasses.replace(
subparser,
helptext_from_intern_prefixed_field_name={
_strings.make_field_name([intern_prefix_captured, k]): v
for k, v in subparser.helptext_from_intern_prefixed_field_name.items()
},
)
return subparser
parser_from_name[subcommand_name] = LazyParserSpecification(
description=_strings.remove_single_line_breaks(description_for_help),
_factory=parser_factory, # type: ignore
)
# Default parser was suppressed!
if default_name not in parser_from_name:
default_name = None
# Required if a default is passed in, but the default value has missing
# parameters.
default_parser = None
if default_name is None:
# If the default is EXCLUDE_FROM_CALL (from TypedDict total=False or
# NotRequired[Union[...]]), the subparser is optional. When no subcommand
# is selected, the field will be excluded from the result (see _calling.py).
required = field.default is not _singleton.EXCLUDE_FROM_CALL
else:
required = False
# Evaluate the lazy parser to check for required args/subparsers.
default_parser_evaluated = parser_from_name[default_name].evaluate()
# Error should have been caught earlier.
assert not isinstance(
default_parser_evaluated, UnsupportedTypeAnnotationError
), "Unexpected UnsupportedTypeAnnotationError in backend"
# If there are any required arguments.
if any(
map(lambda arg: arg.lowered.required, default_parser_evaluated.args)
):
required = True
default_parser = None
# If there are any required subparsers.
elif any(
subparser_spec.required
for subparser_spec in default_parser_evaluated.subparsers_from_intern_prefix.values()
):
required = True
default_parser = None
else:
default_parser = default_parser_evaluated
return SubparsersSpecification(
# If we wanted, we could add information about the default instance
# automatically, as is done for normal fields. But for now we just rely on
# the user to include it in the docstring.
# Keep description lazy - don't evaluate yet.
description=field.helptext,
parser_from_name=parser_from_name,
default_name=default_name,
default_parser=default_parser,
intern_prefix=intern_prefix,
extern_prefix=extern_prefix,
required=required,
default_instance=field.default,
options=tuple(options),
prog_suffix=prog_suffix,
)