fix interval sets when overlap is contained completely in an interval (#21290)
Also add more documentations and tests.
This commit is contained in:
parent
ca7d336c0b
commit
b26c5c2283
|
@ -36,22 +36,51 @@ class Interval(object):
|
|||
return False
|
||||
return True
|
||||
|
||||
def contains(self, point):
|
||||
return self.begin < point < self.end
|
||||
|
||||
def __repr__(self):
|
||||
return '<Interval [%s, %s] %s>' % (self.begin, self.end, self.data or '')
|
||||
|
||||
|
||||
class Intervals(object):
|
||||
"Maintain a list of mostly non overlapping intervals, allow removing overlap"
|
||||
'''Maintain a list of mostly non overlapping intervals, allow removing overlap.
|
||||
|
||||
Intervals are indexed by extremums, an interval is also added to the
|
||||
list of all extremums contained inside the interval.
|
||||
|
||||
Exemple of set: a = [1, 10], b = [2, 4], c = [3, 5], d = [8, 9]
|
||||
|
||||
Structure, as an ordered dict of the endpoints:
|
||||
{
|
||||
1: [a],
|
||||
2: [a, b],
|
||||
3: [a, b, c],
|
||||
4: [a, b, c],
|
||||
5: [a, c],
|
||||
8: [a, d],
|
||||
9: [a, d],
|
||||
10: [a],
|
||||
}
|
||||
'''
|
||||
def __init__(self):
|
||||
self.points = []
|
||||
self.container = []
|
||||
|
||||
def __insert_point(self, point, interval):
|
||||
'''Insert interval in container for this point and if the point is new
|
||||
copy all intervals from the previous container containing this point.
|
||||
'''
|
||||
i = bisect.bisect_left(self.points, point)
|
||||
if i >= len(self.container) or self.points[i] != point:
|
||||
self.points.insert(i, point)
|
||||
self.container.insert(i, [])
|
||||
if i:
|
||||
for itv in self.container[i - 1]:
|
||||
if itv.contains(point):
|
||||
self.container[i].append(itv)
|
||||
self.container[i].append(interval)
|
||||
return i
|
||||
|
||||
def add(self, begin, end, data=None):
|
||||
'Add an interval'
|
||||
|
@ -59,17 +88,37 @@ class Intervals(object):
|
|||
|
||||
def add_interval(self, interval):
|
||||
'Add an interval object'
|
||||
self.__insert_point(interval.begin, interval)
|
||||
self.__insert_point(interval.end, interval)
|
||||
a = self.__insert_point(interval.begin, interval)
|
||||
b = self.__insert_point(interval.end, interval)
|
||||
for i in xrange(a + 1, b):
|
||||
self.container[i].append(interval)
|
||||
|
||||
def __iter_interval(self, begin, end, modify=False):
|
||||
'''Search for overlapping intervals by bisecting over the list of
|
||||
interval endpoints and iterating until a point after the greatest
|
||||
extremum of the search interval.
|
||||
|
||||
We test the first point after the end of the searched interval
|
||||
because if the searched interval is completely included in one of
|
||||
the interval, this interval will be part of the nearest point
|
||||
greater than the end point of the searched interval.
|
||||
|
||||
Prevent returning an interval multiple times by creating an already
|
||||
seen set of intervals.
|
||||
'''
|
||||
seen = set()
|
||||
i = bisect.bisect_left(self.points, begin)
|
||||
while i < len(self.points) and self.points[i] <= end:
|
||||
while i < len(self.points):
|
||||
container = self.container[i]
|
||||
if modify:
|
||||
container = list(container)
|
||||
for interval in container:
|
||||
if id(interval) in seen:
|
||||
continue
|
||||
seen.add(id(interval))
|
||||
yield self.points[i], interval
|
||||
if not self.points[i] <= end:
|
||||
break
|
||||
i += 1
|
||||
|
||||
def remove(self, begin, end):
|
||||
|
@ -124,10 +173,13 @@ class Intervals(object):
|
|||
yield interval.data
|
||||
|
||||
def __remove_interval(self, interval):
|
||||
self.__remove_point_interval(interval.begin, interval)
|
||||
self.__remove_point_interval(interval.end, interval)
|
||||
|
||||
def __remove_point_interval(self, point, interval):
|
||||
i = bisect.bisect_left(self.points, point)
|
||||
assert self.points[i] == point
|
||||
self.container[i].remove(interval)
|
||||
'''Remove the interval from the container of its extremum points and
|
||||
all containers between.
|
||||
'''
|
||||
a = bisect.bisect_left(self.points, interval.begin)
|
||||
b = bisect.bisect_left(self.points, interval.end)
|
||||
# check some invariants
|
||||
assert self.points[a] == interval.begin
|
||||
assert self.points[b] == interval.end
|
||||
for i in xrange(a, b + 1):
|
||||
self.container[i].remove(interval)
|
||||
|
|
|
@ -7,6 +7,7 @@ def test_interval_repr():
|
|||
a = Interval(1, 4)
|
||||
repr(a)
|
||||
|
||||
|
||||
def test_interval_overlap():
|
||||
a = Interval(1, 4)
|
||||
|
||||
|
@ -17,6 +18,7 @@ def test_interval_overlap():
|
|||
assert a.overlap(3, 5)
|
||||
assert not a.overlap(5, 6)
|
||||
|
||||
|
||||
def test_intervals():
|
||||
intervals = Intervals()
|
||||
|
||||
|
@ -56,6 +58,7 @@ def test_intervals():
|
|||
for i in range(15, 20):
|
||||
assert intervals.overlap(i, i + 1)
|
||||
|
||||
|
||||
def test_interval_remove():
|
||||
intervals = Intervals()
|
||||
intervals.add(9, 12)
|
||||
|
@ -70,3 +73,20 @@ def test_interval_remove():
|
|||
intervals.add(14, 17)
|
||||
intervals.remove(10, 11)
|
||||
assert list(intervals.search(0, 24)) == [Interval(9, 10), Interval(11, 12), Interval(14, 17)]
|
||||
|
||||
|
||||
def test_doc_test():
|
||||
a = Interval(1, 10)
|
||||
b = Interval(2, 4)
|
||||
c = Interval(3, 5)
|
||||
d = Interval(8, 9)
|
||||
|
||||
s = Intervals()
|
||||
for x in [a, b, c, d]:
|
||||
s.add_interval(x)
|
||||
|
||||
assert sorted(s.search(2, 9), key=lambda x: x.begin) == [a, b, c, d]
|
||||
assert sorted(s.search(-1, 11), key=lambda x: x.begin) == [a, b, c, d]
|
||||
assert sorted(s.search(1, 3), key=lambda x: x.begin) == [a, b]
|
||||
assert sorted(s.search(4, 10), key=lambda x: x.begin) == [a, c, d]
|
||||
assert sorted(s.search(5, 10), key=lambda x: x.begin) == [a, d]
|
||||
|
|
Loading…
Reference in New Issue