#include "eigen-proxy.h"
#include <Eigen/LU>
#include <Eigen/LeastSquares>
#include <stdio.h>
#include <sstream>

static bool inited = eigen_initParallel();

class eigen_assert_exception : public std::exception {
    std::string _what;
public:
    eigen_assert_exception(const std::string& what) : _what(what) {}
    ~eigen_assert_exception() throw() {}
    const char* what() const throw () { return _what.c_str(); }
};

void eigen_assert_fail(const char* condition, const char* function, const char* file, int line) {
    std::ostringstream os;
    os << "assertion failed: " << condition << " in function " << function << " at " << file << ":" << line << std::endl;
    throw eigen_assert_exception(os.str());
}

typedef float T0;
typedef double T1;
typedef std::complex<float> T2;
typedef std::complex<double> T3;

template <class T>
Map< Matrix<T,Dynamic,Dynamic> > matrix(void* p, int r, int c) {
    return Map< Matrix<T,Dynamic,Dynamic> >((T*)p, r, c);
}

template <class T>
Map< Matrix<T,Dynamic,Dynamic> > matrix(const void* p, int r, int c) {
    return Map< Matrix<T,Dynamic,Dynamic> >((const T*)p, r, c);
}

#define BINOP(name,op) \
extern "C" const char* eigen_##name(\
    int code,\
    void* p, int r, int c,\
    const void* p1, int r1, int c1,\
    const void* p2, int r2, int c2)\
{\
    GUARD_START\
    switch (code) {\
        case 0: matrix<T0>(p,r,c) = matrix<T0>(p1,r1,c1) op matrix<T0>(p2,r2,c2); break;\
        case 1: matrix<T1>(p,r,c) = matrix<T1>(p1,r1,c1) op matrix<T1>(p2,r2,c2); break;\
        case 2: matrix<T2>(p,r,c) = matrix<T2>(p1,r1,c1) op matrix<T2>(p2,r2,c2); break;\
        case 3: matrix<T3>(p,r,c) = matrix<T3>(p1,r1,c1) op matrix<T3>(p2,r2,c2); break;\
    }\
    GUARD_END\
}

BINOP(add,+);
BINOP(sub,-);
BINOP(mul,*);

#define PROP(name) \
extern "C" const char* eigen_##name(int code, void* q, const void* p, int r, int c) {\
        GUARD_START\
        switch (code) {\
            case 0: *(T0*)q = matrix<T0>(p,r,c).name(); break;\
            case 1: *(T1*)q = matrix<T1>(p,r,c).name(); break;\
            case 2: *(T2*)q = matrix<T2>(p,r,c).name(); break;\
            case 3: *(T3*)q = matrix<T3>(p,r,c).name(); break;\
        }\
        GUARD_END\
    }

PROP(norm);
PROP(squaredNorm);
PROP(blueNorm);
PROP(hypotNorm);
PROP(sum);
PROP(prod);
PROP(mean);
PROP(trace);
PROP(determinant);

#define UNOP(name) \
extern "C" const char* eigen_##name(int code, void* p, int r, int c, const void* p1, int r1, int c1) {\
        GUARD_START\
        switch (code) {\
            case 0: matrix<T0>(p,r,c) = matrix<T0>(p1,r1,c1).name(); break;\
            case 1: matrix<T1>(p,r,c) = matrix<T1>(p1,r1,c1).name(); break;\
            case 2: matrix<T2>(p,r,c) = matrix<T2>(p1,r1,c1).name(); break;\
            case 3: matrix<T3>(p,r,c) = matrix<T3>(p1,r1,c1).name(); break;\
        }\
        GUARD_END\
    }

UNOP(inverse);
UNOP(adjoint);
UNOP(conjugate);
UNOP(diagonal);
UNOP(transpose);

extern "C" const char* eigen_normalize(int code, void* p, int r, int c)
{
    GUARD_START
    switch (code) {
        case 0: matrix<T0>(p,r,c).normalize(); break;
        case 1: matrix<T1>(p,r,c).normalize(); break;
        case 2: matrix<T2>(p,r,c).normalize(); break;
        case 3: matrix<T3>(p,r,c).normalize(); break;
    }
    GUARD_END
}

