複製鏈接
請複製以下鏈接發送給好友

TimSort

鎖定
TimSort 是一個歸併排序做了大量優化的版本。對歸併排序排在已經反向排好序的輸入時做了特別優化。對已經正向排好序的輸入減少回溯。對兩種情況混合(一會升序,一會降序)的輸入處理比較好。
中文名
TimSort
特    點
表現O(n^2)
排    序
升序
減    少
升序部分的回溯

TimSort核心過程

假定,我們的 TimSort 是進行升序排序。TimSort 為了減少對升序部分的回溯和對降序部分的性能倒退,將輸入按其升序和降序特點進行了分區。排序的輸入的單位不是一個個單獨的數字了,而一個個的分區。其中每一個分區我們叫一個“run“。針對這個 run 序列,每次我們拿一個 run 出來進行歸併。每次歸併會將兩個 runs 合併成一個 run。歸併的結果保存到 "run_stack" 上。如果我們覺得有必要歸併了,那麼進行歸併,直到消耗掉所有的 runs。這時將 run_stack 上剩餘的 runs 歸併到只剩一個 run 為止。這時這個僅剩的 run 即為我們需要的排好序的結果。
Python代碼
def timsort(arr):
arr = arr or []
if len(arr) <= 0: return []
runs = _partition_to_runs(arr)
run_stack = []
for run in runs:
run_stack.append(run)
while _should_merge(run_stack):
_merge_stack(run_stack)
while len(run_stack) > 1:
_merge_stack(run_stack)
return run_stack[0]
這裏“覺得有必要”這句話很模糊,到底什麼時候有必要後面會給出定義。

TimSort分區方式

為了在已經按升序排好序的輸入面前減少回溯,我們把輸入當中已經有序的這些段分組,使得它們成為一個基本單元,這樣我們就不必在這個基本單元內部浪費時間進行回溯了。比如[1, 2, 3, 2] 進行分區後就變成了 [[1, 2, 3], [2]]。
為了在已經按降序排好序的輸入面前避免歸併排序倒退成 O(n^2),我們把輸入當中降序的部分翻轉成升序,也作為一個單元。比如 [3, 2, 1, 3] 進行分區後就變成了 [[1, 2, 3], [3]]。
Python代碼
def _partition_to_runs(arr):
partitioned_up_to = 0
while partitioned_up_to < len(arr):
if not len(arr) - partitioned_up_to:
return
if len(arr) - partitioned_up_to == 1:
part = list(arr[-1:])
partitioned_up_to += 1
yield part
else:
if arr[partitioned_up_to] > arr[partitioned_up_to + 1]: # 這裏必須是嚴格降序
next_pos = _find_desc_boundary(arr, partitioned_up_to)
_reverse(arr, partitioned_up_to, next_pos)
else:
next_pos = _find_asc_boundary(arr, partitioned_up_to)
part = arr[partitioned_up_to:next_pos]
partitioned_up_to = next_pos
yield part
def _find_desc_boundary(arr, start):
if start >= len(arr) - 1:
return start + 1
if arr[start] > arr[start+1]: # 這裏必須是嚴格降序
return _find_desc_boundary(arr, start + 1)
else:
return start + 1
def _reverse(arr, start=0, end=None):
# 正常的翻轉函數,實現省略
def _find_asc_boundary(arr, start):
if start >= len(arr) - 1:
return start + 1
if arr[start] <= arr[start+1]:
return _find_asc_boundary(arr, start + 1)
else:
return start + 1
這裏注意降序的部分必須是“嚴格”降序才能進行翻轉。因為 TimSort 的一個重要目標是保持穩定性(stability)。如果在 >= 的情況下進行翻轉這個算法就不再是 stable sorting algorithm 了。
逆向分解
傳統的歸併排序是通過遞歸,用函數棧把每次 "divide" 的結果保存下來的。divide 的最終結果是一個個的基本單元-單個數字。但是我們看到 TimSort 把這個過程反過來了。我們經過一次分區,已經拿到了了基本單元列表,只不過這次基本單元是一串數字。所以我們只能自己手工將將基本單元列表進行合併。

TimSort合併方式

