/*  CAO Compiler
    Copyright (C) 2014 Cryptography and Information Security Group, HASLab - INESC TEC and Universidade do Minho

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.  */

#include "CAO_ubits.h"

CAO_RES CAO_ubits_decl(CAO_ubits * b, const int s)
{
	CAO_ubits_s *_b = (CAO_ubits_s *) malloc(sizeof(CAO_ubits_s));
	_b->size = s;
	_b->value = new ZZ;
	*b = _b;
	return CAO_OK;
}

CAO_RES CAO_ubits_init(CAO_ubits b, const char *val)
{
	CAO_ubits_s *_b = (CAO_ubits_s *) b;
	*(_b->value) = to_ZZ(val);
	// b = _b;
	return CAO_OK;
}

CAO_RES CAO_ubits_assign(CAO_ubits r, CAO_ubits b)
{
	CAO_ubits_s *_b = (CAO_ubits_s *) b;
	CAO_ubits_s *_r = (CAO_ubits_s *) r;
	ZZ *zr = _r->value;
	ZZ *zb = _b->value;
	*zr = (*zb);
	return CAO_OK;
}

CAO_RES CAO_ubits_clone(CAO_ubits * b, CAO_ubits a)
{
	CAO_ubits_s *_a = (CAO_ubits_s *) a;
	CAO_ubits_decl(b, _a->size);
	CAO_ubits_assign(*b, a);
	return CAO_OK;
}

CAO_bool _CAO_ubits_equal(CAO_ubits i, CAO_ubits j)
{
	CAO_ubits_s *_i = (CAO_ubits_s *) i;
	CAO_ubits_s *_j = (CAO_ubits_s *) j;
	ZZ *zi = _i->value;
	ZZ *zj = _j->value;
	CAO_bool r = ((*zi) == (*zj));
	return r;
}

CAO_bool _CAO_ubits_nequal(CAO_ubits i, CAO_ubits j)
{
	CAO_ubits_s *_i = (CAO_ubits_s *) i;
	CAO_ubits_s *_j = (CAO_ubits_s *) j;
	ZZ *zi = _i->value;
	ZZ *zj = _j->value;
	CAO_bool r = !((*zi) == (*zj));	// !=
	return r;
}

CAO_RES CAO_ubits_or(CAO_ubits r, CAO_ubits i, CAO_ubits j)
{
	CAO_ubits_s *_r = (CAO_ubits_s *) r;
	CAO_ubits_s *_i = (CAO_ubits_s *) i;
	CAO_ubits_s *_j = (CAO_ubits_s *) j;
	ZZ *zr = _r->value;
	ZZ *zi = _i->value;
	ZZ *zj = _j->value;
	*zr = (*zi) | (*zj);
	return CAO_OK;
}

CAO_RES CAO_ubits_and(CAO_ubits r, CAO_ubits i, CAO_ubits j)
{
	CAO_ubits_s *_i = (CAO_ubits_s *) i;
	CAO_ubits_s *_j = (CAO_ubits_s *) j;
	CAO_ubits_s *_r = (CAO_ubits_s *) r;
	ZZ *zi = _i->value;
	ZZ *zj = _j->value;
	ZZ *zr = _r->value;
	*zr = (*zi) & (*zj);
	return CAO_OK;
}

CAO_RES CAO_ubits_xor(CAO_ubits r, CAO_ubits i, CAO_ubits j)
{
	CAO_ubits_s *_i = (CAO_ubits_s *) i;
	CAO_ubits_s *_j = (CAO_ubits_s *) j;
	CAO_ubits_s *_r = (CAO_ubits_s *) r;
	ZZ *zi = _i->value;
	ZZ *zj = _j->value;
	ZZ *zr = _r->value;
	*zr = (*zi) ^ (*zj);
	return CAO_OK;
}

CAO_RES CAO_ubits_not(CAO_ubits r, CAO_ubits i)
{
	long j;
	CAO_ubits_s *_i = (CAO_ubits_s *) i;
	CAO_ubits_s *_r = (CAO_ubits_s *) r;
	ZZ *zi = _i->value;
	ZZ *zr = _r->value;
	*zr = *zi;
	for (j = 0; j < _i->size; j++)
	{
		SwitchBit(*zr, j);
	}
	return CAO_OK;
}

CAO_RES CAO_ubits_shift_up(CAO_ubits r, CAO_ubits i, CAO_rint e)
{
	CAO_ubits_s *_i = (CAO_ubits_s *) i;
	CAO_ubits_s *_r = (CAO_ubits_s *) r;
	ZZ base;
	ZZ *zi = _i->value;
	ZZ *zr = _r->value;
	int si = _i->size;
	power(base, 2, si);
	*zr = ((*zi) << e) % base;
	return CAO_OK;
}

CAO_RES CAO_ubits_shift_down(CAO_ubits r, CAO_ubits i, CAO_rint e)
{
	CAO_ubits_s *_i = (CAO_ubits_s *) i;
	CAO_ubits_s *_r = (CAO_ubits_s *) r;
	ZZ base;
	ZZ *zi = _i->value;
	ZZ *zr = _r->value;
	*zr = (*zi) >> e;
	return CAO_OK;
}

CAO_RES CAO_ubits_rot_up(CAO_ubits r, CAO_ubits i, CAO_rint e)
{
	CAO_ubits_s *_i = (CAO_ubits_s *) i;
	CAO_ubits_s *_r = (CAO_ubits_s *) r;
	ZZ a, base, upper;

	ZZ *zi = _i->value;
	ZZ *zr = _r->value;
	int si = _i->size;

	power(base, 2, si);
	a = *zi << e;
	upper = a / base;
	a = a % base;
	*zr = a + upper;
	return CAO_OK;
}

