this repo has no description
at trunk 793 lines 20 kB view raw
1#!/usr/bin/env python3 2# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com) 3"""Functional tools for creating and using iterators.""" 4# TODO(T42113424) Replace stubs with an actual implementation 5 6import operator 7from builtins import _number_check 8 9from _builtins import ( 10 _int_check, 11 _int_guard, 12 _list_len, 13 _list_new, 14 _tuple_len, 15 _Unbound, 16 _unimplemented, 17) 18 19 20class accumulate: 21 def __iter__(self): 22 return self 23 24 def __new__(cls, iterable, func=None, initial=None): 25 result = object.__new__(cls) 26 result._it = iter(iterable) 27 result._func = operator.add if func is None else func 28 result._initial = initial 29 result._accumulated = None 30 return result 31 32 def __next__(self): 33 initial = self._initial 34 if initial is not None: 35 self._accumulated = initial 36 self._initial = None 37 return initial 38 39 result = self._accumulated 40 41 if result is None: 42 result = next(self._it) 43 self._accumulated = result 44 return result 45 46 result = self._func(result, next(self._it)) 47 self._accumulated = result 48 return result 49 50 def __reduce__(self): 51 _unimplemented() 52 53 def __setstate__(self): 54 _unimplemented() 55 56 57class chain: 58 def __iter__(self): 59 return self 60 61 def __new__(cls, *iterables): 62 result = object.__new__(cls) 63 result._it = None 64 result._iterables = iter(iterables) 65 return result 66 67 def __next__(self): 68 while True: 69 if self._it is None: 70 try: 71 self._it = iter(next(self._iterables)) 72 except StopIteration: 73 raise 74 try: 75 result = next(self._it) 76 except StopIteration: 77 self._it = None 78 continue 79 return result 80 81 def __reduce__(self): 82 _unimplemented() 83 84 def __setstate__(self): 85 _unimplemented() 86 87 @classmethod 88 def from_iterable(cls, iterable): 89 result = object.__new__(cls) 90 result._it = None 91 result._iterables = iter(iterable) 92 return result 93 94 95class combinations: 96 def __iter__(self): 97 return self 98 99 def __new__(cls, iterable, r): 100 _int_guard(r) 101 if r < 0: 102 raise ValueError("r must be non-negative") 103 104 result = object.__new__(cls) 105 106 seq = tuple(iterable) 107 n = _tuple_len(seq) 108 109 if r > n: 110 result._seq = None 111 return result 112 113 result._seq = seq 114 result._indices = list(range(r)) 115 result._r = r 116 result._index_delta = n - r 117 return result 118 119 def __next__(self): 120 seq = self._seq 121 if seq is None: 122 raise StopIteration 123 124 r = self._r 125 indices = self._indices 126 index_delta = self._index_delta 127 128 # The result is the elements of the sequence at the current indices 129 result = (*(seq[indices[i]] for i in range(r)),) 130 131 # Scan indices right-to-left until finding one that is not at its 132 # maximum (i + n - r). 133 i = r - 1 134 while i >= 0: 135 if indices[i] < i + index_delta: 136 # Increment the current index which we know is not at its 137 # maximum. Then move back to the right setting each index 138 # to its lowest possible value (one higher than the index 139 # to its left -- this maintains the sort order invariant). 140 indices[i] += 1 141 for j in range(i + 1, r): 142 indices[j] = indices[j - 1] + 1 143 break 144 i -= 1 145 else: 146 # The indices are all at their maximum values and we're done. 147 self._seq = None 148 149 return result 150 151 def __reduce__(self): 152 _unimplemented() 153 154 def __setstate__(self): 155 _unimplemented() 156 157 def __sizeof__(self): 158 _unimplemented() 159 160 161class combinations_with_replacement: 162 def __iter__(self): 163 return self 164 165 def __new__(cls, iterable, r): 166 _int_guard(r) 167 if r < 0: 168 raise ValueError("r must be non-negative") 169 170 result = object.__new__(cls) 171 172 seq = tuple(iterable) 173 174 # We can't create combinations is if seq is empty and r > 0 175 if not seq and r: 176 result._seq = None 177 return result 178 179 result._seq = seq 180 result._indices = _list_new(r, 0) 181 result._r = r 182 result._max_index = _tuple_len(seq) - 1 183 return result 184 185 def __next__(self): 186 seq = self._seq 187 if seq is None: 188 raise StopIteration 189 190 r = self._r 191 indices = self._indices 192 max_index = self._max_index 193 194 # The result is the elements of the sequence at the current indices 195 result = (*(seq[indices[i]] for i in range(r)),) 196 197 # Scan indices right-to-left until finding one that is not at its 198 # maximum (n - 1). 199 i = r - 1 200 while i >= 0: 201 if indices[i] < max_index: 202 # Increment the current index which we know is not at its 203 # maximum. Then set all to the right to the same value. 204 index = indices[i] = indices[i] + 1 205 for j in range(i, r): 206 indices[j] = index 207 break 208 i -= 1 209 else: 210 # The indices are all at their maximum values and we're done. 211 self._seq = None 212 213 return result 214 215 def __reduce__(self): 216 _unimplemented() 217 218 def __setstate__(self): 219 _unimplemented() 220 221 def __sizeof__(self): 222 _unimplemented() 223 224 225class compress: 226 def __iter__(self): 227 return self 228 229 def __new__(cls, data, selectors): 230 result = object.__new__(cls) 231 result._data = iter(data) 232 result._selectors = iter(selectors) 233 return result 234 235 def __next__(self): 236 data = self._data 237 selectors = self._selectors 238 239 while True: 240 datum = next(data) 241 selector = next(selectors) 242 if selector: 243 return datum 244 245 def __reduce__(self): 246 _unimplemented() 247 248 249class count: 250 def __iter__(self): 251 return self 252 253 def __new__(cls, start=0, step=1): 254 if not _number_check(start): 255 raise TypeError("a number is required") 256 257 result = object.__new__(cls) 258 result.count = start 259 result.step = step 260 return result 261 262 def __next__(self): 263 result = self.count 264 self.count += self.step 265 return result 266 267 def __reduce__(self): 268 _unimplemented() 269 270 def __repr__(self): 271 return f"count({self.count})" 272 273 274class cycle: 275 def __iter__(self): 276 return self 277 278 def __new__(cls, seq): 279 result = object.__new__(cls) 280 result._seq = iter(seq) 281 result._saved = [] 282 result._first_pass = True 283 return result 284 285 def __next__(self): 286 try: 287 result = next(self._seq) 288 if self._first_pass: 289 self._saved.append(result) 290 return result 291 except StopIteration: 292 self._first_pass = False 293 self._seq = iter(self._saved) 294 return next(self._seq) 295 296 def __reduce__(self): 297 _unimplemented() 298 299 def __setstate__(self): 300 _unimplemented() 301 302 303class dropwhile: 304 def __iter__(self): 305 return self 306 307 def __new__(cls, predicate, iterable): 308 result = object.__new__(cls) 309 result._it = iter(iterable) 310 result._func = predicate 311 result._start = False 312 return result 313 314 def __next__(self): 315 if self._start: 316 return next(self._it) 317 318 func = self._func 319 320 while True: 321 item = next(self._it) 322 if not func(item): 323 self._start = True 324 return item 325 326 def __reduce__(self): 327 _unimplemented() 328 329 def __setstate__(self): 330 _unimplemented() 331 332 333class filterfalse: 334 def __iter__(self): 335 return self 336 337 def __new__(cls, predicate, iterable): 338 result = object.__new__(cls) 339 result._it = iter(iterable) 340 result._predicate = bool if predicate is None else predicate 341 return result 342 343 def __next__(self): 344 while True: 345 item = next(self._it) 346 if not self._predicate(item): 347 return item 348 349 def __reduce__(self): 350 _unimplemented() 351 352 353# internal helper class for groupby 354class _groupby_iterator: 355 def __iter__(self): 356 return self 357 358 def __new__(cls, parent, cur): 359 360 obj = object.__new__(cls) 361 obj._parent = parent 362 obj._currkey = cur 363 return obj 364 365 def __next__(self): 366 parent = self._parent 367 if parent._currkey == self._currkey: 368 val = parent._currval 369 try: 370 parent._currval = next(parent._it) 371 parent._currkey = ( 372 parent._currval 373 if parent._keyfunc is None 374 else parent._keyfunc(parent._currval) 375 ) 376 except StopIteration: 377 parent._currkey = _Unbound 378 return val 379 raise StopIteration 380 381 382class groupby: 383 def __iter__(self): 384 return self 385 386 def __new__(cls, iterable, key=None): 387 388 obj = object.__new__(cls) 389 obj._it = iter(iterable) 390 obj._tgtkey = obj._currkey = obj._currval = _Unbound 391 obj._keyfunc = key 392 return obj 393 394 def __next__(self): 395 # In middle of previous iterator 396 while self._currkey == self._tgtkey: 397 self._currval = next(self._it) 398 self._currkey = ( 399 self._currval if self._keyfunc is None else self._keyfunc(self._currval) 400 ) 401 if self._currkey is _Unbound: 402 raise StopIteration 403 # remember group of returned iterator 404 self._tgtkey = self._currkey 405 return self._currkey, _groupby_iterator(self, self._currkey) 406 407 408class islice: 409 def __new__(cls, seq, stop_or_start, stop=_Unbound, step=_Unbound): 410 result = object.__new__(cls) 411 result._it = iter(seq) 412 result._count = 0 413 if stop is _Unbound: 414 start = 0 415 stop = stop_or_start 416 step = 1 417 else: 418 start = 0 if stop_or_start is None else stop_or_start 419 if step is _Unbound or step is None: 420 step = 1 421 elif not _int_check(step) or step < 1: 422 raise ValueError( 423 "Step for islice() must be a positive integer or None." 424 ) 425 if stop is None: 426 stop = -1 427 elif not _int_check(stop) or stop == -1: 428 raise ValueError( 429 "Stop argument for islice() must be None or an " 430 "integer: 0 <= x <= sys.maxsize." 431 ) 432 if not _int_check(start) or start < 0 or stop < -1: 433 raise ValueError( 434 "Indices for islice() must be None or an integer: " 435 "0 <= x <= sys.maxsize." 436 ) 437 result._next = start 438 result._stop = stop 439 result._step = step 440 return result 441 442 def __iter__(self): 443 return self 444 445 def __next__(self): 446 it = self._it 447 if it is None: 448 raise StopIteration 449 count = self._count 450 new_next = self._next 451 while count < new_next: 452 try: 453 next(it) 454 except Exception as exc: 455 self._it = None 456 raise exc 457 count += 1 458 stop = self._stop 459 if count >= stop and stop != -1: 460 self._it = None 461 raise StopIteration 462 try: 463 item = next(it) 464 except Exception as exc: 465 self._it = None 466 raise exc 467 self._count = count + 1 468 new_next += self._step 469 if new_next > stop and stop != -1: 470 new_next = stop 471 self._next = new_next 472 return item 473 474 def __reduce__(self): 475 _unimplemented() 476 477 def __setstate__(self): 478 _unimplemented() 479 480 481class permutations: 482 def __iter__(self): 483 return self 484 485 def __new__(cls, iterable, r=None): 486 seq = tuple(iterable) 487 n = _tuple_len(seq) 488 489 result = object.__new__(cls) 490 491 if r is None: 492 r = n 493 elif r > n: 494 result._seq = None 495 return result 496 497 result._seq = seq 498 result._r = r 499 result._indices = list(range(n)) 500 result._cycles = list(range(n, n - r, -1)) 501 return result 502 503 def __next__(self): 504 seq = self._seq 505 if seq is None: 506 raise StopIteration 507 r = self._r 508 indices = self._indices 509 indices_len = _list_len(indices) 510 result = (*(seq[indices[i]] for i in range(r)),) 511 cycles = self._cycles 512 i = r - 1 513 while i >= 0: 514 j = cycles[i] - 1 515 if j > 0: 516 cycles[i] = j 517 indices[i], indices[-j] = indices[-j], indices[i] 518 break 519 cycles[i] = indices_len - i 520 tmp = indices[i] 521 k = i + 1 522 while k < indices_len: 523 indices[k - 1] = indices[k] 524 k += 1 525 indices[k - 1] = tmp 526 i -= 1 527 else: 528 self._seq = None 529 return result 530 531 def __reduce__(self): 532 _unimplemented() 533 534 def __setstate__(self): 535 _unimplemented() 536 537 def __sizeof__(self): 538 _unimplemented() 539 540 541class product: 542 def __iter__(self): 543 return self 544 545 def __new__(cls, *iterables, repeat=1): 546 if not _int_check(repeat): 547 raise TypeError 548 length = _tuple_len(iterables) if repeat else 0 549 i = 0 550 repeated = _list_new(length) 551 result = object.__new__(cls) 552 while i < length: 553 item = tuple(iterables[i]) 554 if not item: 555 result._iterables = None 556 return result 557 repeated[i] = item 558 i += 1 559 repeated *= repeat 560 result._iterables = repeated 561 result._digits = _list_new(length * repeat, 0) 562 return result 563 564 def __next__(self): 565 iterables = self._iterables 566 if iterables is None: 567 raise StopIteration 568 digits = self._digits 569 length = _list_len(iterables) 570 result = _list_new(length) 571 i = length - 1 572 carry = 1 573 while i >= 0: 574 j = digits[i] 575 result[i] = iterables[i][j] 576 j += carry 577 if j < _tuple_len(iterables[i]): 578 carry = 0 579 digits[i] = j 580 else: 581 carry = 1 582 digits[i] = 0 583 i -= 1 584 if carry: 585 # counter overflowed, stop iteration 586 self._iterables = None 587 return tuple(result) 588 589 def __reduce__(self): 590 _unimplemented() 591 592 def __setstate__(self): 593 _unimplemented() 594 595 def __sizeof__(self): 596 _unimplemented() 597 598 599class repeat: 600 def __iter__(self): 601 return self 602 603 def __new__(cls, elem, times=None): 604 result = object.__new__(cls) 605 result._elem = elem 606 if times is not None: 607 _int_guard(times) 608 result._times = times 609 return result 610 611 def __next__(self): 612 if self._times is None: 613 return self._elem 614 if self._times > 0: 615 self._times -= 1 616 return self._elem 617 raise StopIteration 618 619 def __length_hint__(self): 620 _unimplemented() 621 622 def __reduce__(self): 623 _unimplemented() 624 625 def __repr__(self): 626 _unimplemented() 627 628 629class starmap: 630 def __iter__(self): 631 return self 632 633 def __new__(cls, function, iterable): 634 result = object.__new__(cls) 635 result._it = iter(iterable) 636 result._func = function 637 return result 638 639 def __next__(self): 640 args = next(self._it) 641 return self._func(*args) 642 643 def __reduce__(self): 644 _unimplemented() 645 646 647def tee(iterable, n=2): 648 _int_guard(n) 649 if n < 0: 650 raise ValueError("n must be >= 0") 651 if n == 0: 652 return () 653 654 it = iter(iterable) 655 copyable = it if hasattr(it, "__copy__") else _tee.from_iterable(it) 656 copyfunc = copyable.__copy__ 657 return tuple(copyable if i == 0 else copyfunc() for i in range(n)) 658 659 660# Internal cache for tee, a linked list where each link is a cached window to 661# a section of the source iterator 662class _tee_dataobject: 663 # CPython sets this at 57 to align exactly with cache line size. We choose 664 # 55 to align with cache lines in our system: Arrays <=255 elements have 1 665 # word of header. The header and each data element is 8 bytes on a 64-bit 666 # machine. Cache lines are 64-bytes on all x86 machines though they tend to 667 # be fetched in pairs, so any multiple of 8 minus 1 up to 255 is fine. 668 _MAX_VALUES = 55 669 670 def __init__(self, it): 671 self._num_read = 0 672 self._next_link = _Unbound 673 self._it = it 674 self._values = [] 675 676 def get_item(self, i): 677 assert i < self.__class__._MAX_VALUES 678 679 if i < self._num_read: 680 return self._values[i] 681 else: 682 assert i == self._num_read 683 value = next(self._it) 684 self._num_read += 1 685 # mutable tuple might be a nice future optimization here 686 self._values.append(value) 687 return value 688 689 def next_link(self): 690 if self._next_link is _Unbound: 691 self._next_link = self.__class__(self._it) 692 return self._next_link 693 694 695class _tee: 696 def __copy__(self): 697 return self.__class__(self._data, self._index) 698 699 def __init__(self, data, index): 700 self._data = data 701 self._index = index 702 703 def __iter__(self): 704 return self 705 706 def __next__(self): 707 if self._index >= _tee_dataobject._MAX_VALUES: 708 self._data = self._data.next_link() 709 self._index = 0 710 711 value = self._data.get_item(self._index) 712 self._index += 1 713 return value 714 715 def __reduce__(self): 716 _unimplemented() 717 718 def __setstate__(self): 719 _unimplemented() 720 721 @classmethod 722 def from_iterable(cls, iterable): 723 it = iter(iterable) 724 725 if isinstance(it, _tee): 726 return it.__copy__() 727 else: 728 return cls(_tee_dataobject(it), 0) 729 730 731class takewhile: 732 def __iter__(self): 733 return self 734 735 def __new__(cls, predicate, iterable): 736 result = object.__new__(cls) 737 result._it = iter(iterable) 738 result._func = predicate 739 result._stop = False 740 return result 741 742 def __next__(self): 743 if self._stop: 744 raise StopIteration 745 746 item = next(self._it) 747 if self._func(item): 748 return item 749 750 self._stop = True 751 raise StopIteration 752 753 def __reduce__(self): 754 _unimplemented() 755 756 def __setstate__(self): 757 _unimplemented() 758 759 760class zip_longest: 761 def __iter__(self): 762 return self 763 764 def __new__(cls, *seqs, fillvalue=None): 765 length = _tuple_len(seqs) 766 result = object.__new__(cls) 767 result._iters = [iter(seq) for seq in seqs] 768 result._num_iters = length 769 result._num_active = length 770 result._fillvalue = fillvalue 771 return result 772 773 def __next__(self): 774 iters = self._iters 775 if not self._num_active: 776 raise StopIteration 777 fillvalue = self._fillvalue 778 values = _list_new(self._num_iters, fillvalue) 779 for i, it in enumerate(iters): 780 try: 781 values[i] = next(it) 782 except StopIteration: 783 self._num_active -= 1 784 if not self._num_active: 785 raise 786 self._iters[i] = repeat(fillvalue) 787 return tuple(values) 788 789 def __reduce__(self): 790 _unimplemented() 791 792 def __setstate__(self): 793 _unimplemented()