/*  CAO Compiler
    Copyright (C) 2014 Cryptography and Information Security Group, HASLab - INESC TEC and Universidade do Minho

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.  */

#include "CAO_matrix.h"

CAO_matrix_s *newMatrix(int rows, int cols, char type)
{
	CAO_matrix_s *newM;

	newM = (CAO_matrix_s *) malloc(sizeof(CAO_matrix_s));
	newM->rows = rows;
	newM->cols = cols;
	newM->type = type;
	newM->value = (CAO_REF *) malloc(cols * rows * sizeof(CAO_REF));

	return newM;
}

CAO_RES
CAO_matrix_decl(CAO_matrix * m, int rows, int cols, const char type[],
				void *indices[])
{
	int jump;
	return _CAO_matrix_decl(m, rows, cols, type, indices, &jump);
}

CAO_RES
_CAO_matrix_decl(CAO_matrix * m, int rows, int cols, const char type[],
				 void *indices[], int *jump)
{
	int i, size = rows * cols, res = CAO_OK;
	CAO_matrix_s *_m = newMatrix(rows, cols, type[0]);

	for (i = 0; ((res == CAO_OK) && (i < size)); i++)
		res = _CAO_global_decl(&(_m->value[i]), type, indices, jump);

	*m = _m;
	return res;
}

int CAO_matrix_iscol(CAO_matrix m)
{
	CAO_matrix_s *_m = (CAO_matrix_s *) m;
	if (_m->cols == 1)
		return 1;
	return 0;
}

CAO_RES CAO_matrix_dispose(CAO_matrix m)
{
	int i, size;
	CAO_RES res = CAO_OK;
	CAO_matrix_s *_m = (CAO_matrix_s *) m;
	size = _m->rows * _m->cols;

	for (i = 0; ((i < size) && (res == CAO_OK)); i++)
		res = CAO_global_dispose(_m->value[i], _m->type);

	free(_m->value);
	free(_m);

	return res;
}

CAO_RES CAO_matrix_const_init(CAO_matrix m, void *value)
{
	CAO_matrix_s *_m = (CAO_matrix_s *) m;

	int s = _m->rows * _m->cols, i;

	for (i = 0; (i < s); i++)
		CAO_global_const_init(_m->value[i], value, _m->type);

	return CAO_OK;
}

CAO_RES CAO_matrix_init(CAO_matrix m, void *value[])
{
	int vjump = 0;
	return _CAO_matrix_init(m, value, &vjump);
}

CAO_RES _CAO_matrix_init(CAO_matrix m, void *value[], int *vjump)
{

	CAO_matrix_s *_m = (CAO_matrix_s *) m;
	int offset = 0;

	int s = (_m->rows * _m->cols), i;

	for (i = 0; (i < s); i++)
	{
		_CAO_global_init(_m->value[i], value + offset, vjump, _m->type);
		offset += *vjump;
	}
	*vjump = offset;
	return CAO_OK;
}

CAO_RES CAO_matrix_assign(CAO_matrix r, CAO_matrix m)
{
	CAO_matrix_s *_r = (CAO_matrix_s *) r;
	CAO_matrix_s *_m = (CAO_matrix_s *) m;
	int sr = (_r->rows * _r->cols);

	int i;
	for (i = 0; i < sr; i++)
		CAO_global_assign(_r->value[i], _m->value[i], _r->type);

	return CAO_OK;
}

CAO_RES CAO_matrix_clone(CAO_matrix * r, CAO_matrix m)
{
	CAO_matrix_s *_m = (CAO_matrix_s *) m;
	CAO_matrix_s *_r = newMatrix(_m->rows, _m->cols, _m->type);

	int size = (_m->rows * _m->cols);

	int i;
	for (i = 0; i < size; i++)
		CAO_global_clone(&(_r->value[i]), _m->value[i], _m->type);
	*r = _r;
	return CAO_OK;
}

CAO_bool _CAO_matrix_equal(CAO_matrix a, CAO_matrix b)
{
	CAO_bool r;

	CAO_matrix_s *_a = (CAO_matrix_s *) a;
	CAO_matrix_s *_b = (CAO_matrix_s *) b;
	int sa = (_a->rows * _a->cols);
	int i;

	r = true;
	i = 0;
	while ((r) && (i < sa))
	{
		CAO_global_equal(r, _a->value[i], _b->value[i], _a->type);
		i++;
	}

	return r;
}