CAO_RES CAO_ubits_rot_down(CAO_ubits r, CAO_ubits i, CAO_rint e)
{
	CAO_ubits_s *_i = (CAO_ubits_s *) i;
	CAO_ubits_s *_r = (CAO_ubits_s *) r;
	ZZ a, base, lower;

	ZZ *zi = _i->value;
	ZZ *zr = _r->value;
	int si = _i->size;

	power(base, 2, e);
	lower = *zi % base;
	a = *zi >> e;
	lower = lower << (si - e);
	*zr = a + lower;
	return CAO_OK;
}

CAO_RES CAO_ubits_select(CAO_ubits r, CAO_ubits b, CAO_rint e)
{
	CAO_ubits_s *_b = (CAO_ubits_s *) b;
	CAO_ubits_s *_r = (CAO_ubits_s *) r;

	ZZ *zb = _b->value;
	ZZ *zr = _r->value;
	long _bit = bit(*zb, e);
	*zr = _bit;
	return CAO_OK;
}

CAO_RES CAO_ubits_set(CAO_ubits r, CAO_ubits b, CAO_rint e)
{
	CAO_ubits_s *_r = (CAO_ubits_s *) r;
	CAO_ubits_s *_b = (CAO_ubits_s *) b;

	ZZ *zr = _r->value;
	ZZ *zb = _b->value;

	if (bit(*zr, e) != bit(*zb, 0))
	{
		SwitchBit(*zr, e);
	}

	return CAO_OK;
}

CAO_RES CAO_ubits_range_select(CAO_ubits r, CAO_ubits b, CAO_rint e, CAO_rint j)
{
	CAO_ubits_s *_r = (CAO_ubits_s *) r;
	CAO_ubits_s *_b = (CAO_ubits_s *) b;
	ZZ a, base, lower;

	ZZ *zb = _b->value;
	ZZ *zr = _r->value;
	int ns = j - e + 1;

	power(base, 2, ns);
	a = (*zb) >> e;
	lower = a % base;
	*zr = lower;
	return CAO_OK;
}

CAO_RES CAO_ubits_range_set(CAO_ubits r, CAO_ubits b, CAO_rint e, CAO_rint j)
{
	CAO_ubits_s *_r = (CAO_ubits_s *) r;
	CAO_ubits_s *_b = (CAO_ubits_s *) b;
	ZZ a, base, lower;

	ZZ *zb = _b->value;
	ZZ *zr = _r->value;
	int ns = j - e + 1;

	power(base, 2, e);
	lower = *zr % base;
	a = (*zr) >> (j + 1);
	a <<= ns;
	a += *zb;
	a <<= e;
	*zr = a + lower;
	return CAO_OK;
}

CAO_RES CAO_ubits_concat(CAO_ubits r, CAO_ubits a, CAO_ubits b)
{
	CAO_ubits_s *_r = (CAO_ubits_s *) r;
	CAO_ubits_s *_a = (CAO_ubits_s *) a;
	CAO_ubits_s *_b = (CAO_ubits_s *) b;
	ZZ nval, base;

	ZZ *zr = _r->value;
	ZZ *za = _a->value;
	ZZ *zb = _b->value;
	int sa = _a->size;

	power(base, 2, sa);
	nval = (*zb) * base;
	*zr = nval + *za;

	return CAO_OK;
}

CAO_RES CAO_ubits_dump(CAO_ubits b)
{
	CAO_ubits_s *_b = (CAO_ubits_s *) b;
	int size = (_b->size);
	ZZ *val = _b->value;
	cout << "bits[" << size << "] = " << *val << "\n";
	return CAO_OK;
}

CAO_RES CAO_ubits_dispose(CAO_ubits a)
{
	CAO_ubits_s *_a = (CAO_ubits_s *) a;
	delete(_a->value);
	free(_a);
	return CAO_OK;
}

CAO_RES CAO_ubits_cast_int(CAO_int b, CAO_ubits a)
{
	CAO_ubits_s *_a = (CAO_ubits_s *) a;
	ZZ *_b = (ZZ *) b;
	*_b = *_a->value;
	return CAO_OK;
}

CAO_RES CAO_int_cast_ubits(CAO_ubits b, CAO_int a)
{
	CAO_ubits_s *_b = (CAO_ubits_s *) b;
	ZZ *_a = (ZZ *) a;
	ZZ base;

	power(base, 2, _b->size);
	if (sign(*_a) == -1)
	{
		*_b->value = base + (*_a);
	}
	else
	{
		*_b->value = *_a;
	}
	*_b->value = (*_b->value) % base;
	return CAO_OK;
}

CAO_RES CAO_ubits_cast_mod(CAO_mod b, CAO_ubits a)
{
	CAO_int aux;
	CAO_int_decl(&aux);
	CAO_ubits_cast_int(aux, a);
	CAO_int_cast_mod(b, aux);
	CAO_int_dispose(aux);
	return CAO_OK;
}

CAO_RES CAO_mod_cast_ubits(CAO_ubits b, CAO_mod a)
{
	CAO_int aux;
	CAO_int_decl(&aux);
	CAO_mod_cast_int(aux, a);
	CAO_int_cast_ubits(b, aux);
	CAO_int_dispose(aux);
	return CAO_OK;
}

CAO_RES CAO_ubits_cast_ubits(CAO_ubits b, CAO_ubits a)
{
	CAO_ubits_s *_b = (CAO_ubits_s *) b;
	CAO_ubits_s *_a = (CAO_ubits_s *) a;
	ZZ base;
	power(base, 2, _b->size);
	*_b->value = (*_a->value) % base;
	return CAO_OK;
}
