diff --git a/chrono/utils/interval.py b/chrono/utils/interval.py index 2a17ecff..2b49b203 100644 --- a/chrono/utils/interval.py +++ b/chrono/utils/interval.py @@ -15,9 +15,35 @@ # along with this program. If not, see . import bisect -import collections +import typing -Interval = collections.namedtuple('Interval', ['begin', 'end']) + +class Interval(typing.NamedTuple): + begin: typing.Any + end: typing.Any + + def disjoint(self, other): + return self < other or self > other + + def overlaps(self, other): + return not self.disjoint(other) + + def __lt__(self, other): + return self[1] < other[0] + + def __gt__(self, other): + return other[1] < self[0] + + def union(self, other): + other = self.cast(other) + assert self.overlaps(other) + return Interval(min(self.begin, other.begin), max(self.end, other.end)) + + @classmethod + def cast(cls, other): + if isinstance(other, cls): + return other + return cls(*other) class IntervalSet: @@ -159,6 +185,9 @@ class IntervalSet: return value return cls(value) + def __rsub__(self, other): + return self.cast(other) - self + def __sub__(self, other): l1 = iter(self) l2 = iter(self.cast(other)) @@ -200,6 +229,95 @@ class IntervalSet: return self.__class__.from_ordered(gen()) + def __radd__(self, other): + return self.cast(other).__add__(self) + + def __add__(self, other): + l1 = iter(self) + l2 = iter(self.cast(other)) + + def gen(): + state = 3 + current = None + while True: + if state & 1: + c1 = next(l1, None) + if state & 2: + c2 = next(l2, None) + if current: + if not c1 and not c2: + yield current + break + if not c1: + if current < c2: + yield current + yield c2 + break + if c2 < current: + yield c2 + yield current + else: + yield current.union(c2) + break + if not c2: + if current < c1: + yield current + yield c1 + break + if c1 < current: + yield c1 + yield current + else: + yield current.union(c1) + break + if current < c1 and current < c2: + yield current + current = None + elif current.overlaps(c1) and current.overlaps(c2): + yield current.union(c1).union(c2) + current = None + state = 3 + continue + elif current < c2: + yield current.union(c1) + current = None + state = 1 + continue + else: + yield current.union(c2) + current = None + state = 2 + continue + if not c1 and not c2: + # l1 and l2 are empty, stop + break + if not c1: + # l1 is empty, yield c2 and stop + yield c2 + break + if not c2: + # l2 is empty, yield c1 and stop + yield c1 + break + if c1 < c2: + # l1 is before l2, yield c1 and advance l1 only + yield c1 + state = 1 + continue + if c2 < c1: + # l2 is before l1, yield c2 and advance l2 only + yield c2 + state = 2 + continue + current = c1.union(c2) + state = 3 + + # finish by yielding from the not empty ones + yield from l1 + yield from l2 + + return self.__class__.from_ordered(gen()) + def min(self): if self: return self.begin[0] diff --git a/tests/test_interval.py b/tests/test_interval.py index 04a53bf5..b27e8d0b 100644 --- a/tests/test_interval.py +++ b/tests/test_interval.py @@ -3,6 +3,11 @@ import pytest from chrono.utils.interval import Interval, IntervalSet +def test_interval_union(): + assert Interval(1, 2).union((2, 3)) == (1, 3) + assert Interval(1, 2).union((2, 3)) == Interval(1, 3) + + def test_interval_set_merge_adjacent(): """ Test that adjacent intervals are merged @@ -67,6 +72,8 @@ def test_interval_set_sub(): assert (s - []) == s assert (IntervalSet([(0, 2)]) - [(1, 2)]) == [(0, 1)] + assert ([] - s) == [] + def test_interval_set_min_max(): assert IntervalSet().min() is None @@ -86,3 +93,25 @@ def test_interval_set_eq(): assert not IntervalSet([(1, 2)]) == None # noqa pylint: disable=singleton-comparison # noqa pylint: disable=singleton-comparison assert not None == IntervalSet([(1, 2)]) + + +def test_interval_set_add(): + s = IntervalSet([(0, 3), (4, 7), (8, 11), (12, 15)]) + t = IntervalSet([(3, 4), (7, 8), (11, 12)]) + + assert s + t == [(0, 15)] + assert t + s == [(0, 15)] + assert [(3, 4), (7, 8), (11, 12)] + s == [(0, 15)] + + t = IntervalSet([(3, 4), (11, 12)]) + assert s + t == [(0, 7), (8, 15)] + assert t + s == [(0, 7), (8, 15)] + + t = IntervalSet([(2, 5), (10, 13)]) + assert s + t == [(0, 7), (8, 15)] + assert t + s == [(0, 7), (8, 15)] + + assert s + [] == s + assert [] + s == s + + assert IntervalSet([(1, 3), (4, 6)]) + [(3, 4), (4, 5)] == [(1, 6)]