/*  CAO Compiler
    Copyright (C) 2014 Cryptography and Information Security Group, HASLab - INESC TEC and Universidade do Minho

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.  */

#include "CAO_globalOp.h"

// type is a string that codes the folowing Haskell type
// data CAOType = A // Int
// | B // Bool
// | C Integer // Mod
// | D RInt // Unsigned Bit Array with dimension
// | E RInt CAOType // Vector n of type
// | F RInt RInt CAOType // Matrix n x m of type
// | G RInt [CAOType] // Struct with n components
// | H RInt [Integer] // ModPol degree [basemod,coef_0,..,coef_deg]
// | I // RInt
// | J RInt // Signed Bit Array with dimension

#define INT_Type    'A'
#define BOOL_Type   'B'
#define MOD_Type    'C'
#define UBITS_Type  'D'
#define VECTOR_Type 'E'
#define MATRIX_Type 'F'
#define STRUCT_Type 'G'
#define MODPOL_Type 'H'
#define RINT_Type   'I'
#define SBITS_Type  'J'

CAO_RES CAO_global_decl(CAO_REF * n, const char *type, void *indices[])
{
	int jump;
	return _CAO_global_decl(n, type, indices, &jump);
}

CAO_RES _CAO_global_decl(CAO_REF * n, const char *type, void *indices[],
						 int *jump)
{

	int res;

	switch (*type)
	{
	case INT_Type:
		res = CAO_int_decl(n);
		*jump = 1;
		break;

	case BOOL_Type:
		*(int **)n = new int;
		**(int **)n = 0;
		*jump = 1;
		res = CAO_OK;
		break;

	case MOD_Type:
		res = CAO_mod_decl(n, (CAO_int) (indices[0]));
		*jump = 1;
		break;

	case UBITS_Type:
		res = CAO_ubits_decl(n, *(CAO_rint *) (indices[0]));
		*jump = 1;
		break;

	case SBITS_Type:
		res = CAO_sbits_decl(n, *(CAO_rint *) (indices[0]));
		*jump = 1;
		break;

	case VECTOR_Type:
		res =
			_CAO_vector_decl(n, *(CAO_rint *) (indices[0]), type + 1,
							 indices + 1, jump);
		(*jump)++;
		break;

	case MATRIX_Type:
		res =
			_CAO_matrix_decl(n, ((CAO_rint *) (indices[0]))[0],
							 ((CAO_rint *) (indices[0]))[1], type + 1,
							 indices + 1, jump);
		(*jump)++;
		break;

	case STRUCT_Type:
		res =
			_CAO_struct_decl(n, *(CAO_rint *) (indices[0]), type + 1,
							 indices + 1, jump);
		(*jump)++;
		break;

	case MODPOL_Type:
		res = CAO_modpol_decl(n, ((CAO_int *)indices[0])[0], ((CAO_int *)indices[0])[1], (((CAO_int *)indices[0]) + 2));
		*jump = 1;
		break;

	case RINT_Type:
		*(int **)n = new int;
		**(int **)n = 0;
		*jump = 1;
		res = CAO_OK;
		break;

	default:
		return CAO_ERR;
	}
	return res;
}

CAO_RES CAO_global_dispose(CAO_REF r, char type)
{

	int res;

	switch (type)
	{

	case INT_Type:
		res = CAO_int_dispose(r);
		break;

	case BOOL_Type:
		delete((int *)r);
		res = CAO_OK;
		break;

	case MOD_Type:
		res = CAO_mod_dispose(r);
		break;

	case UBITS_Type:
		res = CAO_ubits_dispose(r);
		break;
	case SBITS_Type:
		res = CAO_sbits_dispose(r);
		break;
	case VECTOR_Type:
		res = CAO_vector_dispose(r);
		break;
	case MATRIX_Type:
		res = CAO_matrix_dispose(r);
		break;
	case STRUCT_Type:
		res = CAO_struct_dispose(r);
		break;
	case MODPOL_Type:
		res = CAO_modpol_dispose(r);
		break;

	case RINT_Type:
		delete((int *)r);
		res = CAO_OK;
		break;

	default:
		return CAO_ERR;
	}
	return res;
}

CAO_RES CAO_global_assign(CAO_REF dest, CAO_REF source, char type)
{
	int res;
	switch (type)
	{
	case INT_Type:
		res = CAO_int_assign(dest, source);
		break;

	case BOOL_Type:
		*(int *)dest = *(int *)source;
		res = CAO_OK;
		break;

	case MOD_Type:
		res = CAO_mod_assign(dest, source);
		break;

	case UBITS_Type:
		res = CAO_ubits_assign(dest, source);
		break;

	case SBITS_Type:
		res = CAO_sbits_assign(dest, source);
		break;
	case VECTOR_Type:
		res = CAO_vector_assign(dest, source);
		break;
	case MATRIX_Type:
		res = CAO_matrix_assign(dest, source);
		break;
	case STRUCT_Type:
		res = CAO_struct_assign(dest, source);
		break;
	case MODPOL_Type:
		res = CAO_modpol_assign(dest, source);
		break;
	case RINT_Type:
		*(int *)dest = *(int *)source;
		res = CAO_OK;
		break;
	default:
		return CAO_ERR;
	}

	return res;
}

