Package qubx :: Module sorted_collection
[hide private]
[frames] | no frames]

Source Code for Module qubx.sorted_collection

  1  """Sequence sorted by a key function. 
  2   
  3  Created by Raymond Hettinger on Fri, 16 Apr 2010 (MIT license) 
  4  """ 
  5   
  6  from bisect import bisect_left, bisect_right 
  7   
8 -class SortedCollection(object):
9 '''Sequence sorted by a key function. 10 11 SortedCollection() is much easier to work with than using bisect() directly. 12 It supports key functions like those use in sorted(), min(), and max(). 13 The result of the key function call is saved so that keys can be searched 14 efficiently. 15 16 Instead of returning an insertion-point which can be hard to interpret, the 17 five find-methods return a specific item in the sequence. They can scan for 18 exact matches, the last item less-than-or-equal to a key, or the first item 19 greater-than-or-equal to a key. 20 21 Once found, an item's ordinal position can be located with the index() method. 22 New items can be added with the insert() and insert_right() methods. 23 Old items can be deleted with the remove() method. 24 25 The usual sequence methods are provided to support indexing, slicing, 26 length lookup, clearing, copying, forward and reverse iteration, contains 27 checking, item counts, item removal, and a nice looking repr. 28 29 Finding and indexing are O(log n) operations while iteration and insertion 30 are O(n). The initial sort is O(n log n). 31 32 The key function is stored in the 'key' attibute for easy introspection or 33 so that you can assign a new key function (triggering an automatic re-sort). 34 35 In short, the class was designed to handle all of the common use cases for 36 bisect but with a simpler API and support for key functions. 37 38 >>> from pprint import pprint 39 >>> from operator import itemgetter 40 41 >>> s = SortedCollection(key=itemgetter(2)) 42 >>> for record in [ 43 ... ('roger', 'young', 30), 44 ... ('angela', 'jones', 28), 45 ... ('bill', 'smith', 22), 46 ... ('david', 'thomas', 32)]: 47 ... s.insert(record) 48 49 >>> pprint(list(s)) # show records sorted by age 50 [('bill', 'smith', 22), 51 ('angela', 'jones', 28), 52 ('roger', 'young', 30), 53 ('david', 'thomas', 32)] 54 55 >>> s.find_le(29) # find oldest person aged 29 or younger 56 ('angela', 'jones', 28) 57 >>> s.find_lt(28) # find oldest person under 28 58 ('bill', 'smith', 22) 59 >>> s.find_gt(28) # find youngest person over 28 60 ('roger', 'young', 30) 61 62 >>> r = s.find_ge(32) # find youngest person aged 32 or older 63 >>> s.index(r) # get the index of their record 64 3 65 >>> s[3] # fetch the record at that index 66 ('david', 'thomas', 32) 67 68 >>> s.key = itemgetter(0) # now sort by first name 69 >>> pprint(list(s)) 70 [('angela', 'jones', 28), 71 ('bill', 'smith', 22), 72 ('david', 'thomas', 32), 73 ('roger', 'young', 30)] 74 75 ''' 76
77 - def __init__(self, iterable=(), key=None):
78 self._given_key = key 79 key = (lambda x: x) if key is None else key 80 decorated = sorted((key(item), item) for item in iterable) 81 self._keys = [k for k, item in decorated] 82 self._items = [item for k, item in decorated] 83 self._key = key
84
85 - def _getkey(self):
86 return self._key
87
88 - def _setkey(self, key):
89 if key is not self._key: 90 self.__init__(self._items, key=key)
91
92 - def _delkey(self):
93 self._setkey(None)
94 95 key = property(_getkey, _setkey, _delkey, 'key function') 96
97 - def clear(self):
98 self.__init__([], self._key)
99
100 - def copy(self):
101 return self.__class__(self, self._key)
102
103 - def __len__(self):
104 return len(self._items)
105
106 - def __getitem__(self, i):
107 return self._items[i]
108
109 - def __iter__(self):
110 return iter(self._items)
111
112 - def __reversed__(self):
113 return reversed(self._items)
114
115 - def __repr__(self):
116 return '%s(%r, key=%s)' % ( 117 self.__class__.__name__, 118 self._items, 119 getattr(self._given_key, '__name__', repr(self._given_key)) 120 )
121
122 - def __reduce__(self):
123 return self.__class__, (self._items, self._given_key)
124
125 - def __contains__(self, item):
126 k = self._key(item) 127 i = bisect_left(self._keys, k) 128 j = bisect_right(self._keys, k) 129 return item in self._items[i:j]
130
131 - def index(self, item):
132 'Find the position of an item. Raise ValueError if not found.' 133 k = self._key(item) 134 i = bisect_left(self._keys, k) 135 j = bisect_right(self._keys, k) 136 return self._items[i:j].index(item) + i
137
138 - def index_le(self, k):
139 'Find the position of last item with a key <= k. Raise ValueError if not found.' 140 i = bisect_right(self._keys, k) 141 if i: 142 return i 143 raise ValueError('No item found with key at or below: %r' % (k,))
144
145 - def index_lt(self, k):
146 'Find the position of last item with a key < k. Raise ValueError if not found.' 147 i = bisect_left(self._keys, k) 148 if i: 149 return i 150 raise ValueError('No item found with key below: %r' % (k,))
151
152 - def index_ge(self, k):
153 'Find the position of first item with a key >= equal to k. Raise ValueError if not found' 154 i = bisect_left(self._keys, k) 155 if i != len(self): 156 return i 157 raise ValueError('No item found with key at or above: %r' % (k,))
158
159 - def index_gt(self, k):
160 'Find the position of first item with a key > k. Raise ValueError if not found' 161 i = bisect_right(self._keys, k) 162 if i != len(self): 163 return i 164 raise ValueError('No item found with key above: %r' % (k,))
165
166 - def count(self, item):
167 'Return number of occurrences of item' 168 k = self._key(item) 169 i = bisect_left(self._keys, k) 170 j = bisect_right(self._keys, k) 171 return self._items[i:j].count(item)
172
173 - def insert(self, item):
174 'Insert a new item. If equal keys are found, add to the left' 175 k = self._key(item) 176 i = bisect_left(self._keys, k) 177 self._keys.insert(i, k) 178 self._items.insert(i, item)
179
180 - def insert_right(self, item):
181 'Insert a new item. If equal keys are found, add to the right' 182 k = self._key(item) 183 i = bisect_right(self._keys, k) 184 self._keys.insert(i, k) 185 self._items.insert(i, item)
186
187 - def remove(self, item):
188 'Remove first occurence of item. Raise ValueError if not found' 189 i = self.index(item) 190 del self._keys[i] 191 del self._items[i]
192
193 - def find(self, k):
194 'Return first item with a key == k. Raise ValueError if not found.' 195 i = bisect_left(self._keys, k) 196 if i != len(self) and self._keys[i] == k: 197 return self._items[i] 198 raise ValueError('No item found with key equal to: %r' % (k,))
199
200 - def find_le(self, k):
201 'Return last item with a key <= k. Raise ValueError if not found.' 202 i = bisect_right(self._keys, k) 203 if i: 204 return self._items[i-1] 205 raise ValueError('No item found with key at or below: %r' % (k,))
206
207 - def find_lt(self, k):
208 'Return last item with a key < k. Raise ValueError if not found.' 209 i = bisect_left(self._keys, k) 210 if i: 211 return self._items[i-1] 212 raise ValueError('No item found with key below: %r' % (k,))
213
214 - def find_ge(self, k):
215 'Return first item with a key >= equal to k. Raise ValueError if not found' 216 i = bisect_left(self._keys, k) 217 if i != len(self): 218 return self._items[i] 219 raise ValueError('No item found with key at or above: %r' % (k,))
220
221 - def find_gt(self, k):
222 'Return first item with a key > k. Raise ValueError if not found' 223 i = bisect_right(self._keys, k) 224 if i != len(self): 225 return self._items[i] 226 raise ValueError('No item found with key above: %r' % (k,))
227 228 229 # --------------------------- Simple demo and tests ------------------------- 230 if __name__ == '__main__': 231
232 - def ve2no(f, *args):
233 'Convert ValueError result to -1' 234 try: 235 return f(*args) 236 except ValueError: 237 return -1
238
239 - def slow_index(seq, k):
240 'Location of match or -1 if not found' 241 for i, item in enumerate(seq): 242 if item == k: 243 return i 244 return -1
245
246 - def slow_find(seq, k):
247 'First item with a key equal to k. -1 if not found' 248 for item in seq: 249 if item == k: 250 return item 251 return -1
252
253 - def slow_find_le(seq, k):
254 'Last item with a key less-than or equal to k.' 255 for item in reversed(seq): 256 if item <= k: 257 return item 258 return -1
259
260 - def slow_find_lt(seq, k):
261 'Last item with a key less-than k.' 262 for item in reversed(seq): 263 if item < k: 264 return item 265 return -1
266
267 - def slow_find_ge(seq, k):
268 'First item with a key-value greater-than or equal to k.' 269 for item in seq: 270 if item >= k: 271 return item 272 return -1
273
274 - def slow_find_gt(seq, k):
275 'First item with a key-value greater-than or equal to k.' 276 for item in seq: 277 if item > k: 278 return item 279 return -1
280 281 from random import choice 282 pool = [1.5, 2, 2.0, 3, 3.0, 3.5, 4, 4.0, 4.5] 283 for i in range(500): 284 for n in range(6): 285 s = [choice(pool) for i in range(n)] 286 sc = SortedCollection(s) 287 s.sort() 288 for probe in pool: 289 assert repr(ve2no(sc.index, probe)) == repr(slow_index(s, probe)) 290 assert repr(ve2no(sc.find, probe)) == repr(slow_find(s, probe)) 291 assert repr(ve2no(sc.find_le, probe)) == repr(slow_find_le(s, probe)) 292 assert repr(ve2no(sc.find_lt, probe)) == repr(slow_find_lt(s, probe)) 293 assert repr(ve2no(sc.find_ge, probe)) == repr(slow_find_ge(s, probe)) 294 assert repr(ve2no(sc.find_gt, probe)) == repr(slow_find_gt(s, probe)) 295 for i, item in enumerate(s): 296 assert repr(item) == repr(sc[i]) # test __getitem__ 297 assert item in sc # test __contains__ and __iter__ 298 assert s.count(item) == sc.count(item) # test count() 299 assert len(sc) == n # test __len__ 300 assert list(map(repr, reversed(sc))) == list(map(repr, reversed(s))) # test __reversed__ 301 assert list(sc.copy()) == list(sc) # test copy() 302 sc.clear() # test clear() 303 assert len(sc) == 0 304 305 sd = SortedCollection('The quick Brown Fox jumped'.split(), key=str.lower) 306 assert sd._keys == ['brown', 'fox', 'jumped', 'quick', 'the'] 307 assert sd._items == ['Brown', 'Fox', 'jumped', 'quick', 'The'] 308 assert sd._key == str.lower 309 assert repr(sd) == "SortedCollection(['Brown', 'Fox', 'jumped', 'quick', 'The'], key=lower)" 310 sd.key = str.upper 311 assert sd._key == str.upper 312 assert len(sd) == 5 313 assert list(reversed(sd)) == ['The', 'quick', 'jumped', 'Fox', 'Brown'] 314 for item in sd: 315 assert item in sd 316 for i, item in enumerate(sd): 317 assert item == sd[i] 318 sd.insert('jUmPeD') 319 sd.insert_right('QuIcK') 320 assert sd._keys ==['BROWN', 'FOX', 'JUMPED', 'JUMPED', 'QUICK', 'QUICK', 'THE'] 321 assert sd._items == ['Brown', 'Fox', 'jUmPeD', 'jumped', 'quick', 'QuIcK', 'The'] 322 assert sd.find_le('JUMPED') == 'jumped', sd.find_le('JUMPED') 323 assert sd.find_ge('JUMPED') == 'jUmPeD' 324 assert sd.find_le('GOAT') == 'Fox' 325 assert sd.find_ge('GOAT') == 'jUmPeD' 326 assert sd.find('FOX') == 'Fox' 327 assert sd[3] == 'jumped' 328 assert sd[3:5] ==['jumped', 'quick'] 329 assert sd[-2] == 'QuIcK' 330 assert sd[-4:-2] == ['jumped', 'quick'] 331 for i, item in enumerate(sd): 332 assert sd.index(item) == i 333 try: 334 sd.index('xyzpdq') 335 except ValueError: 336 pass 337 else: 338 assert 0, 'Oops, failed to notify of missing value' 339 sd.remove('jumped') 340 assert list(sd) == ['Brown', 'Fox', 'jUmPeD', 'quick', 'QuIcK', 'The'] 341 342 import doctest 343 from operator import itemgetter 344 print(doctest.testmod()) 345