extern "C" const char* eigen_random(int code, void* p, int r, int c)
{
    GUARD_START
    switch (code) {
        case 0: matrix<T0>(p,r,c) = MatrixXf::Random(r,c); break;
        case 1: matrix<T1>(p,r,c) = MatrixXd::Random(r,c); break;
        case 2: matrix<T2>(p,r,c) = MatrixXcf::Random(r,c); break;
        case 3: matrix<T3>(p,r,c) = MatrixXcd::Random(r,c); break;
    }
    GUARD_END
}

extern "C" const char* eigen_identity(int code, void* p, int r, int c)
{
    GUARD_START
    switch (code) {
        case 0: matrix<T0>(p,r,c) = MatrixXf::Identity(r,c); break;
        case 1: matrix<T1>(p,r,c) = MatrixXd::Identity(r,c); break;
        case 2: matrix<T2>(p,r,c) = MatrixXcf::Identity(r,c); break;
        case 3: matrix<T3>(p,r,c) = MatrixXcd::Identity(r,c); break;
    }
    GUARD_END
}

template <class T>
const char* rank(Decomposition d, int* v, const void* p, int r, int c) {
    typedef Map< Matrix<T,Dynamic,Dynamic> > MapMatrix;
    MapMatrix A((const T*)p,r,c);
    switch (d) {
        case ::FullPivLU:
            *v = A.fullPivLu().rank();
            break;
        case ::ColPivHouseholderQR:
            *v = A.colPivHouseholderQr().rank();
            break;
        case ::FullPivHouseholderQR:
            *v = A.fullPivHouseholderQr().rank();
            break;
        case ::JacobiSVD:
            *v = A.jacobiSvd(ComputeThinU | ComputeThinV).rank();
            break;
        default:
            return strdup("Selected decomposition doesn't support rank revealing.");
    }
    return 0;
}

extern "C" const char* eigen_rank(int code, Decomposition d, int* v, const void* p, int r, int c) {
    GUARD_START
    switch (code) {
        case 0: return rank<T0>(d,v,p,r,c);
        case 1: return rank<T1>(d,v,p,r,c);
        case 2: return rank<T2>(d,v,p,r,c);
        case 3: return rank<T3>(d,v,p,r,c);
    }
    GUARD_END
}

template <class T>
const char* kernel(Decomposition d, void** p0, int* r0, int* c0, const void* p1, int r1, int c1) {
    typedef Map< Matrix<T,Dynamic,Dynamic> > MapMatrix;
    if (d != ::FullPivLU)
        return strdup("Selected decomposition doesn't support kernel revealing.");
    MapMatrix A((const T*)p1,r1,c1);
    Matrix<T,Dynamic,Dynamic> B = A.fullPivLu().kernel();
    *r0 = B.rows();
    *c0 = B.cols();
    *p0 = malloc(*r0 * *c0 * sizeof(T));
    MapMatrix((T*)*p0, *r0, *c0) = B;
    return 0;
}

extern "C" const char* eigen_kernel(int code, Decomposition d, void** p0, int* r0, int* c0, const void* p1, int r1, int c1) {
    GUARD_START
    switch (code) {
        case 0: return kernel<T0>(d,p0,r0,c0,p1,r1,c1);
        case 1: return kernel<T1>(d,p0,r0,c0,p1,r1,c1);
        case 2: return kernel<T2>(d,p0,r0,c0,p1,r1,c1);
        case 3: return kernel<T3>(d,p0,r0,c0,p1,r1,c1);
    }
    GUARD_END
}

template <class T>
const char* image(Decomposition d, void** p0, int* r0, int* c0, const void* p1, int r1, int c1) {
    typedef Map< Matrix<T,Dynamic,Dynamic> > MapMatrix;
    if (d != ::FullPivLU)
        return strdup("Selected decomposition doesn't support image revealing.");
    MapMatrix A((const T*)p1,r1,c1);
    Matrix<T,Dynamic,Dynamic> B = A.fullPivLu().image(A);
    *r0 = B.rows();
    *c0 = B.cols();
    *p0 = malloc(*r0 * *c0 * sizeof(T));
    MapMatrix((T*)*p0, *r0, *c0) = B;
    return 0;
}