CAO_RES CAO_global_clone(CAO_REF * dest, CAO_REF source, char type)
{
	int res;
	switch (type)
	{
	case INT_Type:
		res = CAO_int_clone(dest, source);
		break;
	case BOOL_Type:
		*((int **)dest) = new int;
		*(*(int **)dest) = *(int *)source;
		res = CAO_OK;
		break;
	case MOD_Type:
		res = CAO_mod_clone(dest, source);
		break;
	case UBITS_Type:
		res = CAO_ubits_clone(dest, source);
		break;
	case SBITS_Type:
		res = CAO_sbits_clone(dest, source);
		break;
	case VECTOR_Type:
		res = CAO_vector_clone(dest, source);
		break;
	case MATRIX_Type:
		res = CAO_matrix_clone(dest, source);
		break;
	case STRUCT_Type:
		res = CAO_struct_clone(dest, source);
		break;
	case MODPOL_Type:
		res = CAO_modpol_clone(dest, source);
		break;
	case RINT_Type:
		*(int **)dest = new int;
		*(int *)dest = *(int *)source;
		res = CAO_OK;
		break;
	default:
		return CAO_ERR;
	}
	return res;
}

CAO_bool _CAO_global_equal(CAO_REF a, CAO_REF b, char type)
{
	switch (type)
	{
	case INT_Type:
		return _CAO_int_equal(a, b);
		break;
	case BOOL_Type:
		return *(int *)a == *(int *)b;
		break;
	case MOD_Type:
		return _CAO_mod_equal(a, b);
		break;
	case UBITS_Type:
		return _CAO_ubits_equal(a, b);
		break;
	case SBITS_Type:
		return _CAO_sbits_equal(a, b);
		break;
	case VECTOR_Type:
		return _CAO_vector_equal(a, b);
		break;
	case MATRIX_Type:
		return _CAO_matrix_equal(a, b);
		break;
	case STRUCT_Type:
		return _CAO_struct_equal(a, b);
		break;
	case MODPOL_Type:
		return _CAO_modpol_equal(a, b);
		break;
	case RINT_Type:
		return *(int *)a == *(int *)b;
		break;
	default:
		return false;
	}
}

CAO_RES CAO_global_const_init(CAO_REF r, void *value, char type)
{
	int res;
	switch (type)
	{
	case INT_Type:
		res = CAO_int_init(r, (char *)value);
		break;
	case BOOL_Type:
		*(int *)r = *(int *)value;
		res = CAO_OK;
		break;
	case MOD_Type:
		res = CAO_mod_init(r, (char *)value);
		break;
	case UBITS_Type:
		res = CAO_ubits_init(r, (char *)value);
		break;
	case SBITS_Type:
		res = CAO_sbits_init(r, (char *)value);
		break;
	case VECTOR_Type:
		res = CAO_vector_const_init(r, value);
		break;
	case MATRIX_Type:
		res = CAO_matrix_const_init(r, value);
		break;
	case STRUCT_Type:
		res = CAO_struct_const_init(r, value);
		break;
	case MODPOL_Type:
		res = CAO_modpol_init(r, (char *)value);
		break;
	case RINT_Type:
		*(int *)r = *(int *)value;
		res = CAO_OK;
		break;
	default:
		return CAO_ERR;
	}
	return res;

}

CAO_RES CAO_global_init(CAO_REF r, void *value[], char type)
{
	int vjump = 0;
	return _CAO_global_init(r, value, &vjump, type);
}

CAO_RES _CAO_global_init(CAO_REF r, void *value[], int *vjump, char type)
{
	int res;
	switch (type)
	{
	case INT_Type:
		res = CAO_int_init(r, (char *)value[0]);
		*vjump = 1;
		break;

	case BOOL_Type:
		*(int *)r = *(int *)value[0];
		*vjump = 1;
		res = CAO_OK;
		break;

	case MOD_Type:
		res = CAO_mod_init(r, (char *)value[0]);
		*vjump = 1;
		break;

	case UBITS_Type:
		res = CAO_ubits_init(r, (char *)value[0]);
		*vjump = 1;
		break;

	case SBITS_Type:
		res = CAO_sbits_init(r, (char *)value[0]);
		*vjump = 1;
		break;

	case VECTOR_Type:
		res = _CAO_vector_init(r, value, vjump);
		break;

	case MATRIX_Type:
		res = _CAO_matrix_init(r, value, vjump);
		break;

	case STRUCT_Type:
		res = _CAO_struct_init(r, value, vjump);
		break;

	case MODPOL_Type:
		res = CAO_modpol_init(r, (char *)value[0]);
		*vjump = 1;
		break;

	case RINT_Type:
		*(int *)r = *(int *)value[0];
		*vjump = 1;
		res = CAO_OK;
		break;

	default:
		return CAO_ERR;

	}
	return res;

}

