general: add custom implementation of interval sets (#20732)
This commit is contained in:
parent
74984105b4
commit
38b79a368c
|
@ -0,0 +1,116 @@
|
|||
# chrono - agendas system
|
||||
# Copyright (C) 2017 Entr'ouvert
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Affero General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import bisect
|
||||
|
||||
|
||||
class Interval(object):
|
||||
__slots__ = ['begin', 'end', 'data']
|
||||
|
||||
def __init__(self, begin, end, data=None):
|
||||
assert begin < end
|
||||
self.begin = begin
|
||||
self.end = end
|
||||
self.data = data
|
||||
|
||||
def overlap(self, begin, end):
|
||||
if end <= self.begin:
|
||||
return False
|
||||
if begin >= self.end:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __repr__(self):
|
||||
return '<Interval [%s, %s] %s>' % (self.begin, self.end, self.data or '')
|
||||
|
||||
|
||||
class Intervals(object):
|
||||
"Maintain a list of mostly non overlapping intervals, allow removing overlap"
|
||||
def __init__(self):
|
||||
self.points = []
|
||||
self.container = []
|
||||
|
||||
def __insert_point(self, point, interval):
|
||||
i = bisect.bisect_left(self.points, point)
|
||||
if i >= len(self.container) or self.points[i] != point:
|
||||
self.points.insert(i, point)
|
||||
self.container.insert(i, [])
|
||||
self.container[i].append(interval)
|
||||
|
||||
def add(self, begin, end, data=None):
|
||||
'Add an interval'
|
||||
self.add_interval(Interval(begin, end, data))
|
||||
|
||||
def add_interval(self, interval):
|
||||
'Add an interval object'
|
||||
self.__insert_point(interval.begin, interval)
|
||||
self.__insert_point(interval.end, interval)
|
||||
|
||||
def __iter_interval(self, begin, end, modify=False):
|
||||
i = bisect.bisect_left(self.points, begin)
|
||||
while i < len(self.points) and self.points[i] <= end:
|
||||
container = self.container[i]
|
||||
if modify:
|
||||
container = list(container)
|
||||
for interval in container:
|
||||
yield self.points[i], interval
|
||||
i += 1
|
||||
|
||||
def remove_overlap(self, begin, end):
|
||||
'Remove all overlapping intervals'
|
||||
for point, interval in self.__iter_interval(begin, end, modify=True):
|
||||
if interval.overlap(begin, end):
|
||||
self.__remove_interval(interval)
|
||||
|
||||
def overlap(self, begin, end):
|
||||
'Test if some intervals overlap'
|
||||
for point, interval in self.__iter_interval(begin, end):
|
||||
if interval.overlap(begin, end):
|
||||
return True
|
||||
return False
|
||||
|
||||
def search(self, begin, end):
|
||||
'Search overlapping intervals'
|
||||
for point, interval in self.__iter_interval(begin, end):
|
||||
if interval.overlap(begin, end):
|
||||
# prevent returning the same interval twice
|
||||
if interval.begin < begin or interval.begin == point:
|
||||
yield interval
|
||||
|
||||
def search_data(self, begin, end):
|
||||
'Search data elements of overlapping intervals'
|
||||
for interval in self.search(begin, end):
|
||||
yield interval.data
|
||||
|
||||
def iter(self):
|
||||
'Iterate intervals'
|
||||
if not self.points:
|
||||
return []
|
||||
return self.search(self.points[0], self.points[-1])
|
||||
|
||||
def iter_data(self):
|
||||
'Iterate data element attached to intervals'
|
||||
for interval in self.iter():
|
||||
yield interval.data
|
||||
|
||||
def __remove_interval(self, interval):
|
||||
self.__remove_point_interval(interval.begin, interval)
|
||||
self.__remove_point_interval(interval.end, interval)
|
||||
|
||||
def __remove_point_interval(self, point, interval):
|
||||
i = bisect.bisect_left(self.points, point)
|
||||
assert self.points[i] == point
|
||||
self.container[i].remove(interval)
|
|
@ -0,0 +1,62 @@
|
|||
import pytest
|
||||
|
||||
try:
|
||||
from intervaltree import IntervalTree
|
||||
except ImportError:
|
||||
IntervalTree = None
|
||||
|
||||
from chrono.interval import Interval, Intervals
|
||||
|
||||
|
||||
def test_interval_repr():
|
||||
a = Interval(1, 4)
|
||||
repr(a)
|
||||
|
||||
def test_interval_overlap():
|
||||
a = Interval(1, 4)
|
||||
|
||||
assert not a.overlap(0, 1)
|
||||
assert a.overlap(0, 2)
|
||||
assert a.overlap(1, 4)
|
||||
assert a.overlap(2, 3)
|
||||
assert a.overlap(3, 5)
|
||||
assert not a.overlap(5, 6)
|
||||
|
||||
def test_intervals():
|
||||
intervals = Intervals()
|
||||
|
||||
assert len(list(intervals.search(0, 5))) == 0
|
||||
|
||||
for i in range(10):
|
||||
intervals.add(i, i + 1, 1)
|
||||
|
||||
for i in range(10, 20):
|
||||
intervals.add(i, i + 1, 2)
|
||||
|
||||
for i in range(5, 15):
|
||||
intervals.add(i, i + 1, 3)
|
||||
|
||||
assert len(list(intervals.search(0, 5))) == 5
|
||||
assert len(list(intervals.search(0, 10))) == 15
|
||||
assert len(list(intervals.search(5, 15))) == 20
|
||||
assert len(list(intervals.search(10, 20))) == 15
|
||||
assert len(list(intervals.search(15, 20))) == 5
|
||||
|
||||
assert set(intervals.search_data(0, 5)) == {1}
|
||||
assert set(intervals.search_data(0, 10)) == {1, 3}
|
||||
assert set(intervals.search_data(5, 15)) == {1, 2, 3}
|
||||
assert set(intervals.search_data(10, 20)) == {2, 3}
|
||||
assert set(intervals.search_data(15, 20)) == {2}
|
||||
|
||||
for i in range(20):
|
||||
assert intervals.overlap(i, i + 1)
|
||||
|
||||
intervals.remove_overlap(5, 15)
|
||||
assert set(intervals.search_data(0, 20)) == {1, 2}
|
||||
|
||||
for i in range(5):
|
||||
assert intervals.overlap(i, i + 1)
|
||||
for i in range(5, 15):
|
||||
assert not intervals.overlap(i, i + 1)
|
||||
for i in range(15, 20):
|
||||
assert intervals.overlap(i, i + 1)
|
Loading…
Reference in New Issue