extern "C" const char* eigen_image(int code, Decomposition d, void** p0, int* r0, int* c0, const void* p1, int r1, int c1) {
    GUARD_START
    switch (code) {
        case 0: return image<T0>(d,p0,r0,c0,p1,r1,c1);
        case 1: return image<T1>(d,p0,r0,c0,p1,r1,c1);
        case 2: return image<T2>(d,p0,r0,c0,p1,r1,c1);
        case 3: return image<T3>(d,p0,r0,c0,p1,r1,c1);
    }
    GUARD_END
}

template <class T>
const char* solve(Decomposition d,
    void* px, int rx, int cx,
    const void* pa, int ra, int ca,
    const void* pb, int rb, int cb)
{
    typedef Map< Matrix<T,Dynamic,Dynamic> > MapMatrix;
    MapMatrix x((T*)px, rx, cx);
    MapMatrix A((const T*)pa, ra, ca);
    MapMatrix b((const T*)pb, rb, cb);
    switch (d) {
        case ::PartialPivLU:
            x = A.partialPivLu().solve(b);
            break;
        case ::FullPivLU:
            x = A.fullPivLu().solve(b);
            break;
        case ::HouseholderQR:
            x = A.householderQr().solve(b);
            break;
        case ::ColPivHouseholderQR:
            x = A.colPivHouseholderQr().solve(b);
            break;
        case ::FullPivHouseholderQR:
            x = A.fullPivHouseholderQr().solve(b);
            break;
        case ::LLT:
            x = A.llt().solve(b);
            break;
        case ::LDLT:
            x = A.ldlt().solve(b);
            break;
        case ::JacobiSVD:
            x = A.jacobiSvd(ComputeThinU | ComputeThinV).solve(b);
            break;
    }
    return 0;
}

extern "C" const char* eigen_solve(int code, Decomposition d,
    void* px, int rx, int cx,
    const void* pa, int ra, int ca,
    const void* pb, int rb, int cb)
{
    GUARD_START
    switch (code) {
        case 0: return solve<T0>(d,px,rx,cx,pa,ra,ca,pb,rb,cb);
        case 1: return solve<T1>(d,px,rx,cx,pa,ra,ca,pb,rb,cb);
        case 2: return solve<T2>(d,px,rx,cx,pa,ra,ca,pb,rb,cb);
        case 3: return solve<T3>(d,px,rx,cx,pa,ra,ca,pb,rb,cb);
    }
    GUARD_END
}

template <class T>
const char* relativeError(void* e,
    const void* px, int rx, int cx,
    const void* pa, int ra, int ca,
    const void* pb, int rb, int cb)
{
    typedef Map< Matrix<T,Dynamic,Dynamic> > MapMatrix;
    MapMatrix x((const T*)px, rx, cx);
    MapMatrix A((const T*)pa, ra, ca);
    MapMatrix b((const T*)pb, rb, cb);
    *(T*)e = (A*x - b).norm() / b.norm();
    return 0;
}


extern "C" const char* eigen_relativeError(int code, void* e,
    const void* px, int rx, int cx,
    const void* pa, int ra, int ca,
    const void* pb, int rb, int cb)
{
    GUARD_START
    switch (code) {
        case 0: return relativeError<T0>(e,px,rx,cx,pa,ra,ca,pb,rb,cb);
        case 1: return relativeError<T1>(e,px,rx,cx,pa,ra,ca,pb,rb,cb);
        case 2: return relativeError<T2>(e,px,rx,cx,pa,ra,ca,pb,rb,cb);
        case 3: return relativeError<T3>(e,px,rx,cx,pa,ra,ca,pb,rb,cb);
    }
    GUARD_END
}


extern "C" bool eigen_initParallel() {
    initParallel();
    return true;
}

extern "C" void eigen_setNbThreads(int n) {
    setNbThreads(n);
}

extern "C" int eigen_getNbThreads() {
    return nbThreads();
}
