general: add custom implementation of interval sets (#20732)

This commit is contained in:
Benjamin Dauvergne 2017-12-16 04:17:54 +01:00 committed by Frédéric Péters
parent 74984105b4
commit 38b79a368c
3 changed files with 179 additions and 0 deletions

116
chrono/interval.py Normal file
View File

@ -0,0 +1,116 @@
# chrono - agendas system
# Copyright (C) 2017 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import bisect
class Interval(object):
__slots__ = ['begin', 'end', 'data']
def __init__(self, begin, end, data=None):
assert begin < end
self.begin = begin
self.end = end
self.data = data
def overlap(self, begin, end):
if end <= self.begin:
return False
if begin >= self.end:
return False
return True
def __repr__(self):
return '<Interval [%s, %s] %s>' % (self.begin, self.end, self.data or '')
class Intervals(object):
"Maintain a list of mostly non overlapping intervals, allow removing overlap"
def __init__(self):
self.points = []
self.container = []
def __insert_point(self, point, interval):
i = bisect.bisect_left(self.points, point)
if i >= len(self.container) or self.points[i] != point:
self.points.insert(i, point)
self.container.insert(i, [])
self.container[i].append(interval)
def add(self, begin, end, data=None):
'Add an interval'
self.add_interval(Interval(begin, end, data))
def add_interval(self, interval):
'Add an interval object'
self.__insert_point(interval.begin, interval)
self.__insert_point(interval.end, interval)
def __iter_interval(self, begin, end, modify=False):
i = bisect.bisect_left(self.points, begin)
while i < len(self.points) and self.points[i] <= end:
container = self.container[i]
if modify:
container = list(container)
for interval in container:
yield self.points[i], interval
i += 1
def remove_overlap(self, begin, end):
'Remove all overlapping intervals'
for point, interval in self.__iter_interval(begin, end, modify=True):
if interval.overlap(begin, end):
self.__remove_interval(interval)
def overlap(self, begin, end):
'Test if some intervals overlap'
for point, interval in self.__iter_interval(begin, end):
if interval.overlap(begin, end):
return True
return False
def search(self, begin, end):
'Search overlapping intervals'
for point, interval in self.__iter_interval(begin, end):
if interval.overlap(begin, end):
# prevent returning the same interval twice
if interval.begin < begin or interval.begin == point:
yield interval
def search_data(self, begin, end):
'Search data elements of overlapping intervals'
for interval in self.search(begin, end):
yield interval.data
def iter(self):
'Iterate intervals'
if not self.points:
return []
return self.search(self.points[0], self.points[-1])
def iter_data(self):
'Iterate data element attached to intervals'
for interval in self.iter():
yield interval.data
def __remove_interval(self, interval):
self.__remove_point_interval(interval.begin, interval)
self.__remove_point_interval(interval.end, interval)
def __remove_point_interval(self, point, interval):
i = bisect.bisect_left(self.points, point)
assert self.points[i] == point
self.container[i].remove(interval)

62
tests/test_interval.py Normal file
View File

@ -0,0 +1,62 @@
import pytest
try:
from intervaltree import IntervalTree
except ImportError:
IntervalTree = None
from chrono.interval import Interval, Intervals
def test_interval_repr():
a = Interval(1, 4)
repr(a)
def test_interval_overlap():
a = Interval(1, 4)
assert not a.overlap(0, 1)
assert a.overlap(0, 2)
assert a.overlap(1, 4)
assert a.overlap(2, 3)
assert a.overlap(3, 5)
assert not a.overlap(5, 6)
def test_intervals():
intervals = Intervals()
assert len(list(intervals.search(0, 5))) == 0
for i in range(10):
intervals.add(i, i + 1, 1)
for i in range(10, 20):
intervals.add(i, i + 1, 2)
for i in range(5, 15):
intervals.add(i, i + 1, 3)
assert len(list(intervals.search(0, 5))) == 5
assert len(list(intervals.search(0, 10))) == 15
assert len(list(intervals.search(5, 15))) == 20
assert len(list(intervals.search(10, 20))) == 15
assert len(list(intervals.search(15, 20))) == 5
assert set(intervals.search_data(0, 5)) == {1}
assert set(intervals.search_data(0, 10)) == {1, 3}
assert set(intervals.search_data(5, 15)) == {1, 2, 3}
assert set(intervals.search_data(10, 20)) == {2, 3}
assert set(intervals.search_data(15, 20)) == {2}
for i in range(20):
assert intervals.overlap(i, i + 1)
intervals.remove_overlap(5, 15)
assert set(intervals.search_data(0, 20)) == {1, 2}
for i in range(5):
assert intervals.overlap(i, i + 1)
for i in range(5, 15):
assert not intervals.overlap(i, i + 1)
for i in range(15, 20):
assert intervals.overlap(i, i + 1)

View File

@ -15,6 +15,7 @@ deps =
django111: django>=1.11,<1.12
pytest-cov
pytest-django
intervaltree
pytest>=3.3.0
WebTest
mock