chrono/chrono/interval.py

187 lines
6.5 KiB
Python

# chrono - agendas system
# Copyright (C) 2017 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import bisect
class Interval(object):
__slots__ = ['begin', 'end', 'data']
def __init__(self, begin, end, data=None):
assert begin < end
self.begin = begin
self.end = end
self.data = data
def __eq__(self, other):
return self.begin == other.begin and self.end == other.end and self.data == other.data
def overlap(self, begin, end):
if end <= self.begin:
return False
if begin >= self.end:
return False
return True
def contains(self, point):
return self.begin < point < self.end
def __repr__(self):
return '<Interval [%s, %s] %s>' % (self.begin, self.end, self.data or '')
class Intervals(object):
'''Maintain a list of mostly non overlapping intervals, allow removing overlap.
Intervals are indexed by extremums, an interval is also added to the
list of all extremums contained inside the interval.
Exemple of set: a = [1, 10], b = [2, 4], c = [3, 5], d = [8, 9]
Structure, as an ordered dict of the endpoints:
{
1: [a],
2: [a, b],
3: [a, b, c],
4: [a, b, c],
5: [a, c],
8: [a, d],
9: [a, d],
10: [a],
}
'''
def __init__(self):
self.points = []
self.container = []
def __insert_point(self, point, interval):
'''Insert interval in container for this point and if the point is new
copy all intervals from the previous container containing this point.
'''
i = bisect.bisect_left(self.points, point)
if i >= len(self.container) or self.points[i] != point:
self.points.insert(i, point)
self.container.insert(i, [])
if i:
for itv in self.container[i - 1]:
if itv.contains(point):
self.container[i].append(itv)
self.container[i].append(interval)
return i
def add(self, begin, end, data=None):
'Add an interval'
self.add_interval(Interval(begin, end, data))
def add_interval(self, interval):
'Add an interval object'
a = self.__insert_point(interval.begin, interval)
b = self.__insert_point(interval.end, interval)
for i in range(a + 1, b):
self.container[i].append(interval)
def __iter_interval(self, begin, end, modify=False):
'''Search for overlapping intervals by bisecting over the list of
interval endpoints and iterating until a point after the greatest
extremum of the search interval.
We test the first point after the end of the searched interval
because if the searched interval is completely included in one of
the interval, this interval will be part of the nearest point
greater than the end point of the searched interval.
Prevent returning an interval multiple times by creating an already
seen set of intervals.
'''
seen = set()
i = bisect.bisect_left(self.points, begin)
while i < len(self.points):
container = self.container[i]
if modify:
container = list(container)
for interval in container:
if id(interval) in seen:
continue
seen.add(id(interval))
yield self.points[i], interval
if not self.points[i] <= end:
break
i += 1
def remove(self, begin, end):
'Substract interval'
for interval in list(self.iter()):
# create interval with new borders
if interval.overlap(begin, end):
if begin > interval.begin and end < interval.end:
self.add(interval.begin, begin)
self.add(end, interval.end)
elif interval.begin < begin:
self.add(interval.begin, begin)
elif interval.end > end:
self.add(end, interval.end)
self.__remove_interval(interval)
def remove_overlap(self, begin, end):
'Remove all overlapping intervals'
for point, interval in self.__iter_interval(begin, end, modify=True):
if interval.overlap(begin, end):
self.__remove_interval(interval)
def overlap(self, begin, end):
'Test if some intervals overlap'
for point, interval in self.__iter_interval(begin, end):
if interval.overlap(begin, end):
return True
return False
def search(self, begin, end):
'Search overlapping intervals'
for point, interval in self.__iter_interval(begin, end):
if interval.overlap(begin, end):
# prevent returning the same interval twice
if interval.begin < begin or interval.begin == point:
yield interval
def search_data(self, begin, end):
'Search data elements of overlapping intervals'
for interval in self.search(begin, end):
yield interval.data
def iter(self):
'Iterate intervals'
if not self.points:
return []
return self.search(self.points[0], self.points[-1])
def iter_data(self):
'Iterate data element attached to intervals'
for interval in self.iter():
yield interval.data
def __remove_interval(self, interval):
'''Remove the interval from the container of its extremum points and
all containers between.
'''
a = bisect.bisect_left(self.points, interval.begin)
b = bisect.bisect_left(self.points, interval.end)
# check some invariants
assert self.points[a] == interval.begin
assert self.points[b] == interval.end
for i in range(a, b + 1):
self.container[i].remove(interval)