#pragma once
#include "PODInterval.h"
#include "PODRedBlackTree.h"
#include <wtf/Optional.h>
#include <wtf/Vector.h>
namespace WebCore {
struct PODIntervalNodeUpdater;
template<typename T, typename UserData> class PODIntervalTree final : public PODRedBlackTree<PODInterval<T, UserData>, PODIntervalNodeUpdater> {
WTF_MAKE_FAST_ALLOCATED;
public:
using IntervalType = PODInterval<T, UserData>;
class OverlapsSearchAdapter;
Vector<IntervalType> allOverlaps(const IntervalType& interval) const
{
Vector<IntervalType> result;
OverlapsSearchAdapter adapter(result, interval);
allOverlapsWithAdapter(adapter);
return result;
}
template<typename AdapterType> void allOverlapsWithAdapter(AdapterType& adapter) const
{
searchForOverlapsFrom(this->root(), adapter);
}
Optional<IntervalType> nextIntervalAfter(const T& point)
{
auto next = smallestNodeGreaterThanFrom(point, this->root());
if (!next)
return WTF::nullopt;
return next->data();
}
#ifndef NDEBUG
bool checkInvariants() const
{
if (!Base::checkInvariants())
return false;
if (!this->root())
return true;
return checkInvariantsFromNode(this->root(), nullptr);
}
#endif
private:
using Base = PODRedBlackTree<PODInterval<T, UserData>, PODIntervalNodeUpdater>;
using IntervalNode = typename Base::Node;
template<typename AdapterType> void searchForOverlapsFrom(IntervalNode* node, AdapterType& adapter) const
{
if (!node)
return;
IntervalNode* left = node->left();
if (left
&& !(left->data().maxHigh() < adapter.lowValue()))
searchForOverlapsFrom<AdapterType>(left, adapter);
adapter.collectIfNeeded(node->data());
if (!(adapter.highValue() < node->data().low()))
searchForOverlapsFrom<AdapterType>(node->right(), adapter);
}
IntervalNode* smallestNodeGreaterThanFrom(const T& point, IntervalNode* node) const
{
if (!node)
return nullptr;
if (!(point < node->data().low()))
return smallestNodeGreaterThanFrom(point, node->right());
if (auto left = smallestNodeGreaterThanFrom(point, node->right()))
return left;
return node;
}
#ifndef NDEBUG
bool checkInvariantsFromNode(IntervalNode* node, T* currentMaxValue) const
{
T leftMaxValue(node->data().maxHigh());
T rightMaxValue(node->data().maxHigh());
IntervalNode* left = node->left();
IntervalNode* right = node->right();
if (left) {
if (!checkInvariantsFromNode(left, &leftMaxValue))
return false;
}
if (right) {
if (!checkInvariantsFromNode(right, &rightMaxValue))
return false;
}
if (!left && !right) {
if (currentMaxValue)
*currentMaxValue = node->data().high();
return (node->data().high() == node->data().maxHigh());
}
T localMaxValue(node->data().maxHigh());
if (!left || !right) {
if (left)
localMaxValue = leftMaxValue;
else
localMaxValue = rightMaxValue;
} else
localMaxValue = (leftMaxValue < rightMaxValue) ? rightMaxValue : leftMaxValue;
if (localMaxValue < node->data().high())
localMaxValue = node->data().high();
if (!(localMaxValue == node->data().maxHigh())) {
TextStream stream;
stream << "localMaxValue=" << localMaxValue << "and data =" << node->data();
LOG_ERROR("PODIntervalTree verification failed at node 0x%p: %s",
node, stream.release().utf8().data());
return false;
}
if (currentMaxValue)
*currentMaxValue = localMaxValue;
return true;
}
#endif
};
template<typename T, typename UserData> class PODIntervalTree<T, UserData>::OverlapsSearchAdapter {
public:
using IntervalType = PODInterval<T, UserData>;
OverlapsSearchAdapter(Vector<IntervalType>& result, const IntervalType& interval)
: m_result(result)
, m_interval(interval)
{
}
const T& lowValue() const { return m_interval.low(); }
const T& highValue() const { return m_interval.high(); }
void collectIfNeeded(const IntervalType& data) const
{
if (data.overlaps(m_interval))
m_result.append(data);
}
private:
Vector<IntervalType>& m_result;
const IntervalType& m_interval;
};
struct PODIntervalNodeUpdater {
template<typename Node> static bool update(Node& node)
{
auto* curMax = &node.data().high();
auto* left = node.left();
if (left) {
if (*curMax < left->data().maxHigh())
curMax = &left->data().maxHigh();
}
auto* right = node.right();
if (right) {
if (*curMax < right->data().maxHigh())
curMax = &right->data().maxHigh();
}
if (!(*curMax == node.data().maxHigh())) {
node.data().setMaxHigh(*curMax);
return true;
}
return false;
}
};
}