import collections
import dataclasses
from typing import Iterable, Optional, TypeVar, Union, cast, overload
import jax
from jax.lib import xla_client
from ._protocols import SizedIterable
PyTreeType = TypeVar("PyTreeType")
@overload
def prefetching_map(
    inputs: SizedIterable[PyTreeType],
    device: Optional[xla_client.Device] = None,
    buffer_size: int = 2,
) -> SizedIterable[PyTreeType]:
    ...
@overload
def prefetching_map(
    inputs: Iterable[PyTreeType],
    device: Optional[xla_client.Device] = None,
    buffer_size: int = 2,
) -> Iterable[PyTreeType]:
    ...
[docs]def prefetching_map(
    inputs: Union[Iterable[PyTreeType], SizedIterable[PyTreeType]],
    device: Optional[xla_client.Device] = None,
    buffer_size: int = 2,
) -> Union[Iterable[PyTreeType], SizedIterable[PyTreeType]]:
    """Maps iterables over PyTrees to an identical iterable, but with a prefetching
    buffer under the hood. Adapted from `flax.jax_utils.prefetch_to_device()`.
    This can improve parallelization for GPUs, particularly when memory is re-allocated
    before freeing is finished. When the buffer size is set to 2, we make it explicit
    that two sets of data should live in GPU memory at once: for a standard training
    loop, this is typically both the "current" minibatch and the "next" one.
    If a device is specified, we commit arrays (via `jax.device_put()`) before pushing them
    onto the buffer. This should generally be set if the input iterable yields arrays
    that are still living on the CPU.
    For multi-device use cases, we can combine this function with
    :meth:`fifteen.data.sharding_map()`."""
    if hasattr(inputs, "__len__"):
        return _PrefetchingMapSized(
            cast(SizedIterable[PyTreeType], inputs), device, buffer_size
        )
    else:
        return _PrefetchingMap(inputs, device, buffer_size) 
@dataclasses.dataclass
class _PrefetchingMap(Iterable[PyTreeType]):
    inputs: Iterable[PyTreeType]
    device: Optional[jax.lib.xla_client.Device]
    buffer_size: int
    def __iter__(self):
        """Adapted from `flax.jax_utils.prefetch_to_device()`:
        https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device
        """
        queue = collections.deque()
        input_iter = iter(self.inputs)
        def try_enqueue() -> None:
            try:
                item = next(input_iter)
            except StopIteration:
                return
            if self.device is not None:
                assert not isinstance(
                    item, jax.lib.xla_extension.pmap_lib.ShardedDeviceArray
                ), "Should not move sharded arrays -- device should be set to `None`."
                item = jax.device_put(item, device=self.device)
            queue.append(item)
        for i in range(self.buffer_size):
            try_enqueue()
        while len(queue) > 0:
            yield queue.pop()
            try_enqueue()
@dataclasses.dataclass
class _PrefetchingMapSized(_PrefetchingMap[PyTreeType], SizedIterable[PyTreeType]):
    inputs: SizedIterable[PyTreeType]
    def __len__(self) -> int:
        return len(self.inputs)