fix interval sets when overlap is contained completely in an interval (#21290)

Also add more documentations and tests.
This commit is contained in:
Benjamin Dauvergne 2018-01-22 01:55:26 +01:00
parent ca7d336c0b
commit b26c5c2283
2 changed files with 83 additions and 11 deletions

View File

@ -36,22 +36,51 @@ class Interval(object):
return False
return True
def contains(self, point):
return self.begin < point < self.end
def __repr__(self):
return '<Interval [%s, %s] %s>' % (self.begin, self.end, self.data or '')
class Intervals(object):
"Maintain a list of mostly non overlapping intervals, allow removing overlap"
'''Maintain a list of mostly non overlapping intervals, allow removing overlap.
Intervals are indexed by extremums, an interval is also added to the
list of all extremums contained inside the interval.
Exemple of set: a = [1, 10], b = [2, 4], c = [3, 5], d = [8, 9]
Structure, as an ordered dict of the endpoints:
{
1: [a],
2: [a, b],
3: [a, b, c],
4: [a, b, c],
5: [a, c],
8: [a, d],
9: [a, d],
10: [a],
}
'''
def __init__(self):
self.points = []
self.container = []
def __insert_point(self, point, interval):
'''Insert interval in container for this point and if the point is new
copy all intervals from the previous container containing this point.
'''
i = bisect.bisect_left(self.points, point)
if i >= len(self.container) or self.points[i] != point:
self.points.insert(i, point)
self.container.insert(i, [])
if i:
for itv in self.container[i - 1]:
if itv.contains(point):
self.container[i].append(itv)
self.container[i].append(interval)
return i
def add(self, begin, end, data=None):
'Add an interval'
@ -59,17 +88,37 @@ class Intervals(object):
def add_interval(self, interval):
'Add an interval object'
self.__insert_point(interval.begin, interval)
self.__insert_point(interval.end, interval)
a = self.__insert_point(interval.begin, interval)
b = self.__insert_point(interval.end, interval)
for i in xrange(a + 1, b):
self.container[i].append(interval)
def __iter_interval(self, begin, end, modify=False):
'''Search for overlapping intervals by bisecting over the list of
interval endpoints and iterating until a point after the greatest
extremum of the search interval.
We test the first point after the end of the searched interval
because if the searched interval is completely included in one of
the interval, this interval will be part of the nearest point
greater than the end point of the searched interval.
Prevent returning an interval multiple times by creating an already
seen set of intervals.
'''
seen = set()
i = bisect.bisect_left(self.points, begin)
while i < len(self.points) and self.points[i] <= end:
while i < len(self.points):
container = self.container[i]
if modify:
container = list(container)
for interval in container:
if id(interval) in seen:
continue
seen.add(id(interval))
yield self.points[i], interval
if not self.points[i] <= end:
break
i += 1
def remove(self, begin, end):
@ -124,10 +173,13 @@ class Intervals(object):
yield interval.data
def __remove_interval(self, interval):
self.__remove_point_interval(interval.begin, interval)
self.__remove_point_interval(interval.end, interval)
def __remove_point_interval(self, point, interval):
i = bisect.bisect_left(self.points, point)
assert self.points[i] == point
self.container[i].remove(interval)
'''Remove the interval from the container of its extremum points and
all containers between.
'''
a = bisect.bisect_left(self.points, interval.begin)
b = bisect.bisect_left(self.points, interval.end)
# check some invariants
assert self.points[a] == interval.begin
assert self.points[b] == interval.end
for i in xrange(a, b + 1):
self.container[i].remove(interval)

View File

@ -7,6 +7,7 @@ def test_interval_repr():
a = Interval(1, 4)
repr(a)
def test_interval_overlap():
a = Interval(1, 4)
@ -17,6 +18,7 @@ def test_interval_overlap():
assert a.overlap(3, 5)
assert not a.overlap(5, 6)
def test_intervals():
intervals = Intervals()
@ -56,6 +58,7 @@ def test_intervals():
for i in range(15, 20):
assert intervals.overlap(i, i + 1)
def test_interval_remove():
intervals = Intervals()
intervals.add(9, 12)
@ -70,3 +73,20 @@ def test_interval_remove():
intervals.add(14, 17)
intervals.remove(10, 11)
assert list(intervals.search(0, 24)) == [Interval(9, 10), Interval(11, 12), Interval(14, 17)]
def test_doc_test():
a = Interval(1, 10)
b = Interval(2, 4)
c = Interval(3, 5)
d = Interval(8, 9)
s = Intervals()
for x in [a, b, c, d]:
s.add_interval(x)
assert sorted(s.search(2, 9), key=lambda x: x.begin) == [a, b, c, d]
assert sorted(s.search(-1, 11), key=lambda x: x.begin) == [a, b, c, d]
assert sorted(s.search(1, 3), key=lambda x: x.begin) == [a, b]
assert sorted(s.search(4, 10), key=lambda x: x.begin) == [a, c, d]
assert sorted(s.search(5, 10), key=lambda x: x.begin) == [a, d]