CAO_RES CAO_global_dump(CAO_REF r, char type)
{
	int res;
	switch (type)
	{
	case INT_Type:
		res = CAO_int_dump(r);
		break;
	case BOOL_Type:
		res = CAO_bool_dump(*(int *)r);
		break;
	case MOD_Type:
		res = CAO_mod_dump(r);
		break;
	case UBITS_Type:
		res = CAO_ubits_dump(r);
		break;
	case SBITS_Type:
		res = CAO_sbits_dump(r);
		break;

	case VECTOR_Type:
		res = CAO_vector_dump(r);
		break;
	case MATRIX_Type:
		res = CAO_matrix_dump(r);
		break;
	case STRUCT_Type:
		res = CAO_struct_dump(r);
		break;

	case MODPOL_Type:
		res = CAO_modpol_dump(r);
		break;
	case RINT_Type:
		res = CAO_rint_dump(*(int *)r);
		break;
	default:
		return CAO_ERR;
	}
	return res;

}

CAO_RES CAO_global_ref(CAO_REF * res, CAO_REF root, char type, CAO_rint path[],
					   int pathlen)
{
	int container = 1, i = 0;
	CAO_REF ref = root;

	while (container && (i < pathlen))
		switch (type)
		{
		case VECTOR_Type:
			ref = _CAO_vector_ref(ref, path[i++], &type);
			break;
		case MATRIX_Type:
			ref = _CAO_matrix_ref(ref, path[i], path[i + 1], &type);
			i += 2;
			break;
		case STRUCT_Type:
			ref = _CAO_struct_ref(ref, path[i++], &type);
			break;
		default:
			container = 0;
		}
	*res = ref;
	if (i == pathlen)
		return CAO_OK;
	else
		return CAO_ERR;
}

CAO_RES CAO_global_addTo(CAO_REF x, CAO_REF y, char type)
{
	CAO_RES res;

	switch (type)
	{
	case INT_Type:
		res = CAO_int_addTo(x, y);
		break;
	case MOD_Type:
		res = CAO_mod_addTo(x, y);
		break;
	case MODPOL_Type:
		res = CAO_modpol_addTo(x, y);
		break;
	case MATRIX_Type:
		res = CAO_matrix_addTo(x, y);
		break;
	case RINT_Type:
		*(int *)x += *(int *)y;
		res = CAO_OK;
		break;
	default:
		return CAO_ERR;
	}
	return res;
}

CAO_RES CAO_global_subTo(CAO_REF x, CAO_REF y, char type)
{
	CAO_RES res;

	switch (type)
	{
	case INT_Type:
		res = CAO_int_subTo(x, y);
		break;
	case MOD_Type:
		res = CAO_mod_subTo(x, y);
		break;
	case MODPOL_Type:
		res = CAO_modpol_subTo(x, y);
		break;
	case MATRIX_Type:
		res = CAO_matrix_subTo(x, y);
		break;
	case RINT_Type:
		*(int *)x -= *(int *)y;
		res = CAO_OK;
		break;
	default:
		return CAO_ERR;
	}
	return res;
}

CAO_RES CAO_global_sym(CAO_REF x, CAO_REF y, char type)
{
	CAO_RES res;

	switch (type)
	{
	case INT_Type:
		res = CAO_int_sym(x, y);
		break;
	case MOD_Type:
		res = CAO_mod_sym(x, y);
		break;
	case MODPOL_Type:
		res = CAO_modpol_sym(x, y);
		break;
	case MATRIX_Type:
		res = CAO_matrix_sym(x, y);
		break;
	case RINT_Type:
		CAO_rint_sym(*(int *)x, *(int *)y);
		res = CAO_OK;
		break;
	default:
		return CAO_ERR;
	}
	return res;
}

CAO_RES CAO_global_mul(CAO_REF r, CAO_REF a, CAO_REF b, char type)
{
	CAO_RES res;

	switch (type)
	{
	case INT_Type:
		res = CAO_int_mul(r, a, b);
		break;
	case MOD_Type:
		res = CAO_mod_mul(r, a, b);
		break;
	case MODPOL_Type:
		res = CAO_modpol_mul(r, a, b);
		break;
	case MATRIX_Type:
		res = CAO_matrix_mul(r, a, b);
		break;
	case RINT_Type:
		CAO_rint_mul(*(int *)r, *(int *)a, *(int *)b);
		res = CAO_OK;
		break;
	default:
		return CAO_ERR;
	}
	return res;
}

CAO_RES CAO_global_assign_zero(CAO_REF s, char type)
{
	CAO_RES res;

	switch (type)
	{
	case INT_Type:
		res = CAO_int_assign_zero(s);
		break;
	case MOD_Type:
		res = CAO_mod_assign_zero(s);
		break;
	case MODPOL_Type:
		res = CAO_modpol_assign_zero(s);
		break;
	case MATRIX_Type:
		res = CAO_matrix_assign_zero(s);
		break;
	case RINT_Type:
		*(int *)s = 0;
		res = CAO_OK;
		break;
	default:
		return CAO_ERR;
	}
	return res;

}

CAO_RES CAO_global_assign_one(CAO_REF s, char type)
{
	CAO_RES res;

	switch (type)
	{
	case INT_Type:
		res = CAO_int_assign_one(s);
		break;
	case MOD_Type:
		res = CAO_mod_assign_one(s);
		break;
	case MODPOL_Type:
		res = CAO_modpol_assign_one(s);
		break;
	case MATRIX_Type:
		res = CAO_matrix_assign_one(s);
		break;
	case RINT_Type:
		*(int *)s = 1;
		res = CAO_OK;
		break;
	default:
		return CAO_ERR;
	}
	return res;

}

CAO_RES CAO_global_cast(CAO_REF d, char td, CAO_REF s, char ts)
{
	CAO_RES res;

	switch (ts)
	{
	case INT_Type:
		switch (td)
		{
		case RINT_Type:
			CAO_int_cast_rint(*(int *)d, s);
			res = CAO_OK;
			break;
		case UBITS_Type:
			res = CAO_int_cast_ubits(d, s);
			break;
		case SBITS_Type:
			res = CAO_int_cast_sbits(d, s);
			break;
		case MOD_Type:
			res = CAO_int_cast_mod(d, s);
			break;
		default:
			res = CAO_ERR;
		}
		break;
	case MOD_Type:
		switch (td)
		{
		case INT_Type:
			res = CAO_mod_cast_int(d, s);
			break;
		case MOD_Type:
			res = CAO_mod_cast_mod(d, s);
			break;
		case MODPOL_Type:
			res = CAO_mod_cast_modpol(d, s);
			break;
		default:
			res = CAO_ERR;
		}
		break;
	case UBITS_Type:
		switch (td)
		{
		case INT_Type:
			res = CAO_ubits_cast_int(d, s);
			break;
		case UBITS_Type:
			res = CAO_ubits_cast_ubits(d, s);
			break;
		default:
			res = CAO_ERR;
		}
		break;
	case SBITS_Type:
		switch (td)
		{
		case INT_Type:
			res = CAO_sbits_cast_int(d, s);
			break;
		case SBITS_Type:
			res = CAO_sbits_cast_sbits(d, s);
			break;
		default:
			res = CAO_ERR;
		}
		break;
	case VECTOR_Type:
		switch (td)
		{
		case VECTOR_Type:
			res = CAO_vector_cast_vector(d, s);
			break;
		case MATRIX_Type:
			res = CAO_vector_cast_matrix(d, s);
			break;
		case MODPOL_Type:
			res = CAO_vector_cast_modpol(d, s);
			break;
		default:
			res = CAO_ERR;
		}
		break;
	case MATRIX_Type:
		switch (td)
		{
		case MATRIX_Type:
			res = CAO_matrix_cast_matrix(d, s);
			break;
		case VECTOR_Type:
			res = CAO_matrix_cast_vector(d, s);
			break;
		case MODPOL_Type:
			res = CAO_matrix_cast_modpol(d, s);
			break;
		default:
			res = CAO_ERR;
		}
		break;
	case STRUCT_Type:
		switch (td)
		{
		case STRUCT_Type:
			res = CAO_struct_assign(d, s);
			break;
		default:
			res = CAO_ERR;
		}
		break;
	case MODPOL_Type:
		switch (td)
		{
		case VECTOR_Type:
			res = CAO_modpol_cast_vector(d, s);
			break;
		case MATRIX_Type:
			res = CAO_modpol_cast_matrix(d, s);
			break;
		default:
			res = CAO_ERR;
		}
		break;
	case RINT_Type:
		switch (td)
		{
		case INT_Type:
			res = CAO_rint_cast_int(d, *(int *)s);
			break;
		}
		break;
	default:
		res = CAO_ERR;
	}
	return res;
}
