Source code for fifteen.data._prefetching_map

import collections
import dataclasses
from typing import Iterable, Optional, TypeVar, Union, cast, overload

import jax
from jax.lib import xla_client

from ._protocols import SizedIterable

PyTreeType = TypeVar("PyTreeType")


@overload
def prefetching_map(
    inputs: SizedIterable[PyTreeType],
    device: Optional[xla_client.Device] = None,
    buffer_size: int = 2,
) -> SizedIterable[PyTreeType]:
    ...


@overload
def prefetching_map(
    inputs: Iterable[PyTreeType],
    device: Optional[xla_client.Device] = None,
    buffer_size: int = 2,
) -> Iterable[PyTreeType]:
    ...


[docs]def prefetching_map( inputs: Union[Iterable[PyTreeType], SizedIterable[PyTreeType]], device: Optional[xla_client.Device] = None, buffer_size: int = 2, ) -> Union[Iterable[PyTreeType], SizedIterable[PyTreeType]]: """Maps iterables over PyTrees to an identical iterable, but with a prefetching buffer under the hood. Adapted from `flax.jax_utils.prefetch_to_device()`. This can improve parallelization for GPUs, particularly when memory is re-allocated before freeing is finished. When the buffer size is set to 2, we make it explicit that two sets of data should live in GPU memory at once: for a standard training loop, this is typically both the "current" minibatch and the "next" one. If a device is specified, we commit arrays (via `jax.device_put()`) before pushing them onto the buffer. This should generally be set if the input iterable yields arrays that are still living on the CPU. For multi-device use cases, we can combine this function with :meth:`fifteen.data.sharding_map()`.""" if hasattr(inputs, "__len__"): return _PrefetchingMapSized( cast(SizedIterable[PyTreeType], inputs), device, buffer_size ) else: return _PrefetchingMap(inputs, device, buffer_size)
@dataclasses.dataclass class _PrefetchingMap(Iterable[PyTreeType]): inputs: Iterable[PyTreeType] device: Optional[jax.lib.xla_client.Device] buffer_size: int def __iter__(self): """Adapted from `flax.jax_utils.prefetch_to_device()`: https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device """ queue = collections.deque() input_iter = iter(self.inputs) def try_enqueue() -> None: try: item = next(input_iter) except StopIteration: return if self.device is not None: assert not isinstance( item, jax.lib.xla_extension.pmap_lib.ShardedDeviceArray ), "Should not move sharded arrays -- device should be set to `None`." item = jax.device_put(item, device=self.device) queue.append(item) for i in range(self.buffer_size): try_enqueue() while len(queue) > 0: yield queue.pop() try_enqueue() @dataclasses.dataclass class _PrefetchingMapSized(_PrefetchingMap[PyTreeType], SizedIterable[PyTreeType]): inputs: SizedIterable[PyTreeType] def __len__(self) -> int: return len(self.inputs)