1 """Sequence sorted by a key function.
2
3 Created by Raymond Hettinger on Fri, 16 Apr 2010 (MIT license)
4 """
5
6 from bisect import bisect_left, bisect_right
7
9 '''Sequence sorted by a key function.
10
11 SortedCollection() is much easier to work with than using bisect() directly.
12 It supports key functions like those use in sorted(), min(), and max().
13 The result of the key function call is saved so that keys can be searched
14 efficiently.
15
16 Instead of returning an insertion-point which can be hard to interpret, the
17 five find-methods return a specific item in the sequence. They can scan for
18 exact matches, the last item less-than-or-equal to a key, or the first item
19 greater-than-or-equal to a key.
20
21 Once found, an item's ordinal position can be located with the index() method.
22 New items can be added with the insert() and insert_right() methods.
23 Old items can be deleted with the remove() method.
24
25 The usual sequence methods are provided to support indexing, slicing,
26 length lookup, clearing, copying, forward and reverse iteration, contains
27 checking, item counts, item removal, and a nice looking repr.
28
29 Finding and indexing are O(log n) operations while iteration and insertion
30 are O(n). The initial sort is O(n log n).
31
32 The key function is stored in the 'key' attibute for easy introspection or
33 so that you can assign a new key function (triggering an automatic re-sort).
34
35 In short, the class was designed to handle all of the common use cases for
36 bisect but with a simpler API and support for key functions.
37
38 >>> from pprint import pprint
39 >>> from operator import itemgetter
40
41 >>> s = SortedCollection(key=itemgetter(2))
42 >>> for record in [
43 ... ('roger', 'young', 30),
44 ... ('angela', 'jones', 28),
45 ... ('bill', 'smith', 22),
46 ... ('david', 'thomas', 32)]:
47 ... s.insert(record)
48
49 >>> pprint(list(s)) # show records sorted by age
50 [('bill', 'smith', 22),
51 ('angela', 'jones', 28),
52 ('roger', 'young', 30),
53 ('david', 'thomas', 32)]
54
55 >>> s.find_le(29) # find oldest person aged 29 or younger
56 ('angela', 'jones', 28)
57 >>> s.find_lt(28) # find oldest person under 28
58 ('bill', 'smith', 22)
59 >>> s.find_gt(28) # find youngest person over 28
60 ('roger', 'young', 30)
61
62 >>> r = s.find_ge(32) # find youngest person aged 32 or older
63 >>> s.index(r) # get the index of their record
64 3
65 >>> s[3] # fetch the record at that index
66 ('david', 'thomas', 32)
67
68 >>> s.key = itemgetter(0) # now sort by first name
69 >>> pprint(list(s))
70 [('angela', 'jones', 28),
71 ('bill', 'smith', 22),
72 ('david', 'thomas', 32),
73 ('roger', 'young', 30)]
74
75 '''
76
77 - def __init__(self, iterable=(), key=None):
78 self._given_key = key
79 key = (lambda x: x) if key is None else key
80 decorated = sorted((key(item), item) for item in iterable)
81 self._keys = [k for k, item in decorated]
82 self._items = [item for k, item in decorated]
83 self._key = key
84
87
91
94
95 key = property(_getkey, _setkey, _delkey, 'key function')
96
99
101 return self.__class__(self, self._key)
102
104 return len(self._items)
105
107 return self._items[i]
108
110 return iter(self._items)
111
113 return reversed(self._items)
114
116 return '%s(%r, key=%s)' % (
117 self.__class__.__name__,
118 self._items,
119 getattr(self._given_key, '__name__', repr(self._given_key))
120 )
121
123 return self.__class__, (self._items, self._given_key)
124
126 k = self._key(item)
127 i = bisect_left(self._keys, k)
128 j = bisect_right(self._keys, k)
129 return item in self._items[i:j]
130
132 'Find the position of an item. Raise ValueError if not found.'
133 k = self._key(item)
134 i = bisect_left(self._keys, k)
135 j = bisect_right(self._keys, k)
136 return self._items[i:j].index(item) + i
137
139 'Find the position of last item with a key <= k. Raise ValueError if not found.'
140 i = bisect_right(self._keys, k)
141 if i:
142 return i
143 raise ValueError('No item found with key at or below: %r' % (k,))
144
146 'Find the position of last item with a key < k. Raise ValueError if not found.'
147 i = bisect_left(self._keys, k)
148 if i:
149 return i
150 raise ValueError('No item found with key below: %r' % (k,))
151
153 'Find the position of first item with a key >= equal to k. Raise ValueError if not found'
154 i = bisect_left(self._keys, k)
155 if i != len(self):
156 return i
157 raise ValueError('No item found with key at or above: %r' % (k,))
158
160 'Find the position of first item with a key > k. Raise ValueError if not found'
161 i = bisect_right(self._keys, k)
162 if i != len(self):
163 return i
164 raise ValueError('No item found with key above: %r' % (k,))
165
167 'Return number of occurrences of item'
168 k = self._key(item)
169 i = bisect_left(self._keys, k)
170 j = bisect_right(self._keys, k)
171 return self._items[i:j].count(item)
172
174 'Insert a new item. If equal keys are found, add to the left'
175 k = self._key(item)
176 i = bisect_left(self._keys, k)
177 self._keys.insert(i, k)
178 self._items.insert(i, item)
179
181 'Insert a new item. If equal keys are found, add to the right'
182 k = self._key(item)
183 i = bisect_right(self._keys, k)
184 self._keys.insert(i, k)
185 self._items.insert(i, item)
186
188 'Remove first occurence of item. Raise ValueError if not found'
189 i = self.index(item)
190 del self._keys[i]
191 del self._items[i]
192
194 'Return first item with a key == k. Raise ValueError if not found.'
195 i = bisect_left(self._keys, k)
196 if i != len(self) and self._keys[i] == k:
197 return self._items[i]
198 raise ValueError('No item found with key equal to: %r' % (k,))
199
201 'Return last item with a key <= k. Raise ValueError if not found.'
202 i = bisect_right(self._keys, k)
203 if i:
204 return self._items[i-1]
205 raise ValueError('No item found with key at or below: %r' % (k,))
206
208 'Return last item with a key < k. Raise ValueError if not found.'
209 i = bisect_left(self._keys, k)
210 if i:
211 return self._items[i-1]
212 raise ValueError('No item found with key below: %r' % (k,))
213
215 'Return first item with a key >= equal to k. Raise ValueError if not found'
216 i = bisect_left(self._keys, k)
217 if i != len(self):
218 return self._items[i]
219 raise ValueError('No item found with key at or above: %r' % (k,))
220
222 'Return first item with a key > k. Raise ValueError if not found'
223 i = bisect_right(self._keys, k)
224 if i != len(self):
225 return self._items[i]
226 raise ValueError('No item found with key above: %r' % (k,))
227
228
229
230 if __name__ == '__main__':
231
233 'Convert ValueError result to -1'
234 try:
235 return f(*args)
236 except ValueError:
237 return -1
238
240 'Location of match or -1 if not found'
241 for i, item in enumerate(seq):
242 if item == k:
243 return i
244 return -1
245
247 'First item with a key equal to k. -1 if not found'
248 for item in seq:
249 if item == k:
250 return item
251 return -1
252
254 'Last item with a key less-than or equal to k.'
255 for item in reversed(seq):
256 if item <= k:
257 return item
258 return -1
259
261 'Last item with a key less-than k.'
262 for item in reversed(seq):
263 if item < k:
264 return item
265 return -1
266
268 'First item with a key-value greater-than or equal to k.'
269 for item in seq:
270 if item >= k:
271 return item
272 return -1
273
275 'First item with a key-value greater-than or equal to k.'
276 for item in seq:
277 if item > k:
278 return item
279 return -1
280
281 from random import choice
282 pool = [1.5, 2, 2.0, 3, 3.0, 3.5, 4, 4.0, 4.5]
283 for i in range(500):
284 for n in range(6):
285 s = [choice(pool) for i in range(n)]
286 sc = SortedCollection(s)
287 s.sort()
288 for probe in pool:
289 assert repr(ve2no(sc.index, probe)) == repr(slow_index(s, probe))
290 assert repr(ve2no(sc.find, probe)) == repr(slow_find(s, probe))
291 assert repr(ve2no(sc.find_le, probe)) == repr(slow_find_le(s, probe))
292 assert repr(ve2no(sc.find_lt, probe)) == repr(slow_find_lt(s, probe))
293 assert repr(ve2no(sc.find_ge, probe)) == repr(slow_find_ge(s, probe))
294 assert repr(ve2no(sc.find_gt, probe)) == repr(slow_find_gt(s, probe))
295 for i, item in enumerate(s):
296 assert repr(item) == repr(sc[i])
297 assert item in sc
298 assert s.count(item) == sc.count(item)
299 assert len(sc) == n
300 assert list(map(repr, reversed(sc))) == list(map(repr, reversed(s)))
301 assert list(sc.copy()) == list(sc)
302 sc.clear()
303 assert len(sc) == 0
304
305 sd = SortedCollection('The quick Brown Fox jumped'.split(), key=str.lower)
306 assert sd._keys == ['brown', 'fox', 'jumped', 'quick', 'the']
307 assert sd._items == ['Brown', 'Fox', 'jumped', 'quick', 'The']
308 assert sd._key == str.lower
309 assert repr(sd) == "SortedCollection(['Brown', 'Fox', 'jumped', 'quick', 'The'], key=lower)"
310 sd.key = str.upper
311 assert sd._key == str.upper
312 assert len(sd) == 5
313 assert list(reversed(sd)) == ['The', 'quick', 'jumped', 'Fox', 'Brown']
314 for item in sd:
315 assert item in sd
316 for i, item in enumerate(sd):
317 assert item == sd[i]
318 sd.insert('jUmPeD')
319 sd.insert_right('QuIcK')
320 assert sd._keys ==['BROWN', 'FOX', 'JUMPED', 'JUMPED', 'QUICK', 'QUICK', 'THE']
321 assert sd._items == ['Brown', 'Fox', 'jUmPeD', 'jumped', 'quick', 'QuIcK', 'The']
322 assert sd.find_le('JUMPED') == 'jumped', sd.find_le('JUMPED')
323 assert sd.find_ge('JUMPED') == 'jUmPeD'
324 assert sd.find_le('GOAT') == 'Fox'
325 assert sd.find_ge('GOAT') == 'jUmPeD'
326 assert sd.find('FOX') == 'Fox'
327 assert sd[3] == 'jumped'
328 assert sd[3:5] ==['jumped', 'quick']
329 assert sd[-2] == 'QuIcK'
330 assert sd[-4:-2] == ['jumped', 'quick']
331 for i, item in enumerate(sd):
332 assert sd.index(item) == i
333 try:
334 sd.index('xyzpdq')
335 except ValueError:
336 pass
337 else:
338 assert 0, 'Oops, failed to notify of missing value'
339 sd.remove('jumped')
340 assert list(sd) == ['Brown', 'Fox', 'jUmPeD', 'quick', 'QuIcK', 'The']
341
342 import doctest
343 from operator import itemgetter
344 print(doctest.testmod())
345