utils: add IntervalSet.__add__ (#76335)
Most algo around agendas amounts to adding a bunch of intervals them removing some. The method to add them was missing.
This commit is contained in:
parent
246e62d96b
commit
33b4c807b4
|
@ -15,9 +15,35 @@
|
|||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import bisect
|
||||
import collections
|
||||
import typing
|
||||
|
||||
Interval = collections.namedtuple('Interval', ['begin', 'end'])
|
||||
|
||||
class Interval(typing.NamedTuple):
|
||||
begin: typing.Any
|
||||
end: typing.Any
|
||||
|
||||
def disjoint(self, other):
|
||||
return self < other or self > other
|
||||
|
||||
def overlaps(self, other):
|
||||
return not self.disjoint(other)
|
||||
|
||||
def __lt__(self, other):
|
||||
return self[1] < other[0]
|
||||
|
||||
def __gt__(self, other):
|
||||
return other[1] < self[0]
|
||||
|
||||
def union(self, other):
|
||||
other = self.cast(other)
|
||||
assert self.overlaps(other)
|
||||
return Interval(min(self.begin, other.begin), max(self.end, other.end))
|
||||
|
||||
@classmethod
|
||||
def cast(cls, other):
|
||||
if isinstance(other, cls):
|
||||
return other
|
||||
return cls(*other)
|
||||
|
||||
|
||||
class IntervalSet:
|
||||
|
@ -159,6 +185,9 @@ class IntervalSet:
|
|||
return value
|
||||
return cls(value)
|
||||
|
||||
def __rsub__(self, other):
|
||||
return self.cast(other) - self
|
||||
|
||||
def __sub__(self, other):
|
||||
l1 = iter(self)
|
||||
l2 = iter(self.cast(other))
|
||||
|
@ -200,6 +229,95 @@ class IntervalSet:
|
|||
|
||||
return self.__class__.from_ordered(gen())
|
||||
|
||||
def __radd__(self, other):
|
||||
return self.cast(other).__add__(self)
|
||||
|
||||
def __add__(self, other):
|
||||
l1 = iter(self)
|
||||
l2 = iter(self.cast(other))
|
||||
|
||||
def gen():
|
||||
state = 3
|
||||
current = None
|
||||
while True:
|
||||
if state & 1:
|
||||
c1 = next(l1, None)
|
||||
if state & 2:
|
||||
c2 = next(l2, None)
|
||||
if current:
|
||||
if not c1 and not c2:
|
||||
yield current
|
||||
break
|
||||
if not c1:
|
||||
if current < c2:
|
||||
yield current
|
||||
yield c2
|
||||
break
|
||||
if c2 < current:
|
||||
yield c2
|
||||
yield current
|
||||
else:
|
||||
yield current.union(c2)
|
||||
break
|
||||
if not c2:
|
||||
if current < c1:
|
||||
yield current
|
||||
yield c1
|
||||
break
|
||||
if c1 < current:
|
||||
yield c1
|
||||
yield current
|
||||
else:
|
||||
yield current.union(c1)
|
||||
break
|
||||
if current < c1 and current < c2:
|
||||
yield current
|
||||
current = None
|
||||
elif current.overlaps(c1) and current.overlaps(c2):
|
||||
yield current.union(c1).union(c2)
|
||||
current = None
|
||||
state = 3
|
||||
continue
|
||||
elif current < c2:
|
||||
yield current.union(c1)
|
||||
current = None
|
||||
state = 1
|
||||
continue
|
||||
else:
|
||||
yield current.union(c2)
|
||||
current = None
|
||||
state = 2
|
||||
continue
|
||||
if not c1 and not c2:
|
||||
# l1 and l2 are empty, stop
|
||||
break
|
||||
if not c1:
|
||||
# l1 is empty, yield c2 and stop
|
||||
yield c2
|
||||
break
|
||||
if not c2:
|
||||
# l2 is empty, yield c1 and stop
|
||||
yield c1
|
||||
break
|
||||
if c1 < c2:
|
||||
# l1 is before l2, yield c1 and advance l1 only
|
||||
yield c1
|
||||
state = 1
|
||||
continue
|
||||
if c2 < c1:
|
||||
# l2 is before l1, yield c2 and advance l2 only
|
||||
yield c2
|
||||
state = 2
|
||||
continue
|
||||
current = c1.union(c2)
|
||||
state = 3
|
||||
|
||||
# finish by yielding from the not empty ones
|
||||
yield from l1
|
||||
yield from l2
|
||||
|
||||
return self.__class__.from_ordered(gen())
|
||||
|
||||
def min(self):
|
||||
if self:
|
||||
return self.begin[0]
|
||||
|
|
|
@ -3,6 +3,11 @@ import pytest
|
|||
from chrono.utils.interval import Interval, IntervalSet
|
||||
|
||||
|
||||
def test_interval_union():
|
||||
assert Interval(1, 2).union((2, 3)) == (1, 3)
|
||||
assert Interval(1, 2).union((2, 3)) == Interval(1, 3)
|
||||
|
||||
|
||||
def test_interval_set_merge_adjacent():
|
||||
"""
|
||||
Test that adjacent intervals are merged
|
||||
|
@ -67,6 +72,8 @@ def test_interval_set_sub():
|
|||
assert (s - []) == s
|
||||
assert (IntervalSet([(0, 2)]) - [(1, 2)]) == [(0, 1)]
|
||||
|
||||
assert ([] - s) == []
|
||||
|
||||
|
||||
def test_interval_set_min_max():
|
||||
assert IntervalSet().min() is None
|
||||
|
@ -86,3 +93,25 @@ def test_interval_set_eq():
|
|||
assert not IntervalSet([(1, 2)]) == None # noqa pylint: disable=singleton-comparison
|
||||
# noqa pylint: disable=singleton-comparison
|
||||
assert not None == IntervalSet([(1, 2)])
|
||||
|
||||
|
||||
def test_interval_set_add():
|
||||
s = IntervalSet([(0, 3), (4, 7), (8, 11), (12, 15)])
|
||||
t = IntervalSet([(3, 4), (7, 8), (11, 12)])
|
||||
|
||||
assert s + t == [(0, 15)]
|
||||
assert t + s == [(0, 15)]
|
||||
assert [(3, 4), (7, 8), (11, 12)] + s == [(0, 15)]
|
||||
|
||||
t = IntervalSet([(3, 4), (11, 12)])
|
||||
assert s + t == [(0, 7), (8, 15)]
|
||||
assert t + s == [(0, 7), (8, 15)]
|
||||
|
||||
t = IntervalSet([(2, 5), (10, 13)])
|
||||
assert s + t == [(0, 7), (8, 15)]
|
||||
assert t + s == [(0, 7), (8, 15)]
|
||||
|
||||
assert s + [] == s
|
||||
assert [] + s == s
|
||||
|
||||
assert IntervalSet([(1, 3), (4, 6)]) + [(3, 4), (4, 5)] == [(1, 6)]
|
||||
|
|
Loading…
Reference in New Issue