Kaydet (Commit) b25aa36f authored tarafından Raymond Hettinger's avatar Raymond Hettinger

Improve the memory performance and speed of heapq.nsmallest() by using

an alternate algorithm when the number of selected items is small
relative to the full iterable.
üst 2e669408
...@@ -130,6 +130,7 @@ __all__ = ['heappush', 'heappop', 'heapify', 'heapreplace', 'nlargest', ...@@ -130,6 +130,7 @@ __all__ = ['heappush', 'heappop', 'heapify', 'heapreplace', 'nlargest',
'nsmallest'] 'nsmallest']
from itertools import islice, repeat from itertools import islice, repeat
import bisect
def heappush(heap, item): def heappush(heap, item):
"""Push item onto heap, maintaining the heap invariant.""" """Push item onto heap, maintaining the heap invariant."""
...@@ -196,6 +197,28 @@ def nsmallest(iterable, n): ...@@ -196,6 +197,28 @@ def nsmallest(iterable, n):
Equivalent to: sorted(iterable)[:n] Equivalent to: sorted(iterable)[:n]
""" """
if hasattr(iterable, '__len__') and n * 10 <= len(iterable):
# For smaller values of n, the bisect method is faster than a minheap.
# It is also memory efficient, consuming only n elements of space.
it = iter(iterable)
result = sorted(islice(it, 0, n))
if not result:
return result
insort = bisect.insort
pop = result.pop
los = result[-1] # los --> Largest of the nsmallest
for elem in it:
if los <= elem:
continue
insort(result, elem)
pop()
los = result[-1]
return result
# An alternative approach manifests the whole iterable in memory but
# saves comparisons by heapifying all at once. Also, saves time
# over bisect.insort() which has O(n) data movement time for every
# insertion. Finding the n smallest of an m length iterable requires
# O(m) + O(n log m) comparisons.
h = list(iterable) h = list(iterable)
heapify(h) heapify(h)
return map(heappop, repeat(h, min(n, len(h)))) return map(heappop, repeat(h, min(n, len(h))))
......
...@@ -92,6 +92,7 @@ class TestHeap(unittest.TestCase): ...@@ -92,6 +92,7 @@ class TestHeap(unittest.TestCase):
def test_nsmallest(self): def test_nsmallest(self):
data = [random.randrange(2000) for i in range(1000)] data = [random.randrange(2000) for i in range(1000)]
self.assertEqual(nsmallest(data, 400), sorted(data)[:400]) self.assertEqual(nsmallest(data, 400), sorted(data)[:400])
self.assertEqual(nsmallest(data, 50), sorted(data)[:50])
def test_largest(self): def test_largest(self):
data = [random.randrange(2000) for i in range(1000)] data = [random.randrange(2000) for i in range(1000)]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment