#ifndef LLVM_CODEGEN_PBQP_MATH_H
#define LLVM_CODEGEN_PBQP_MATH_H
#include "llvm/ADT/Hashing.h"
#include <algorithm>
#include <cassert>
#include <functional>
namespace llvm {
namespace PBQP {
typedef float PBQPNum;
class Vector {
friend hash_code hash_value(const Vector &);
public:
explicit Vector(unsigned Length)
: Length(Length), Data(new PBQPNum[Length]) {
}
Vector(unsigned Length, PBQPNum InitVal)
: Length(Length), Data(new PBQPNum[Length]) {
std::fill(Data, Data + Length, InitVal);
}
Vector(const Vector &V)
: Length(V.Length), Data(new PBQPNum[Length]) {
std::copy(V.Data, V.Data + Length, Data);
}
Vector(Vector &&V)
: Length(V.Length), Data(V.Data) {
V.Length = 0;
V.Data = nullptr;
}
~Vector() {
delete[] Data;
}
Vector& operator=(const Vector &V) {
delete[] Data;
Length = V.Length;
Data = new PBQPNum[Length];
std::copy(V.Data, V.Data + Length, Data);
return *this;
}
Vector& operator=(Vector &&V) {
delete[] Data;
Length = V.Length;
Data = V.Data;
V.Length = 0;
V.Data = nullptr;
return *this;
}
bool operator==(const Vector &V) const {
assert(Length != 0 && Data != nullptr && "Invalid vector");
if (Length != V.Length)
return false;
return std::equal(Data, Data + Length, V.Data);
}
unsigned getLength() const {
assert(Length != 0 && Data != nullptr && "Invalid vector");
return Length;
}
PBQPNum& operator[](unsigned Index) {
assert(Length != 0 && Data != nullptr && "Invalid vector");
assert(Index < Length && "Vector element access out of bounds.");
return Data[Index];
}
const PBQPNum& operator[](unsigned Index) const {
assert(Length != 0 && Data != nullptr && "Invalid vector");
assert(Index < Length && "Vector element access out of bounds.");
return Data[Index];
}
Vector& operator+=(const Vector &V) {
assert(Length != 0 && Data != nullptr && "Invalid vector");
assert(Length == V.Length && "Vector length mismatch.");
std::transform(Data, Data + Length, V.Data, Data, std::plus<PBQPNum>());
return *this;
}
Vector& operator-=(const Vector &V) {
assert(Length != 0 && Data != nullptr && "Invalid vector");
assert(Length == V.Length && "Vector length mismatch.");
std::transform(Data, Data + Length, V.Data, Data, std::minus<PBQPNum>());
return *this;
}
unsigned minIndex() const {
assert(Length != 0 && Data != nullptr && "Invalid vector");
return std::min_element(Data, Data + Length) - Data;
}
private:
unsigned Length;
PBQPNum *Data;
};
inline hash_code hash_value(const Vector &V) {
unsigned *VBegin = reinterpret_cast<unsigned*>(V.Data);
unsigned *VEnd = reinterpret_cast<unsigned*>(V.Data + V.Length);
return hash_combine(V.Length, hash_combine_range(VBegin, VEnd));
}
template <typename OStream>
OStream& operator<<(OStream &OS, const Vector &V) {
assert((V.getLength() != 0) && "Zero-length vector badness.");
OS << "[ " << V[0];
for (unsigned i = 1; i < V.getLength(); ++i)
OS << ", " << V[i];
OS << " ]";
return OS;
}
class Matrix {
private:
friend hash_code hash_value(const Matrix &);
public:
Matrix(unsigned Rows, unsigned Cols) :
Rows(Rows), Cols(Cols), Data(new PBQPNum[Rows * Cols]) {
}
Matrix(unsigned Rows, unsigned Cols, PBQPNum InitVal)
: Rows(Rows), Cols(Cols), Data(new PBQPNum[Rows * Cols]) {
std::fill(Data, Data + (Rows * Cols), InitVal);
}
Matrix(const Matrix &M)
: Rows(M.Rows), Cols(M.Cols), Data(new PBQPNum[Rows * Cols]) {
std::copy(M.Data, M.Data + (Rows * Cols), Data);
}
Matrix(Matrix &&M)
: Rows(M.Rows), Cols(M.Cols), Data(M.Data) {
M.Rows = M.Cols = 0;
M.Data = nullptr;
}
~Matrix() { delete[] Data; }
Matrix& operator=(const Matrix &M) {
delete[] Data;
Rows = M.Rows; Cols = M.Cols;
Data = new PBQPNum[Rows * Cols];
std::copy(M.Data, M.Data + (Rows * Cols), Data);
return *this;
}
Matrix& operator=(Matrix &&M) {
delete[] Data;
Rows = M.Rows;
Cols = M.Cols;
Data = M.Data;
M.Rows = M.Cols = 0;
M.Data = nullptr;
return *this;
}
bool operator==(const Matrix &M) const {
assert(Rows != 0 && Cols != 0 && Data != nullptr && "Invalid matrix");
if (Rows != M.Rows || Cols != M.Cols)
return false;
return std::equal(Data, Data + (Rows * Cols), M.Data);
}
unsigned getRows() const {
assert(Rows != 0 && Cols != 0 && Data != nullptr && "Invalid matrix");
return Rows;
}
unsigned getCols() const {
assert(Rows != 0 && Cols != 0 && Data != nullptr && "Invalid matrix");
return Cols;
}
PBQPNum* operator[](unsigned R) {
assert(Rows != 0 && Cols != 0 && Data != nullptr && "Invalid matrix");
assert(R < Rows && "Row out of bounds.");
return Data + (R * Cols);
}
const PBQPNum* operator[](unsigned R) const {
assert(Rows != 0 && Cols != 0 && Data != nullptr && "Invalid matrix");
assert(R < Rows && "Row out of bounds.");
return Data + (R * Cols);
}
Vector getRowAsVector(unsigned R) const {
assert(Rows != 0 && Cols != 0 && Data != nullptr && "Invalid matrix");
Vector V(Cols);
for (unsigned C = 0; C < Cols; ++C)
V[C] = (*this)[R][C];
return V;
}
Vector getColAsVector(unsigned C) const {
assert(Rows != 0 && Cols != 0 && Data != nullptr && "Invalid matrix");
Vector V(Rows);
for (unsigned R = 0; R < Rows; ++R)
V[R] = (*this)[R][C];
return V;
}
Matrix& reset(PBQPNum Val = 0) {
assert(Rows != 0 && Cols != 0 && Data != nullptr && "Invalid matrix");
std::fill(Data, Data + (Rows * Cols), Val);
return *this;
}
Matrix& setRow(unsigned R, PBQPNum Val) {
assert(Rows != 0 && Cols != 0 && Data != nullptr && "Invalid matrix");
assert(R < Rows && "Row out of bounds.");
std::fill(Data + (R * Cols), Data + ((R + 1) * Cols), Val);
return *this;
}
Matrix& setCol(unsigned C, PBQPNum Val) {
assert(Rows != 0 && Cols != 0 && Data != nullptr && "Invalid matrix");
assert(C < Cols && "Column out of bounds.");
for (unsigned R = 0; R < Rows; ++R)
(*this)[R][C] = Val;
return *this;
}
Matrix transpose() const {
assert(Rows != 0 && Cols != 0 && Data != nullptr && "Invalid matrix");
Matrix M(Cols, Rows);
for (unsigned r = 0; r < Rows; ++r)
for (unsigned c = 0; c < Cols; ++c)
M[c][r] = (*this)[r][c];
return M;
}
Vector diagonalize() const {
assert(Rows != 0 && Cols != 0 && Data != nullptr && "Invalid matrix");
assert(Rows == Cols && "Attempt to diagonalize non-square matrix.");
Vector V(Rows);
for (unsigned r = 0; r < Rows; ++r)
V[r] = (*this)[r][r];
return V;
}
Matrix& operator+=(const Matrix &M) {
assert(Rows != 0 && Cols != 0 && Data != nullptr && "Invalid matrix");
assert(Rows == M.Rows && Cols == M.Cols &&
"Matrix dimensions mismatch.");
std::transform(Data, Data + (Rows * Cols), M.Data, Data,
std::plus<PBQPNum>());
return *this;
}
Matrix operator+(const Matrix &M) {
assert(Rows != 0 && Cols != 0 && Data != nullptr && "Invalid matrix");
Matrix Tmp(*this);
Tmp += M;
return Tmp;
}
PBQPNum getRowMin(unsigned R) const {
assert(Rows != 0 && Cols != 0 && Data != nullptr && "Invalid matrix");
assert(R < Rows && "Row out of bounds");
return *std::min_element(Data + (R * Cols), Data + ((R + 1) * Cols));
}
PBQPNum getColMin(unsigned C) const {
assert(Rows != 0 && Cols != 0 && Data != nullptr && "Invalid matrix");
PBQPNum MinElem = (*this)[0][C];
for (unsigned R = 1; R < Rows; ++R)
if ((*this)[R][C] < MinElem)
MinElem = (*this)[R][C];
return MinElem;
}
Matrix& subFromRow(unsigned R, PBQPNum Val) {
assert(Rows != 0 && Cols != 0 && Data != nullptr && "Invalid matrix");
assert(R < Rows && "Row out of bounds");
std::transform(Data + (R * Cols), Data + ((R + 1) * Cols),
Data + (R * Cols),
std::bind2nd(std::minus<PBQPNum>(), Val));
return *this;
}
Matrix& subFromCol(unsigned C, PBQPNum Val) {
assert(Rows != 0 && Cols != 0 && Data != nullptr && "Invalid matrix");
for (unsigned R = 0; R < Rows; ++R)
(*this)[R][C] -= Val;
return *this;
}
bool isZero() const {
assert(Rows != 0 && Cols != 0 && Data != nullptr && "Invalid matrix");
return find_if(Data, Data + (Rows * Cols),
std::bind2nd(std::not_equal_to<PBQPNum>(), 0)) ==
Data + (Rows * Cols);
}
private:
unsigned Rows, Cols;
PBQPNum *Data;
};
inline hash_code hash_value(const Matrix &M) {
unsigned *MBegin = reinterpret_cast<unsigned*>(M.Data);
unsigned *MEnd = reinterpret_cast<unsigned*>(M.Data + (M.Rows * M.Cols));
return hash_combine(M.Rows, M.Cols, hash_combine_range(MBegin, MEnd));
}
template <typename OStream>
OStream& operator<<(OStream &OS, const Matrix &M) {
assert((M.getRows() != 0) && "Zero-row matrix badness.");
for (unsigned i = 0; i < M.getRows(); ++i)
OS << M.getRowAsVector(i) << "\n";
return OS;
}
template <typename Metadata>
class MDVector : public Vector {
public:
MDVector(const Vector &v) : Vector(v), md(*this) { }
MDVector(Vector &&v) : Vector(std::move(v)), md(*this) { }
const Metadata& getMetadata() const { return md; }
private:
Metadata md;
};
template <typename Metadata>
inline hash_code hash_value(const MDVector<Metadata> &V) {
return hash_value(static_cast<const Vector&>(V));
}
template <typename Metadata>
class MDMatrix : public Matrix {
public:
MDMatrix(const Matrix &m) : Matrix(m), md(*this) { }
MDMatrix(Matrix &&m) : Matrix(std::move(m)), md(*this) { }
const Metadata& getMetadata() const { return md; }
private:
Metadata md;
};
template <typename Metadata>
inline hash_code hash_value(const MDMatrix<Metadata> &M) {
return hash_value(static_cast<const Matrix&>(M));
}
} }
#endif // LLVM_CODEGEN_PBQP_MATH_H