utils: add IntervalSet.__add__ (#76335)

Most algo around agendas amounts to adding a bunch of intervals them
removing some. The method to add them was missing.
This commit is contained in:
Benjamin Dauvergne 2023-04-06 15:57:21 +02:00
parent 881b585c3d
commit fd28d075a5
2 changed files with 149 additions and 2 deletions

View File

@ -15,9 +15,35 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import bisect
import collections
import typing
Interval = collections.namedtuple('Interval', ['begin', 'end'])
class Interval(typing.NamedTuple):
begin: typing.Any
end: typing.Any
def disjoint(self, other):
return self < other or self > other
def overlaps(self, other):
return not self.disjoint(other)
def __lt__(self, other):
return self[1] < other[0]
def __gt__(self, other):
return other[1] < self[0]
def union(self, other):
other = self.cast(other)
assert self.overlaps(other)
return Interval(min(self.begin, other.begin), max(self.end, other.end))
@classmethod
def cast(cls, other):
if isinstance(other, cls):
return other
return cls(*other)
class IntervalSet:
@ -159,6 +185,9 @@ class IntervalSet:
return value
return cls(value)
def __rsub__(self, other):
return self.cast(other) - self
def __sub__(self, other):
l1 = iter(self)
l2 = iter(self.cast(other))
@ -200,6 +229,95 @@ class IntervalSet:
return self.__class__.from_ordered(gen())
def __radd__(self, other):
return self.cast(other).__add__(self)
def __add__(self, other):
l1 = iter(self)
l2 = iter(self.cast(other))
def gen():
state = 3
current = None
while True:
if state & 1:
c1 = next(l1, None)
if state & 2:
c2 = next(l2, None)
if current:
if not c1 and not c2:
yield current
break
if not c1:
if current < c2:
yield current
yield c2
break
if c2 < current:
yield c2
yield current
else:
yield current.union(c2)
break
if not c2:
if current < c1:
yield current
yield c1
break
if c1 < current:
yield c1
yield current
else:
yield current.union(c1)
break
if current < c1 and current < c2:
yield current
current = None
elif current.overlaps(c1) and current.overlaps(c2):
yield current.union(c1).union(c2)
current = None
state = 3
continue
elif current < c2:
yield current.union(c1)
current = None
state = 1
continue
else:
yield current.union(c2)
current = None
state = 2
continue
if not c1 and not c2:
# l1 and l2 are empty, stop
break
if not c1:
# l1 is empty, yield c2 and stop
yield c2
break
if not c2:
# l2 is empty, yield c1 and stop
yield c1
break
if c1 < c2:
# l1 is before l2, yield c1 and advance l1 only
yield c1
state = 1
continue
if c2 < c1:
# l2 is before l1, yield c2 and advance l2 only
yield c2
state = 2
continue
current = c1.union(c2)
state = 3
# finish by yielding from the not empty ones
yield from l1
yield from l2
return self.__class__.from_ordered(gen())
def min(self):
if self:
return self.begin[0]

View File

@ -3,6 +3,11 @@ import pytest
from chrono.utils.interval import Interval, IntervalSet
def test_interval_union():
assert Interval(1, 2).union((2, 3)) == (1, 3)
assert Interval(1, 2).union((2, 3)) == Interval(1, 3)
def test_interval_set_merge_adjacent():
"""
Test that adjacent intervals are merged
@ -67,6 +72,8 @@ def test_interval_set_sub():
assert (s - []) == s
assert (IntervalSet([(0, 2)]) - [(1, 2)]) == [(0, 1)]
assert ([] - s) == []
def test_interval_set_min_max():
assert IntervalSet().min() is None
@ -86,3 +93,25 @@ def test_interval_set_eq():
assert not IntervalSet([(1, 2)]) == None # noqa pylint: disable=singleton-comparison
# noqa pylint: disable=singleton-comparison
assert not None == IntervalSet([(1, 2)])
def test_interval_set_add():
s = IntervalSet([(0, 3), (4, 7), (8, 11), (12, 15)])
t = IntervalSet([(3, 4), (7, 8), (11, 12)])
assert s + t == [(0, 15)]
assert t + s == [(0, 15)]
assert [(3, 4), (7, 8), (11, 12)] + s == [(0, 15)]
t = IntervalSet([(3, 4), (11, 12)])
assert s + t == [(0, 7), (8, 15)]
assert t + s == [(0, 7), (8, 15)]
t = IntervalSet([(2, 5), (10, 13)])
assert s + t == [(0, 7), (8, 15)]
assert t + s == [(0, 7), (8, 15)]
assert s + [] == s
assert [] + s == s
assert IntervalSet([(1, 3), (4, 6)]) + [(3, 4), (4, 5)] == [(1, 6)]