this repo has no description
1#!/usr/bin/env python3
2# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
3"""Functional tools for creating and using iterators."""
4# TODO(T42113424) Replace stubs with an actual implementation
5
6import operator
7from builtins import _number_check
8
9from _builtins import (
10 _int_check,
11 _int_guard,
12 _list_len,
13 _list_new,
14 _tuple_len,
15 _Unbound,
16 _unimplemented,
17)
18
19
20class accumulate:
21 def __iter__(self):
22 return self
23
24 def __new__(cls, iterable, func=None, initial=None):
25 result = object.__new__(cls)
26 result._it = iter(iterable)
27 result._func = operator.add if func is None else func
28 result._initial = initial
29 result._accumulated = None
30 return result
31
32 def __next__(self):
33 initial = self._initial
34 if initial is not None:
35 self._accumulated = initial
36 self._initial = None
37 return initial
38
39 result = self._accumulated
40
41 if result is None:
42 result = next(self._it)
43 self._accumulated = result
44 return result
45
46 result = self._func(result, next(self._it))
47 self._accumulated = result
48 return result
49
50 def __reduce__(self):
51 _unimplemented()
52
53 def __setstate__(self):
54 _unimplemented()
55
56
57class chain:
58 def __iter__(self):
59 return self
60
61 def __new__(cls, *iterables):
62 result = object.__new__(cls)
63 result._it = None
64 result._iterables = iter(iterables)
65 return result
66
67 def __next__(self):
68 while True:
69 if self._it is None:
70 try:
71 self._it = iter(next(self._iterables))
72 except StopIteration:
73 raise
74 try:
75 result = next(self._it)
76 except StopIteration:
77 self._it = None
78 continue
79 return result
80
81 def __reduce__(self):
82 _unimplemented()
83
84 def __setstate__(self):
85 _unimplemented()
86
87 @classmethod
88 def from_iterable(cls, iterable):
89 result = object.__new__(cls)
90 result._it = None
91 result._iterables = iter(iterable)
92 return result
93
94
95class combinations:
96 def __iter__(self):
97 return self
98
99 def __new__(cls, iterable, r):
100 _int_guard(r)
101 if r < 0:
102 raise ValueError("r must be non-negative")
103
104 result = object.__new__(cls)
105
106 seq = tuple(iterable)
107 n = _tuple_len(seq)
108
109 if r > n:
110 result._seq = None
111 return result
112
113 result._seq = seq
114 result._indices = list(range(r))
115 result._r = r
116 result._index_delta = n - r
117 return result
118
119 def __next__(self):
120 seq = self._seq
121 if seq is None:
122 raise StopIteration
123
124 r = self._r
125 indices = self._indices
126 index_delta = self._index_delta
127
128 # The result is the elements of the sequence at the current indices
129 result = (*(seq[indices[i]] for i in range(r)),)
130
131 # Scan indices right-to-left until finding one that is not at its
132 # maximum (i + n - r).
133 i = r - 1
134 while i >= 0:
135 if indices[i] < i + index_delta:
136 # Increment the current index which we know is not at its
137 # maximum. Then move back to the right setting each index
138 # to its lowest possible value (one higher than the index
139 # to its left -- this maintains the sort order invariant).
140 indices[i] += 1
141 for j in range(i + 1, r):
142 indices[j] = indices[j - 1] + 1
143 break
144 i -= 1
145 else:
146 # The indices are all at their maximum values and we're done.
147 self._seq = None
148
149 return result
150
151 def __reduce__(self):
152 _unimplemented()
153
154 def __setstate__(self):
155 _unimplemented()
156
157 def __sizeof__(self):
158 _unimplemented()
159
160
161class combinations_with_replacement:
162 def __iter__(self):
163 return self
164
165 def __new__(cls, iterable, r):
166 _int_guard(r)
167 if r < 0:
168 raise ValueError("r must be non-negative")
169
170 result = object.__new__(cls)
171
172 seq = tuple(iterable)
173
174 # We can't create combinations is if seq is empty and r > 0
175 if not seq and r:
176 result._seq = None
177 return result
178
179 result._seq = seq
180 result._indices = _list_new(r, 0)
181 result._r = r
182 result._max_index = _tuple_len(seq) - 1
183 return result
184
185 def __next__(self):
186 seq = self._seq
187 if seq is None:
188 raise StopIteration
189
190 r = self._r
191 indices = self._indices
192 max_index = self._max_index
193
194 # The result is the elements of the sequence at the current indices
195 result = (*(seq[indices[i]] for i in range(r)),)
196
197 # Scan indices right-to-left until finding one that is not at its
198 # maximum (n - 1).
199 i = r - 1
200 while i >= 0:
201 if indices[i] < max_index:
202 # Increment the current index which we know is not at its
203 # maximum. Then set all to the right to the same value.
204 index = indices[i] = indices[i] + 1
205 for j in range(i, r):
206 indices[j] = index
207 break
208 i -= 1
209 else:
210 # The indices are all at their maximum values and we're done.
211 self._seq = None
212
213 return result
214
215 def __reduce__(self):
216 _unimplemented()
217
218 def __setstate__(self):
219 _unimplemented()
220
221 def __sizeof__(self):
222 _unimplemented()
223
224
225class compress:
226 def __iter__(self):
227 return self
228
229 def __new__(cls, data, selectors):
230 result = object.__new__(cls)
231 result._data = iter(data)
232 result._selectors = iter(selectors)
233 return result
234
235 def __next__(self):
236 data = self._data
237 selectors = self._selectors
238
239 while True:
240 datum = next(data)
241 selector = next(selectors)
242 if selector:
243 return datum
244
245 def __reduce__(self):
246 _unimplemented()
247
248
249class count:
250 def __iter__(self):
251 return self
252
253 def __new__(cls, start=0, step=1):
254 if not _number_check(start):
255 raise TypeError("a number is required")
256
257 result = object.__new__(cls)
258 result.count = start
259 result.step = step
260 return result
261
262 def __next__(self):
263 result = self.count
264 self.count += self.step
265 return result
266
267 def __reduce__(self):
268 _unimplemented()
269
270 def __repr__(self):
271 return f"count({self.count})"
272
273
274class cycle:
275 def __iter__(self):
276 return self
277
278 def __new__(cls, seq):
279 result = object.__new__(cls)
280 result._seq = iter(seq)
281 result._saved = []
282 result._first_pass = True
283 return result
284
285 def __next__(self):
286 try:
287 result = next(self._seq)
288 if self._first_pass:
289 self._saved.append(result)
290 return result
291 except StopIteration:
292 self._first_pass = False
293 self._seq = iter(self._saved)
294 return next(self._seq)
295
296 def __reduce__(self):
297 _unimplemented()
298
299 def __setstate__(self):
300 _unimplemented()
301
302
303class dropwhile:
304 def __iter__(self):
305 return self
306
307 def __new__(cls, predicate, iterable):
308 result = object.__new__(cls)
309 result._it = iter(iterable)
310 result._func = predicate
311 result._start = False
312 return result
313
314 def __next__(self):
315 if self._start:
316 return next(self._it)
317
318 func = self._func
319
320 while True:
321 item = next(self._it)
322 if not func(item):
323 self._start = True
324 return item
325
326 def __reduce__(self):
327 _unimplemented()
328
329 def __setstate__(self):
330 _unimplemented()
331
332
333class filterfalse:
334 def __iter__(self):
335 return self
336
337 def __new__(cls, predicate, iterable):
338 result = object.__new__(cls)
339 result._it = iter(iterable)
340 result._predicate = bool if predicate is None else predicate
341 return result
342
343 def __next__(self):
344 while True:
345 item = next(self._it)
346 if not self._predicate(item):
347 return item
348
349 def __reduce__(self):
350 _unimplemented()
351
352
353# internal helper class for groupby
354class _groupby_iterator:
355 def __iter__(self):
356 return self
357
358 def __new__(cls, parent, cur):
359
360 obj = object.__new__(cls)
361 obj._parent = parent
362 obj._currkey = cur
363 return obj
364
365 def __next__(self):
366 parent = self._parent
367 if parent._currkey == self._currkey:
368 val = parent._currval
369 try:
370 parent._currval = next(parent._it)
371 parent._currkey = (
372 parent._currval
373 if parent._keyfunc is None
374 else parent._keyfunc(parent._currval)
375 )
376 except StopIteration:
377 parent._currkey = _Unbound
378 return val
379 raise StopIteration
380
381
382class groupby:
383 def __iter__(self):
384 return self
385
386 def __new__(cls, iterable, key=None):
387
388 obj = object.__new__(cls)
389 obj._it = iter(iterable)
390 obj._tgtkey = obj._currkey = obj._currval = _Unbound
391 obj._keyfunc = key
392 return obj
393
394 def __next__(self):
395 # In middle of previous iterator
396 while self._currkey == self._tgtkey:
397 self._currval = next(self._it)
398 self._currkey = (
399 self._currval if self._keyfunc is None else self._keyfunc(self._currval)
400 )
401 if self._currkey is _Unbound:
402 raise StopIteration
403 # remember group of returned iterator
404 self._tgtkey = self._currkey
405 return self._currkey, _groupby_iterator(self, self._currkey)
406
407
408class islice:
409 def __new__(cls, seq, stop_or_start, stop=_Unbound, step=_Unbound):
410 result = object.__new__(cls)
411 result._it = iter(seq)
412 result._count = 0
413 if stop is _Unbound:
414 start = 0
415 stop = stop_or_start
416 step = 1
417 else:
418 start = 0 if stop_or_start is None else stop_or_start
419 if step is _Unbound or step is None:
420 step = 1
421 elif not _int_check(step) or step < 1:
422 raise ValueError(
423 "Step for islice() must be a positive integer or None."
424 )
425 if stop is None:
426 stop = -1
427 elif not _int_check(stop) or stop == -1:
428 raise ValueError(
429 "Stop argument for islice() must be None or an "
430 "integer: 0 <= x <= sys.maxsize."
431 )
432 if not _int_check(start) or start < 0 or stop < -1:
433 raise ValueError(
434 "Indices for islice() must be None or an integer: "
435 "0 <= x <= sys.maxsize."
436 )
437 result._next = start
438 result._stop = stop
439 result._step = step
440 return result
441
442 def __iter__(self):
443 return self
444
445 def __next__(self):
446 it = self._it
447 if it is None:
448 raise StopIteration
449 count = self._count
450 new_next = self._next
451 while count < new_next:
452 try:
453 next(it)
454 except Exception as exc:
455 self._it = None
456 raise exc
457 count += 1
458 stop = self._stop
459 if count >= stop and stop != -1:
460 self._it = None
461 raise StopIteration
462 try:
463 item = next(it)
464 except Exception as exc:
465 self._it = None
466 raise exc
467 self._count = count + 1
468 new_next += self._step
469 if new_next > stop and stop != -1:
470 new_next = stop
471 self._next = new_next
472 return item
473
474 def __reduce__(self):
475 _unimplemented()
476
477 def __setstate__(self):
478 _unimplemented()
479
480
481class permutations:
482 def __iter__(self):
483 return self
484
485 def __new__(cls, iterable, r=None):
486 seq = tuple(iterable)
487 n = _tuple_len(seq)
488
489 result = object.__new__(cls)
490
491 if r is None:
492 r = n
493 elif r > n:
494 result._seq = None
495 return result
496
497 result._seq = seq
498 result._r = r
499 result._indices = list(range(n))
500 result._cycles = list(range(n, n - r, -1))
501 return result
502
503 def __next__(self):
504 seq = self._seq
505 if seq is None:
506 raise StopIteration
507 r = self._r
508 indices = self._indices
509 indices_len = _list_len(indices)
510 result = (*(seq[indices[i]] for i in range(r)),)
511 cycles = self._cycles
512 i = r - 1
513 while i >= 0:
514 j = cycles[i] - 1
515 if j > 0:
516 cycles[i] = j
517 indices[i], indices[-j] = indices[-j], indices[i]
518 break
519 cycles[i] = indices_len - i
520 tmp = indices[i]
521 k = i + 1
522 while k < indices_len:
523 indices[k - 1] = indices[k]
524 k += 1
525 indices[k - 1] = tmp
526 i -= 1
527 else:
528 self._seq = None
529 return result
530
531 def __reduce__(self):
532 _unimplemented()
533
534 def __setstate__(self):
535 _unimplemented()
536
537 def __sizeof__(self):
538 _unimplemented()
539
540
541class product:
542 def __iter__(self):
543 return self
544
545 def __new__(cls, *iterables, repeat=1):
546 if not _int_check(repeat):
547 raise TypeError
548 length = _tuple_len(iterables) if repeat else 0
549 i = 0
550 repeated = _list_new(length)
551 result = object.__new__(cls)
552 while i < length:
553 item = tuple(iterables[i])
554 if not item:
555 result._iterables = None
556 return result
557 repeated[i] = item
558 i += 1
559 repeated *= repeat
560 result._iterables = repeated
561 result._digits = _list_new(length * repeat, 0)
562 return result
563
564 def __next__(self):
565 iterables = self._iterables
566 if iterables is None:
567 raise StopIteration
568 digits = self._digits
569 length = _list_len(iterables)
570 result = _list_new(length)
571 i = length - 1
572 carry = 1
573 while i >= 0:
574 j = digits[i]
575 result[i] = iterables[i][j]
576 j += carry
577 if j < _tuple_len(iterables[i]):
578 carry = 0
579 digits[i] = j
580 else:
581 carry = 1
582 digits[i] = 0
583 i -= 1
584 if carry:
585 # counter overflowed, stop iteration
586 self._iterables = None
587 return tuple(result)
588
589 def __reduce__(self):
590 _unimplemented()
591
592 def __setstate__(self):
593 _unimplemented()
594
595 def __sizeof__(self):
596 _unimplemented()
597
598
599class repeat:
600 def __iter__(self):
601 return self
602
603 def __new__(cls, elem, times=None):
604 result = object.__new__(cls)
605 result._elem = elem
606 if times is not None:
607 _int_guard(times)
608 result._times = times
609 return result
610
611 def __next__(self):
612 if self._times is None:
613 return self._elem
614 if self._times > 0:
615 self._times -= 1
616 return self._elem
617 raise StopIteration
618
619 def __length_hint__(self):
620 _unimplemented()
621
622 def __reduce__(self):
623 _unimplemented()
624
625 def __repr__(self):
626 _unimplemented()
627
628
629class starmap:
630 def __iter__(self):
631 return self
632
633 def __new__(cls, function, iterable):
634 result = object.__new__(cls)
635 result._it = iter(iterable)
636 result._func = function
637 return result
638
639 def __next__(self):
640 args = next(self._it)
641 return self._func(*args)
642
643 def __reduce__(self):
644 _unimplemented()
645
646
647def tee(iterable, n=2):
648 _int_guard(n)
649 if n < 0:
650 raise ValueError("n must be >= 0")
651 if n == 0:
652 return ()
653
654 it = iter(iterable)
655 copyable = it if hasattr(it, "__copy__") else _tee.from_iterable(it)
656 copyfunc = copyable.__copy__
657 return tuple(copyable if i == 0 else copyfunc() for i in range(n))
658
659
660# Internal cache for tee, a linked list where each link is a cached window to
661# a section of the source iterator
662class _tee_dataobject:
663 # CPython sets this at 57 to align exactly with cache line size. We choose
664 # 55 to align with cache lines in our system: Arrays <=255 elements have 1
665 # word of header. The header and each data element is 8 bytes on a 64-bit
666 # machine. Cache lines are 64-bytes on all x86 machines though they tend to
667 # be fetched in pairs, so any multiple of 8 minus 1 up to 255 is fine.
668 _MAX_VALUES = 55
669
670 def __init__(self, it):
671 self._num_read = 0
672 self._next_link = _Unbound
673 self._it = it
674 self._values = []
675
676 def get_item(self, i):
677 assert i < self.__class__._MAX_VALUES
678
679 if i < self._num_read:
680 return self._values[i]
681 else:
682 assert i == self._num_read
683 value = next(self._it)
684 self._num_read += 1
685 # mutable tuple might be a nice future optimization here
686 self._values.append(value)
687 return value
688
689 def next_link(self):
690 if self._next_link is _Unbound:
691 self._next_link = self.__class__(self._it)
692 return self._next_link
693
694
695class _tee:
696 def __copy__(self):
697 return self.__class__(self._data, self._index)
698
699 def __init__(self, data, index):
700 self._data = data
701 self._index = index
702
703 def __iter__(self):
704 return self
705
706 def __next__(self):
707 if self._index >= _tee_dataobject._MAX_VALUES:
708 self._data = self._data.next_link()
709 self._index = 0
710
711 value = self._data.get_item(self._index)
712 self._index += 1
713 return value
714
715 def __reduce__(self):
716 _unimplemented()
717
718 def __setstate__(self):
719 _unimplemented()
720
721 @classmethod
722 def from_iterable(cls, iterable):
723 it = iter(iterable)
724
725 if isinstance(it, _tee):
726 return it.__copy__()
727 else:
728 return cls(_tee_dataobject(it), 0)
729
730
731class takewhile:
732 def __iter__(self):
733 return self
734
735 def __new__(cls, predicate, iterable):
736 result = object.__new__(cls)
737 result._it = iter(iterable)
738 result._func = predicate
739 result._stop = False
740 return result
741
742 def __next__(self):
743 if self._stop:
744 raise StopIteration
745
746 item = next(self._it)
747 if self._func(item):
748 return item
749
750 self._stop = True
751 raise StopIteration
752
753 def __reduce__(self):
754 _unimplemented()
755
756 def __setstate__(self):
757 _unimplemented()
758
759
760class zip_longest:
761 def __iter__(self):
762 return self
763
764 def __new__(cls, *seqs, fillvalue=None):
765 length = _tuple_len(seqs)
766 result = object.__new__(cls)
767 result._iters = [iter(seq) for seq in seqs]
768 result._num_iters = length
769 result._num_active = length
770 result._fillvalue = fillvalue
771 return result
772
773 def __next__(self):
774 iters = self._iters
775 if not self._num_active:
776 raise StopIteration
777 fillvalue = self._fillvalue
778 values = _list_new(self._num_iters, fillvalue)
779 for i, it in enumerate(iters):
780 try:
781 values[i] = next(it)
782 except StopIteration:
783 self._num_active -= 1
784 if not self._num_active:
785 raise
786 self._iters[i] = repeat(fillvalue)
787 return tuple(values)
788
789 def __reduce__(self):
790 _unimplemented()
791
792 def __setstate__(self):
793 _unimplemented()