那麼何時進行合併?合併的策略是要在 "run_stack" 上維護一個不變式。當這個不變式被打破時即進行合併。傳統的歸併排序通過二分法可以保證函數棧的深度為 log(n)。我們也模擬這個策略,也讓 run_stack 的長度不超過 log(n)。假如 runN 先入棧,runN+1 緊隨其後入棧。那麼就要求 runN 的長度要是 runN+1 長度的 2 倍。所以歸併的條件是:如果 runN 的長度 < (runN+1 的長度 * 2) 即進行歸併。
Python代碼
# 因為我們每次新添 run 進入 run_stack 時都判斷是否需要歸併,
# 並且在每次歸併之後還要進一步確保 run_stack 是滿足不變式的,
# 所以這裏只判斷棧頭的兩個 run 就夠了。
def _should_merge(run_stack):
if len(run_stack) < 2:
return False
return len(run_stack[-2]) < 2*len(run_stack[-1])
def _merge(ls1, ls2):
# 正常的歸併函數,實現省略
def _merge_stack(run_stack):
head = run_stack.pop()
next = run_stack.pop()
new_run = _merge(next, head)
run_stack.append(new_run)
跟分區的情況類似,這裏在歸併的時候也要用 stable merge。
插入排序優化
到上面的步驟為止,程序已經可以正確地排序了。但是我們知道插入排序在輸入元素數小於一個閥值的時候相比其它排序會更快,所以很多排序算法在 divide 這一步進行到只剩不到這個閥值個數的元素的時候會改用插入排序(比如 JDK6 的快排,參考這裏),所以我們也要做這個優化。
在分區的時候,如果我們觀察到新產生出來的 run 的長度小於適用於插入排序的閥值,我們就用插入排序把這個 run 的長度擴充到這個閥值。
Python代碼
def _partition_to_runs(arr):
partitioned_up_to = 0
while partitioned_up_to < len(arr):
if not len(arr) - partitioned_up_to:
return
if len(arr) - partitioned_up_to == 1:
part = list(arr[-1:])
partitioned_up_to += 1
yield part
else:
if arr[partitioned_up_to] > arr[partitioned_up_to + 1]:
next_pos = _find_desc_boundary(arr, partitioned_up_to)
_reverse(arr, partitioned_up_to, next_pos)
else:
next_pos = _find_asc_boundary(arr, partitioned_up_to)
# 只加了這一句話
next_pos = _do_insertion_sort_optimization(arr, partitioned_up_to, next_pos)
part = arr[partitioned_up_to:next_pos]
partitioned_up_to = next_pos
yield part
def _insertion_sort(arr, start, end):
# 標準插入排序實現
def _do_insertion_sort_optimization(arr, start, end):
length = end - start
if length < INSERTION_SORT_THRESHOLD:
end = min(start+INSERTION_SORT_THRESHOLD, len(arr))
_insertion_sort(arr, start, end)
return end
這裏我們只加一句話就夠了。剩餘的就是標準的插入排序實現。
與原文代碼的差異
TimSort 最多使用 O(n) 臨時內存空間。由於原文是 C 的代碼,為了減少 malloc 的次數而一次性分配了 O(n) 的數組空間。我們這裏因為是用 python,也這麼做會顯得很怪異。所以內存是在每次歸併的時候一點點分配的。
TimSort 的實現邏輯上可以看成分區和歸併兩部分。但由於 C 不支持協程,而 python 通過 generator 部分支持協程。所以為了提高可讀性,分區的部分我是用 generator 的方式做的。在代碼上與歸併的部分完全分離。而原文為了達到 lazy 的目的,是一邊分區一邊歸併的。
完整的實現和測試代碼
Python代碼
# -*- coding: utf-8 -*-
import functools
from unittest import TestCase
INSERTION_SORT_THRESHOLD = 6
def _find_desc_boundary(arr, start):
if start >= len(arr) - 1:
return start + 1
if arr[start] > arr[start+1]:
return _find_desc_boundary(arr, start + 1)
else:
return start + 1
def _reverse(arr, start=0, end=None):
if end is None:
end = len(arr)
for i in range(start, start + (end-start)//2):
opposite = end - i - 1
arr[i], arr[opposite] = arr[opposite], arr[i]
def _find_asc_boundary(arr, start):
if start >= len(arr) - 1:
return start + 1
if arr[start] <= arr[start+1]:
return _find_asc_boundary(arr, start + 1)
else:
return start + 1
def _insertion_sort(arr, start, end):
if end - start <= 1:
return
for i in range(start, end):
v = arr[i]
j = i - 1
while j>=0 and arr[j] > v:
arr[j+1] = arr[j]
j -= 1
arr[j+1] = v
def _do_insertion_sort_optimization(arr, start, end):
length = end - start
if length < INSERTION_SORT_THRESHOLD:
end = min(start+INSERTION_SORT_THRESHOLD, len(arr))
_insertion_sort(arr, start, end)
return end
def _partition_to_runs(arr):
partitioned_up_to = 0
while partitioned_up_to < len(arr):
if not len(arr) - partitioned_up_to:
return
if len(arr) - partitioned_up_to == 1:
part = list(arr[-1:])
partitioned_up_to += 1
yield part
else:
if arr[partitioned_up_to] > arr[partitioned_up_to + 1]:
next_pos = _find_desc_boundary(arr, partitioned_up_to)
_reverse(arr, partitioned_up_to, next_pos)
else:
next_pos = _find_asc_boundary(arr, partitioned_up_to)
next_pos = _do_insertion_sort_optimization(arr, partitioned_up_to, next_pos)
part = arr[partitioned_up_to:next_pos]
partitioned_up_to = next_pos
yield part
def _should_merge(run_stack):
if len(run_stack) < 2:
return False
return len(run_stack[-2]) < 2*len(run_stack[-1])
def _merge(ls1, ls2, merge_storage=None):
ret = merge_storage or []
i1 = 0
i2 = 0
while i1 < len(ls1) and i2 < len(ls2):
a = ls1[i1]
b = ls2[i2]
if a <= b:
ret.append(a)
i1 += 1
else:
ret.append(b)
i2 += 1
ret += ls1[i1:]
ret += ls2[i2:]
return ret
def _merge_stack(run_stack, merge_storage=None):
head = run_stack.pop()
next = run_stack.pop()
new_run = _merge(next, head, merge_storage=merge_storage)
run_stack.append(new_run)
def timsort(arr):
arr = arr or []
if len(arr) <= 0: return []
runs = _partition_to_runs(arr)
run_stack = []
for run in runs:
run_stack.append(run)
while _should_merge(run_stack):
_merge_stack(run_stack)
while len(run_stack) > 1:
_merge_stack(run_stack)
return run_stack[0]
class Test(TestCase):
class Elem:
seq_no = 0
def __init__(self, n):
Elem = Test.Elem
self.n = n
self.seq_no = Elem.seq_no
Elem.seq_no += 1
def __lt__(self, other):
return self.n < other.n
def __str__(self):
return "E" + str(self.n) + "S" + str(self.seq_no)
Elem = functools.total_ordering(Elem)
def setUp(self):
Test.Elem.seq_no = 0
def test_reverse(self):
arr = [3, 2, 1, 4, 7, 5, 6]
_reverse(arr)
self.assertEquals(arr, [6, 5, 7, 4, 1, 2, 3])
arr = [3, 2, 1]
_reverse(arr)
self.assertEquals(arr, [1, 2, 3])
def test_find_asc_boundary(self):
arr = [1, 2, 3, 3, 2]
self.assertEqual(_find_asc_boundary(arr, 0), 4)
arr = [1, 2, 3, 3]
self.assertEqual(_find_asc_boundary(arr, 0), 4)
def test_find_desc_boundary(self):
arr = [3, 2, 1]
self.assertEqual(_find_desc_boundary(arr, 0), 3)
arr = [3, 2, 1, 1]
self.assertEqual(_find_desc_boundary(arr, 0), 3)
def test_merge_stack(self):
arr1 = [1, 2, 3]
arr2 = [2, 3, 4]
stack = [arr1, arr2]
_merge_stack(stack)
self.assertEqual(stack, [[1, 2, 2, 3, 3, 4]])
def test_merge_stability(self):
Elem = Test.Elem
arr1 = map(lambda e: Elem(e), [1, 2, 3])
arr2 = map(lambda e: Elem(e), [2, 3, 4])
stack = [arr1, arr2]
_merge_stack(stack)
self.assertEqual(map(lambda lst: map(str, lst), stack), [['E1S0', 'E2S1', 'E2S3', 'E3S2', 'E3S4', 'E4S5']])
def test_timsort(self):
Elem = Test.Elem
arr = map(lambda e: Elem(e), [3, 1, 2, 2, 7, 5])
ret = timsort(arr)
self.assertEquals(map(str, ret), ['E1S1', 'E2S2', 'E2S3', 'E3S0', 'E5S5', 'E7S4'])
self.assertEqual(timsort([]), [])
self.assertEqual(timsort(None), [])