# type: ignore[no-redef]
import asyncio
import collections
import contextlib
import itertools
import operator
import random
import time
import types
from functools import partial
from typing import AsyncIterable, Callable, Iterable, Iterator, Sequence
from multimethod import multimethod, overload # type: ignore
__version__ = '1.1'
iscoro = asyncio.iscoroutinefunction
def fibonacci(x, y):
"""Generate fibonacci sequence."""
while True:
yield x
x, y = y, (x + y)
[docs]@contextlib.contextmanager
def suppress(*exceptions):
"""Variant of `contextlib.suppress`, which also records exception."""
excs = []
try:
yield excs
except exceptions as exc:
excs.append(exc)
[docs]def first(predicate: Callable, iterable: Iterable, *default):
"""Return first item which evaluates to true, like `any` with filtering."""
return next(filter(predicate, iterable), *default)
class reiter(partial):
"""A partial iterator which is re-iterable."""
__iter__ = partial.__call__
class partialmethod(partial):
"""Variant of functools.partialmethod."""
def __get__(self, instance, owner):
return self if instance is None else types.MethodType(self, instance)
[docs]class Stats(collections.Counter):
"""Mapping of attempt counts."""
[docs] def add(self, attempt, elapsed):
"""Record attempt and return next value."""
self[attempt] += 1
return elapsed
@property
def total(self):
"""total number of attempts"""
return sum(self.values())
@property
def failures(self):
"""number of repeat attempts"""
return self.total - self[0]
def grouped(queue, size=None):
"""Generate slices from a sequence without relying on a fixed `len`."""
group, start = queue[:size], 0
while group:
start += len(group)
yield group
group = queue[start : size and start + size]
[docs]class waiter:
"""An iterable which sleeps for given delays.
:param delays: any iterable of seconds, or a scalar which is repeated endlessly
:param timeout: optional timeout for iteration
"""
Stats = Stats
[docs] def __init__(self, delays, timeout=float('inf')):
with suppress(TypeError) as excs:
iter(delays)
self.delays = itertools.repeat(delays) if excs else delays
self.timeout = timeout
self.stats = self.Stats()
[docs] def __iter__(self):
"""Generate a slow loop of elapsed time."""
start = time.time()
yield self.stats.add(0, 0.0)
for attempt, delay in enumerate(self.delays, 1):
remaining = start + self.timeout - time.time()
if remaining < 0:
break
time.sleep(min(delay, remaining))
yield self.stats.add(attempt, time.time() - start)
[docs] async def __aiter__(self):
"""Asynchronously generate a slow loop of elapsed time."""
start = time.time()
yield self.stats.add(0, 0.0)
for attempt, delay in enumerate(self.delays, 1):
remaining = start + self.timeout - time.time()
if remaining < 0:
break
await asyncio.sleep(min(delay, remaining))
yield self.stats.add(attempt, time.time() - start)
def clone(self, func, *args):
return type(self)(reiter(func, *args), self.timeout)
[docs] def map(self, func: Callable, *iterables: Iterable) -> 'waiter':
"""Return new waiter with function mapped across delays."""
return self.clone(map, func, self.delays, *iterables)
[docs] @classmethod
def fibonacci(cls, delay, **kwargs) -> 'waiter':
"""Create waiter with fibonacci backoff."""
return cls(reiter(fibonacci, delay, delay), **kwargs)
[docs] @classmethod
def count(cls, *args, **kwargs) -> 'waiter':
"""Create waiter based on `itertools.count`."""
return cls(reiter(itertools.count, *args), **kwargs)
[docs] @classmethod
def accumulate(cls, *args, **kwargs) -> 'waiter':
"""Create waiter based on `itertools.accumulate` (requires Python 3)."""
return cls(reiter(itertools.accumulate, *args), **kwargs)
[docs] @classmethod
def exponential(cls, base, **kwargs) -> 'waiter':
"""Create waiter with exponential backoff."""
return cls.count(**kwargs).map(base.__pow__)
[docs] @classmethod
def polynomial(cls, exp, **kwargs) -> 'waiter':
"""Create waiter with polynomial backoff."""
return cls.count(**kwargs).map(exp.__rpow__)
[docs] def __getitem__(self, slc: slice) -> 'waiter':
"""Slice delays, e.g., to limit attempt count."""
return self.clone(itertools.islice, self.delays, slc.start, slc.stop, slc.step)
[docs] def __le__(self, ceiling) -> 'waiter':
"""Limit maximum delay generated."""
return self.map(partial(min, ceiling))
[docs] def __ge__(self, floor) -> 'waiter':
"""Limit minimum delay generated."""
return self.map(partial(max, floor))
[docs] def __add__(self, step) -> 'waiter':
"""Generate incremental backoff."""
return self.map(operator.add, reiter(itertools.count, 0, step))
[docs] def __mul__(self, factor) -> 'waiter':
"""Generate exponential backoff."""
return self.map(operator.mul, reiter(map, factor.__pow__, reiter(itertools.count)))
[docs] def random(self, start, stop) -> 'waiter':
"""Add random jitter within given range."""
return self.map(lambda delay: delay + random.uniform(start, stop))
@multimethod
def throttle(self, iterable):
"""Delay iteration."""
return map(operator.itemgetter(1), zip(self, iterable))
[docs] @multimethod
async def throttle(self, iterable: AsyncIterable):
anext = iterable.__aiter__().__anext__
with suppress(StopAsyncIteration):
async for _ in self:
yield await anext()
[docs] def stream(self, queue: Sequence, size: int = None) -> Iterator:
"""Generate chained values in groups from an iterable.
The queue can be extended while in use.
"""
it = iter(queue)
groups = iter(lambda: list(itertools.islice(it, size)), [])
if isinstance(queue, Sequence):
groups = grouped(queue, size)
return itertools.chain.from_iterable(self.throttle(groups))
[docs] def suppressed(self, exception, func: Callable, iterable: Iterable) -> Iterator[tuple]:
"""Provisionally generate `arg, func(arg)` pairs while exception isn't raised."""
queue = list(iterable)
for arg in self.stream(queue):
try:
yield arg, func(arg)
except exception:
queue.append(arg)
[docs] def filtered(self, predicate: Callable, func: Callable, iterable: Iterable) -> Iterator[tuple]:
"""Provisionally generate `arg, func(arg)` pairs while predicate evaluates to true."""
queue = list(iterable)
for arg in self.stream(queue):
result = func(arg)
if predicate(result):
yield arg, result
else:
queue.append(arg)
@overload
def repeat(self, func, *args, **kwargs):
"""Repeat function call."""
return (func(*args, **kwargs) for _ in self)
[docs] @overload
async def repeat(self, func: iscoro, *args, **kwargs): # type: ignore
async for _ in self:
yield await func(*args, **kwargs) # type: ignore
@overload
def retry(self, exception, func, *args, **kwargs):
"""Repeat function call until exception isn't raised."""
for _ in self:
with suppress(exception) as excs:
return func(*args, **kwargs)
raise excs[0]
[docs] @overload
async def retry(self, exception, func: iscoro, *args, **kwargs): # type: ignore
async for _ in self:
with suppress(exception) as excs:
return await func(*args, **kwargs) # type: ignore
raise excs[0]
@overload
def poll(self, predicate, func, *args, **kwargs):
"""Repeat function call until predicate evaluates to true."""
return first(predicate, self.repeat(func, *args, **kwargs))
[docs] @overload
async def poll(self, predicate, func: iscoro, *args, **kwargs): # type: ignore
async for result in self.repeat(func, *args, **kwargs):
if predicate(result):
return result
raise StopAsyncIteration
[docs] def repeating(self, func: Callable):
"""A decorator for `repeat`."""
return partialmethod(self.repeat, func)
[docs] def retrying(self, exception: Exception):
"""Return a decorator for `retry`."""
return partial(partialmethod, self.retry, exception)
[docs] def polling(self, predicate: Callable):
"""Return a decorator for `poll`."""
return partial(partialmethod, self.poll, predicate)
wait = waiter