/*
 * Copyright 1999, Alexander Feldman <alex@varna.net>
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. Neither the name of Alexander Feldman nor the names of its contributors
 *    may be used to endorse or promote products derived from this software
 *    without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY ALEXANDER FELDMAN AND CONTRIBUTORS ``AS IS''
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL ALEXANDER FELDMAN OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#include <stdio.h>

#include "rijndael.hpp"
#include "rijndaeltables.hpp"

#define rol(p, q) (((p) << (q)) | ((p) >> (32 - (q))))
#define ror(p, q) (((p) >> (q)) | ((p) << (32 - (q))))

#define sbox(x) (MakeWord(bSt[WBYTE0(x)],					\
								  bSt[WBYTE1(x)],					\
								  bSt[WBYTE2(x)],					\
								  bSt[WBYTE3(x)]))

#define l4(i, j, wKey)											\
j = sbox(ror(j, 8)) ^ wR[i];									\
wKey[4 * i + 4] = j ^= wKey[4 * i    ];					\
wKey[4 * i + 5] = j ^= wKey[4 * i + 1];					\
wKey[4 * i + 6] = j ^= wKey[4 * i + 2];					\
wKey[4 * i + 7] = j ^= wKey[4 * i + 3]

#define l6(i, j, wKey)											\
j = sbox(ror(j, 8)) ^ wR[i];									\
wKey[6 * i +  6] = j ^= wKey[6 * i    ];					\
wKey[6 * i +  7] = j ^= wKey[6 * i + 1];					\
wKey[6 * i +  8] = j ^= wKey[6 * i + 2];					\
wKey[6 * i +  9] = j ^= wKey[6 * i + 3];					\
wKey[6 * i + 10] = j ^= wKey[6 * i + 4];					\
wKey[6 * i + 11] = j ^= wKey[6 * i + 5]

#define l8(i, j, wKey)											\
j = sbox(ror(j, 8)) ^ wR[i];									\
wKey[8 * i +  8] = j ^= wKey[8 * i    ];					\
wKey[8 * i +  9] = j ^= wKey[8 * i + 1];					\
wKey[8 * i + 10] = j ^= wKey[8 * i + 2];					\
wKey[8 * i + 11] = j ^= wKey[8 * i + 3];					\
wKey[8 * i + 12] = j  = wKey[8 * i + 4] ^ sbox(j);		\
wKey[8 * i + 13] = j ^= wKey[8 * i + 5];					\
wKey[8 * i + 14] = j ^= wKey[8 * i + 6];					\
wKey[8 * i + 15] = j ^= wKey[8 * i + 7]

#define s(x) (((x) & 0x7f7f7f7fL) << 1) ^ ((((x) & 0x80808080L) >> 7) * 0x1bL)

#define z(i, n, q) wF[0][WBYTE3(i[(n    )    ])] ^								\
						 wF[1][WBYTE2(i[(n + 1) & 3])] ^								\
						 wF[2][WBYTE1(i[(n + 2) & 3])] ^								\
						 wF[3][WBYTE0(i[(n + 3) & 3])] ^ *(q + n)

#define x(i, n, q) wI[0][WBYTE3(i[(n    )    ])] ^								\
						 wI[1][WBYTE2(i[(n + 3) & 3])] ^								\
						 wI[2][WBYTE1(i[(n + 2) & 3])] ^								\
						 wI[3][WBYTE0(i[(n + 1) & 3])] ^ *(q + n)

#define y(i, n, q, r) (MakeWord(r[WBYTE0(i[(n + 3) & 3])],					\
										  r[WBYTE1(i[(n + 2) & 3])],					\
										  r[WBYTE2(i[(n + 1) & 3])],					\
										  r[WBYTE3(i[(n + 0)    ])]) ^ *(q + n))

#define w(i, n, q, r) (MakeWord(r[WBYTE0(i[(n + 1) & 3])],					\
										  r[WBYTE1(i[(n + 2) & 3])],					\
										  r[WBYTE2(i[(n + 3) & 3])],					\
										  r[WBYTE3(i[(n + 0)    ])]) ^ *(q + n))

#define f(a, b, k) a[0] = z(b, 0, k);												\
						 a[1] = z(b, 1, k);												\
						 a[2] = z(b, 2, k);												\
						 a[3] = z(b, 3, k)

#define l(a, b, k) a[0] = y(b, 0, k, bSt);										\
						 a[1] = y(b, 1, k, bSt);										\
						 a[2] = y(b, 2, k, bSt);										\
						 a[3] = y(b, 3, k, bSt)

#define g(a, b, k) a[0] = x(b, 0, k);												\
						 a[1] = x(b, 1, k);												\
						 a[2] = x(b, 2, k);												\
						 a[3] = x(b, 3, k)

#define m(a, b, k) a[0] = w(b, 0, k, bIt);										\
						 a[1] = w(b, 1, k, bIt);										\
						 a[2] = w(b, 2, k, bIt);										\
						 a[3] = w(b, 3, k, bIt)

CRijndaelKey::CRijndaelKey()
{
}

CRijndaelKey::CRijndaelKey(const CRijndaelKey &cRijndaelKey)
{
	for (int i = 0; i < 8; i++)
		wMaster[i] = cRijndaelKey.wMaster[i];
	wKeyLength = cRijndaelKey.wKeyLength;
	MakeKeys();
}

CRijndaelKey::CRijndaelKey(const Byte *pbMasterKey, Word wKeyLength)
{
	CRijndaelKey::wKeyLength = wKeyLength / BYTESINWORD;
	memcpy(wMaster, pbMasterKey, wKeyLength);
	memset(wMaster + wKeyLength, 0, sizeof(wMaster) - wKeyLength);
	MakeKeys();
}

CRijndaelKey::CRijndaelKey(const Word *pwMasterKey, Word wKeyLength)
{
	Word i;
	CRijndaelKey::wKeyLength = wKeyLength;
	for (i = 0; i < wKeyLength; i++)
		wMaster[i] = pwMasterKey[i];
	for ( ; i < 8; i++)
		wMaster[i] = 0;
	MakeKeys();
}

CRijndaelKey::~CRijndaelKey()
{
	int i;
	for (i = 0; i < 8; i++)					// Cleanup
		wMaster[i] = 0;
	for (i = 0; i < 60; i++)
		wEncryptionKey[i] = wDecryptionKey[i] = 0;
}

void CRijndaelKey::MakeKeys()
{
	Word i, j, u, v, w;
	
	memset(wEncryptionKey, 0, sizeof(wEncryptionKey));
	memset(wDecryptionKey, 0, sizeof(wDecryptionKey));
	
	wEncryptionKey[0] = wMaster[0];
	wEncryptionKey[1] = wMaster[1];
	wEncryptionKey[2] = wMaster[2];
	wEncryptionKey[3] = wMaster[3];

	switch(wKeyLength) {
		case 4:
			j = wEncryptionKey[3];
			for (i = 0; i < 10; i++) {
				l4(i, j, wEncryptionKey);
			}
		break;
		case 6:
			wEncryptionKey[4] = wMaster[4];
			wEncryptionKey[5] = wMaster[5];
			j = wEncryptionKey[5];
			for (i = 0; i < 8; i++) {
				l6(i, j, wEncryptionKey);
			}
		break;
		case 8:
			wEncryptionKey[4] = wMaster[4];
			wEncryptionKey[5] = wMaster[5];
			wEncryptionKey[6] = wMaster[6];
			wEncryptionKey[7] = wMaster[7];
			j = wEncryptionKey[7];
			for (i = 0; i < 7; i++) {
				l8(i, j, wEncryptionKey);
			}
		break;
	}
	wDecryptionKey[0] = wEncryptionKey[0];
	wDecryptionKey[1] = wEncryptionKey[1];
	wDecryptionKey[2] = wEncryptionKey[2];
	wDecryptionKey[3] = wEncryptionKey[3];
	for (i = 4; i < wKeyLength * 4 + 24; i++) {
		u = s(wEncryptionKey[i]);
		v = s(u);
		w = s(v);
		j = w ^ wEncryptionKey[i];
		wDecryptionKey[i] = (u ^ v ^ w) ^ ror(u ^ j, 8) ^ ror(v ^ j, 16) ^ ror(j, 24);
	}
}

CRijndaelBlock::CRijndaelBlock()
{
	for (int i = 0; i < 4; i++)			// This is here for definiteness
		wData[i] = 0;
}

CRijndaelBlock::CRijndaelBlock(const CRijndaelBlock &cRijndaelBlock)
{
	for (int i = 0; i < 4; i++)	
		wData[i] = cRijndaelBlock.wData[i];
}

CRijndaelBlock::CRijndaelBlock(Word a, Word b, Word c, Word d)
{
	wData[0] = a;
	wData[1] = b;
	wData[2] = c;
	wData[3] = d;
}

CRijndaelBlock::CRijndaelBlock(const Byte *pbData, Word wLength)
{
	SetData(pbData, wLength);
}

CRijndaelBlock::~CRijndaelBlock()
{
	for (int i = 0; i < 4; i++)			// Cleanup
		wData[i] = 0;
}

void CRijndaelBlock::Encrypt(const CRijndaelKey &cRijndaelKey)
{
	Word wTemp[4], k = 4;
	const Word *pwEncryptionKey = cRijndaelKey.GetEncryptionKeys();
	
	wData[0] ^= pwEncryptionKey[0];
	wData[1] ^= pwEncryptionKey[1];
	wData[2] ^= pwEncryptionKey[2];
	wData[3] ^= pwEncryptionKey[3];

	if (cRijndaelKey.GetKeyLength() > 4) {
		f(wTemp, wData, pwEncryptionKey + k); k += 4;
		f(wData, wTemp, pwEncryptionKey + k); k += 4;
	}
	if (cRijndaelKey.GetKeyLength() > 6) {
		f(wTemp, wData, pwEncryptionKey + k); k += 4;
		f(wData, wTemp, pwEncryptionKey + k); k += 4;
	}
	f(wTemp, wData, pwEncryptionKey + k); k += 4;
	f(wData, wTemp, pwEncryptionKey + k); k += 4;
	f(wTemp, wData, pwEncryptionKey + k); k += 4;
	f(wData, wTemp, pwEncryptionKey + k); k += 4;
	f(wTemp, wData, pwEncryptionKey + k); k += 4;
	f(wData, wTemp, pwEncryptionKey + k); k += 4;
	f(wTemp, wData, pwEncryptionKey + k); k += 4;
	f(wData, wTemp, pwEncryptionKey + k); k += 4;
	f(wTemp, wData, pwEncryptionKey + k); k += 4;
	l(wData, wTemp, pwEncryptionKey + k);
}

void CRijndaelBlock::Decrypt(const CRijndaelKey &cRijndaelKey)
{
	Word wKeyLength = cRijndaelKey.GetKeyLength();
	Word wTemp[4], k = 4 * (wKeyLength + 5);
	const Word *pwDecryptionKey = cRijndaelKey.GetDecryptionKeys();
	const Word *pwEncryptionKey = cRijndaelKey.GetEncryptionKeys();
	
	wData[0] ^= pwEncryptionKey[4 * wKeyLength + 24];
	wData[1] ^= pwEncryptionKey[4 * wKeyLength + 25];
	wData[2] ^= pwEncryptionKey[4 * wKeyLength + 26];
	wData[3] ^= pwEncryptionKey[4 * wKeyLength + 27];
	
	if (wKeyLength > 4) {
		g(wTemp, wData, pwDecryptionKey + k); k -= 4;
		g(wData, wTemp, pwDecryptionKey + k); k -= 4;
	}
	if (wKeyLength > 6) {
		g(wTemp, wData, pwDecryptionKey + k); k -= 4;
		g(wData, wTemp, pwDecryptionKey + k); k -= 4;
	}
	g(wTemp, wData, pwDecryptionKey + k); k -= 4;
	g(wData, wTemp, pwDecryptionKey + k); k -= 4;
	g(wTemp, wData, pwDecryptionKey + k); k -= 4;
	g(wData, wTemp, pwDecryptionKey + k); k -= 4;
	g(wTemp, wData, pwDecryptionKey + k); k -= 4;
	g(wData, wTemp, pwDecryptionKey + k); k -= 4;
	g(wTemp, wData, pwDecryptionKey + k); k -= 4;
	g(wData, wTemp, pwDecryptionKey + k); k -= 4;
	g(wTemp, wData, pwDecryptionKey + k); k -= 4;
	m(wData, wTemp, pwDecryptionKey + k);
}

void CRijndaelBlock::SetData(const Byte *pbData, Word wLength)
{
	memcpy((Byte *)wData, pbData, wLength);
	memset((Byte *)wData + wLength, 0, sizeof(wData) - wLength);
}

Byte *CRijndaelBlock::GetData()
{
	return (Byte *)wData;
}