CAO_RES CAO_matrix_select(CAO_REF r, CAO_matrix m, CAO_rint i, CAO_rint j)
{
	CAO_matrix_s *_m = (CAO_matrix_s *) m;

	if ((i >= 0) && (i < _m->rows) && (j >= 0) && (j < _m->cols))
	{
		CAO_global_assign(r, _m->value[i * _m->cols + j], _m->type);
	}
	else
	{
		return CAO_ERR;
	}
	return CAO_OK;
}

CAO_REF CAO_matrix_ref(CAO_matrix m, CAO_rint i, CAO_rint j)
{
	char type;
	return _CAO_matrix_ref(m, i, j, &type);
}

CAO_REF _CAO_matrix_ref(CAO_matrix m, CAO_rint i, CAO_rint j, char *t)
{
	CAO_matrix_s *_m = (CAO_matrix_s *) m;
	*t = _m->type;
	return (_m->value[i * _m->cols + j]);
}

CAO_RES
CAO_matrix_range_select(CAO_matrix r, CAO_matrix m, CAO_rint ri, CAO_rint rj,
						CAO_rint ci, CAO_rint cj)
{

	CAO_matrix_s *_r = (CAO_matrix_s *) r;
	CAO_matrix_s *_m = (CAO_matrix_s *) m;
	int i, j, pr, pm;

	if ((ri >= 0) && (ri < _m->rows) &&
		(rj >= 0) && (rj < _m->rows) && (ri <= rj) && (ci >= 0)
		&& (ci < _m->rows) && (cj >= 0) && (cj < _m->rows) && (ci <= cj))
	{
		pr = 0;
		for (i = ri; (i <= rj); i++)
		{
			pm = i * _m->cols + ci;
			for (j = ci; (j <= cj); j++)
			{
				CAO_global_assign(_r->value[pr++], _m->value[pm++], _r->type);
			}
		}
	}
	else
	{
		return CAO_ERR;
	}

	return CAO_OK;
}

CAO_RES
CAO_matrix_row_range_select(CAO_matrix r, CAO_matrix m, CAO_rint c,
							CAO_rint ri, CAO_rint rj)
{
	CAO_matrix_s *_r = (CAO_matrix_s *) r;
	CAO_matrix_s *_m = (CAO_matrix_s *) m;
	int size;
	int i, pr, pm;

	size = (rj - ri + 1);

	if ((ri >= 0) && (ri < _m->rows) && (rj >= 0) && (rj < _m->rows)
		&& (ri <= rj))
	{
		pr = 0;
		pm = ri * _m->cols + c;
		for (i = 0; (i < size); i++)
		{
			CAO_global_assign(_r->value[pr], _m->value[pm], _r->type);
			pr++;
			pm += _m->cols;
		}
	}
	else
	{
		return CAO_ERR;
	}
	return CAO_OK;
}

CAO_RES
CAO_matrix_col_range_select(CAO_matrix r, CAO_matrix m, CAO_rint row,
							CAO_rint ci, CAO_rint cj)
{
	CAO_matrix_s *_r = (CAO_matrix_s *) r;
	CAO_matrix_s *_m = (CAO_matrix_s *) m;
	int size;
	int i, pr, pm;

	size = (cj - ci + 1);

	if ((ci >= 0) && (ci < _m->rows) && (cj >= 0) && (cj < _m->rows)
		&& (ci <= cj))
	{
		pr = 0;
		pm = row * _m->cols + ci;
		for (i = 0; (i < size); i++)
			CAO_global_assign(_r->value[pr++], _m->value[pm++], _r->type);
	}
	else
	{
		return CAO_ERR;
	}

	return CAO_OK;
}

CAO_RES
CAO_matrix_range_set(CAO_matrix r, CAO_matrix m, CAO_rint ri, CAO_rint rj,
					 CAO_rint ci, CAO_rint cj)
{

	CAO_matrix_s *_r = (CAO_matrix_s *) r;
	CAO_matrix_s *_m = (CAO_matrix_s *) m;
	int i, j, pr, pm;

	if ((ri >= 0) && (ri < _r->rows) &&
		(rj >= 0) && (rj < _r->rows) && (ri <= rj) && (ci >= 0)
		&& (ci < _r->rows) && (cj >= 0) && (cj < _r->rows) && (ci <= cj))
	{
		pm = 0;
		for (i = ri; (i <= rj); i++)
		{
			pr = i * _r->cols + ci;
			for (j = ci; (j <= cj); j++)
			{
				CAO_global_assign(_r->value[pr++], _m->value[pm++], _r->type);
			}
		}
	}
	else
	{
		return CAO_ERR;
	}

	return CAO_OK;
}

CAO_RES
CAO_matrix_row_range_set(CAO_matrix r, CAO_matrix m, CAO_rint c, CAO_rint ri,
						 CAO_rint rj)
{
	CAO_matrix_s *_r = (CAO_matrix_s *) r;
	CAO_matrix_s *_m = (CAO_matrix_s *) m;
	int size;
	int i, pr, pm;

	size = (rj - ri + 1);

	if ((ri >= 0) && (ri < _r->rows) && (rj >= 0) && (rj < _r->rows)
		&& (ri <= rj))
	{
		pr = _r->cols * ri + c;
		pm = 0;
		for (i = 0; (i < size); i++)
		{
			CAO_global_assign(_r->value[pr], _m->value[pm], _r->type);
			pr += _r->cols;
			pm++;
		}
	}
	else
	{
		return CAO_ERR;
	}
	return CAO_OK;
}

CAO_RES
CAO_matrix_col_range_set(CAO_matrix r, CAO_matrix m, CAO_rint row, CAO_rint ci,
						 CAO_rint cj)
{
	CAO_matrix_s *_r = (CAO_matrix_s *) r;
	CAO_matrix_s *_m = (CAO_matrix_s *) m;
	int size;
	int i, pr, pm;

	size = cj - ci + 1;

	if ((ci >= 0) && (ci < _r->rows) && (cj >= 0) && (cj < _r->rows)
		&& (ci <= cj))
	{
		pr = row * _r->cols + ci;
		pm = 0;
		for (i = 0; (i < size); i++)
		{
			CAO_global_assign(_r->value[pr++], _m->value[pm++], _r->type);
		}
	}
	else
	{
		return CAO_ERR;
	}

	return CAO_OK;
}

CAO_RES CAO_matrix_concat(CAO_matrix r, CAO_matrix a, CAO_matrix b)
{
	CAO_matrix_s *_r = (CAO_matrix_s *) r;
	CAO_matrix_s *_a = (CAO_matrix_s *) a;
	CAO_matrix_s *_b = (CAO_matrix_s *) b;

	int pr, p, i, j;

	pr = 0;
	p = 0;
	for (i = 0; (i < _a->rows); i++)
	{
		for (j = 0; (j < _a->cols); j++)
		{
			CAO_global_assign(_r->value[pr++], _a->value[p++], _r->type);
		}
	}

	p = 0;
	for (i = 0; (i < _b->rows); i++)
	{
		for (j = 0; (j < _b->cols); j++)
		{
			CAO_global_assign(_r->value[pr++], _b->value[p++], _r->type);
		}
	}

	return CAO_OK;
}

CAO_RES CAO_matrix_dump(CAO_matrix m)
{
	CAO_matrix_s *_m = (CAO_matrix_s *) m;
	int rm = _m->rows, cm = _m->cols, i, j, k;

	cout << "matrix[" << rm << " x " << cm << "] = \n";
	k = 0;
	for (i = 0; (i < rm); i++)
	{
		cout << "row " << i << "\n";
		for (j = 0; (j < cm); j++)
		{
			CAO_global_dump(_m->value[k++], _m->type);
			std::cout << "\n";
		}
	}
	cout << "end of matrix[" << rm << " x " << cm << "] = \n";

	return CAO_OK;
}

CAO_RES CAO_matrix_addTo(CAO_matrix r, CAO_matrix m)
{
	CAO_matrix_s *_r = (CAO_matrix_s *) r;
	CAO_matrix_s *_m = (CAO_matrix_s *) m;
	int rsize = (_r->rows * _r->cols), i;

	CAO_RES res = CAO_OK;

	for (i = 0; ((i < rsize) && (res == CAO_OK)); i++)
	{
		res = CAO_global_addTo(_r->value[i], _m->value[i], _m->type);
	}

	return res;
}

CAO_RES CAO_matrix_add(CAO_matrix r, CAO_matrix a, CAO_matrix b)
{

	if ((CAO_matrix_assign(r, a) == CAO_OK)
		&& (CAO_matrix_addTo(r, b) == CAO_OK))
	{
		return CAO_OK;
	}
	else
	{
		return CAO_ERR;
	}
}

CAO_RES CAO_matrix_subTo(CAO_matrix r, CAO_matrix m)
{
	CAO_matrix_s *_r = (CAO_matrix_s *) r;
	CAO_matrix_s *_m = (CAO_matrix_s *) m;
	int rsize = (_r->rows * _r->cols), i;
	CAO_RES res = CAO_OK;

	for (i = 0; ((i < rsize) && (res == CAO_OK)); i++)
	{
		res = CAO_global_subTo(_r->value[i], _m->value[i], _m->type);
	}

	return res;
}

CAO_RES CAO_matrix_sub(CAO_matrix r, CAO_matrix a, CAO_matrix b)
{

	if ((CAO_matrix_assign(r, a) == CAO_OK)
		&& (CAO_matrix_subTo(r, b) == CAO_OK))
	{
		return CAO_OK;
	}
	else
	{
		return CAO_ERR;
	}
}

CAO_RES CAO_matrix_sym(CAO_matrix r, CAO_matrix m)
{
	CAO_matrix_s *_r = (CAO_matrix_s *) r;
	CAO_matrix_s *_m = (CAO_matrix_s *) m;
	int rsize = (_r->rows * _r->cols), i;
	CAO_RES res = CAO_OK;

	for (i = 0; ((i < rsize) && (res == CAO_OK)); i++)
	{
		res = CAO_global_sym(_r->value[i], _m->value[i], _m->type);
	}

	return res;
}

CAO_RES CAO_matrix_mul(CAO_matrix r, CAO_matrix a, CAO_matrix b)
{
	CAO_matrix_s *_r = (CAO_matrix_s *) r;
	CAO_matrix_s *_a = (CAO_matrix_s *) a;
	CAO_matrix_s *_b = (CAO_matrix_s *) b;

	CAO_REF tmp;
	char type = _r->type;
	int i, j, k;

	CAO_global_clone(&tmp, _r->value[0], type);

	for (i = 0; (i < _a->rows); i++)
	{
		for (j = 0; (j < _b->cols); j++)
		{
			CAO_global_mul(_r->value[i * _r->cols + j],
						   _a->value[i * _a->cols], _b->value[j], type);

			for (k = 1; (k < _a->cols); k++)
				CAO_global_mul(tmp, _a->value[i * _a->cols + k],
							   _b->value[k * _b->cols + j], type);

			CAO_global_addTo(_r->value[i * _r->cols + j], tmp, type);
		}
	}

	CAO_global_dispose(tmp, type);

	return CAO_OK;
}

CAO_RES CAO_matrix_assign_zero(CAO_matrix r)
{
	CAO_matrix_s *_r = (CAO_matrix_s *) r;

	int i, size = (_r->rows * _r->cols);
	CAO_RES res = CAO_OK;

	for (i = 0; ((i < size) && (res == CAO_OK)); i++)
		res = CAO_global_assign_zero(_r->value[i], _r->type);

	return res;
}

CAO_RES CAO_matrix_assign_one(CAO_matrix r)
{
	CAO_matrix_s *_r = (CAO_matrix_s *) r;
	int i, size = _r->rows * _r->cols;
	CAO_RES res = CAO_OK;

	for (i = 0; ((i < size) && (res == CAO_OK)); i++)
		res = CAO_global_assign_zero(_r->value[i], _r->type);

	for (i = 0; ((i < _r->rows) && (res = CAO_OK)); i += _r->cols)
		res = CAO_global_assign_one(_r->value[i], _r->type);

	return res;
}

CAO_RES CAO_matrix_pow(CAO_matrix r, CAO_matrix m, CAO_int n)
{

	CAO_matrix a, aAux, rAux;
	int junk = 1;
	CAO_RES res = CAO_OK;
	ZZ _n = *(ZZ *) n;
	// Check for negative?
	CAO_matrix_clone(&a, m);
	CAO_matrix_clone(&aAux, a);
	CAO_matrix_clone(&(rAux), r);

	while (!IsZero(_n))
	{
		if (IsOdd(_n))
		{
			if (junk)
			{
				junk = 0;
				CAO_matrix_assign(r, a);
			}
			else
			{
				CAO_matrix_assign(rAux, r);
				CAO_matrix_mul(r, rAux, a);
			}
		}
		CAO_matrix_mul(aAux, a, a);
		CAO_matrix_assign(a, aAux);
		_n = _n / 2;
	}
	if (junk)
		res = CAO_ERR;
	CAO_matrix_dispose(a);
	CAO_matrix_dispose(aAux);
	CAO_matrix_dispose(rAux);
	return res;
}

CAO_RES CAO_matrix_cast_matrix(CAO_matrix d, CAO_matrix s)
{
	CAO_matrix_s *_s = (CAO_matrix_s *) s;
	CAO_matrix_s *_d = (CAO_matrix_s *) d;

	int i, size = (_s->rows * _s->cols);
	CAO_RES res;

	res = CAO_OK;
	for (i = 0; ((res == CAO_OK) && (i < size)); i++)
		res = CAO_global_cast(_d->value[i], _d->type, _s->value[i], _s->type);

	return res;
}
