/*
 * Copyright 1999, Alexander Feldman <alex@varna.net>
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. Neither the name of Alexander Feldman nor the names of its contributors
 *    may be used to endorse or promote products derived from this software
 *    without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY ALEXANDER FELDMAN AND CONTRIBUTORS ``AS IS''
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL ALEXANDER FELDMAN OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#include <stdio.h>

#include "square.hpp"
#include "squaretables.hpp"
#include "gf256_1f5.hpp"

#define rol(p, q)			(((p) << (q)) | ((p) >> (32 - (q))))
#define ror(p, q)			(((p) >> (q)) | ((p) << (32 - (q))))

#define t(a, b, t, s)	(a[t] = (s[(Byte)(b[0] >> (24 - 8 * t))]) ^ 			\
										 ((s[(Byte)(b[1] >> (24 - 8 * t))] >>  8) +	\
										  (s[(Byte)(b[1] >> (24 - 8 * t))] << 24)) ^ \
								  		 ((s[(Byte)(b[2] >> (24 - 8 * t))] >> 16) +	\
										  (s[(Byte)(b[2] >> (24 - 8 * t))] << 16)) ^ \
								  		 ((s[(Byte)(b[3] >> (24 - 8 * t))] >> 24) +	\
										  (s[(Byte)(b[3] >> (24 - 8 * t))] <<  8)))

#define r(b, k, s)			\
{									\
	Word wTemp[4];				\
	t(wTemp, b, 0, s);		\
	t(wTemp, b, 1, s);		\
	t(wTemp, b, 2, s);		\
	t(wTemp, b, 3, s);		\
	b[0] = wTemp[0] ^ k[0];	\
	b[1] = wTemp[1] ^ k[1];	\
	b[2] = wTemp[2] ^ k[2];	\
	b[3] = wTemp[3] ^ k[3];	\
}

#define s(a, b, t, u)	(a[t] = MakeWord(u[(Byte)(b[0] >> (24 - 8 * t))],	\
													  u[(Byte)(b[1] >> (24 - 8 * t))],	\
													  u[(Byte)(b[2] >> (24 - 8 * t))],	\
													  u[(Byte)(b[3] >> (24 - 8 * t))]))

CSquareKey::CSquareKey()
{
}

CSquareKey::CSquareKey(const CSquareKey &cSquareKey)
{
	for (int i = 0; i < 4; i++)
		wMaster[i] = cSquareKey.wMaster[i];
	MakeKeys();
}

CSquareKey::CSquareKey(const Word *pwMasterKey)
{
	for (Word i = 0; i < 4; i++)
		wMaster[i] = pwMasterKey[i];
	MakeKeys();
}

CSquareKey::CSquareKey(const Byte *pbMasterKey, Word wLength)
{
	memcpy((Byte *)wMaster, pbMasterKey, wLength);
	memset((Byte *)wMaster + wLength, 0, sizeof(wMaster) - wLength);
	for (Word i = 0; i < 4; i++)
		wMaster[i] = REVERSEWORD(wMaster[i]);
	MakeKeys();
}

CSquareKey::~CSquareKey()
{
	for (int i = 0; i < 4; i++)			// Cleanup
		wMaster[i] = 0;
}

void CSquareKey::MakeKeys()
{
	Word i, l, wTemp[SQUAREROUNDS + 1][4];

	wTemp[0][0] = wMaster[0];
	wTemp[0][1] = wMaster[1];
	wTemp[0][2] = wMaster[2];
	wTemp[0][3] = wMaster[3];
	
	for (i = 1; i <= SQUAREROUNDS; i++) {
		wTemp[i][0] = wTemp[i - 1][0] ^ rol(wTemp[i - 1][3], 8) ^ wOffset[i - 1];
		wTemp[i][1] = wTemp[i - 1][1] ^ wTemp[i][0];
		wTemp[i][2] = wTemp[i - 1][2] ^ wTemp[i][1];
		wTemp[i][3] = wTemp[i - 1][3] ^ wTemp[i][2];
	}

	for (l = 0; l < SQUAREROUNDS; l++) {
		for (i = 0; i < 4; i++) {
			wEncryptionKey[l][i] = 0;
			wEncryptionKey[l][i] ^= MakeWord(gf256_1f5_mul(WBYTE0(wTemp[l][i]), bSquareG[0][0]),
														gf256_1f5_mul(WBYTE0(wTemp[l][i]), bSquareG[0][1]),
														gf256_1f5_mul(WBYTE0(wTemp[l][i]), bSquareG[0][2]),
														gf256_1f5_mul(WBYTE0(wTemp[l][i]), bSquareG[0][3]));
			wEncryptionKey[l][i] ^= MakeWord(gf256_1f5_mul(WBYTE1(wTemp[l][i]), bSquareG[1][0]),
														gf256_1f5_mul(WBYTE1(wTemp[l][i]), bSquareG[1][1]),
														gf256_1f5_mul(WBYTE1(wTemp[l][i]), bSquareG[1][2]),
														gf256_1f5_mul(WBYTE1(wTemp[l][i]), bSquareG[1][3]));
			wEncryptionKey[l][i] ^= MakeWord(gf256_1f5_mul(WBYTE2(wTemp[l][i]), bSquareG[2][0]),
														gf256_1f5_mul(WBYTE2(wTemp[l][i]), bSquareG[2][1]),
														gf256_1f5_mul(WBYTE2(wTemp[l][i]), bSquareG[2][2]),
														gf256_1f5_mul(WBYTE2(wTemp[l][i]), bSquareG[2][3]));
			wEncryptionKey[l][i] ^= MakeWord(gf256_1f5_mul(WBYTE3(wTemp[l][i]), bSquareG[3][0]),
														gf256_1f5_mul(WBYTE3(wTemp[l][i]), bSquareG[3][1]),
														gf256_1f5_mul(WBYTE3(wTemp[l][i]), bSquareG[3][2]),
														gf256_1f5_mul(WBYTE3(wTemp[l][i]), bSquareG[3][3]));
		}
	}
	
	wEncryptionKey[SQUAREROUNDS][0] = wTemp[SQUAREROUNDS][0];
	wEncryptionKey[SQUAREROUNDS][1] = wTemp[SQUAREROUNDS][1];
	wEncryptionKey[SQUAREROUNDS][2] = wTemp[SQUAREROUNDS][2];
	wEncryptionKey[SQUAREROUNDS][3] = wTemp[SQUAREROUNDS][3];

	for (i = 0; i < SQUAREROUNDS; i++) {
		wDecryptionKey[i][0] = wTemp[SQUAREROUNDS - i][0];
		wDecryptionKey[i][1] = wTemp[SQUAREROUNDS - i][1];
		wDecryptionKey[i][2] = wTemp[SQUAREROUNDS - i][2];
		wDecryptionKey[i][3] = wTemp[SQUAREROUNDS - i][3];
	}
	wDecryptionKey[SQUAREROUNDS][0] = wEncryptionKey[0][0];
	wDecryptionKey[SQUAREROUNDS][1] = wEncryptionKey[0][1];
	wDecryptionKey[SQUAREROUNDS][2] = wEncryptionKey[0][2];
	wDecryptionKey[SQUAREROUNDS][3] = wEncryptionKey[0][3];
}

CSquareBlock::CSquareBlock()
{
	for (int i = 0; i < 4; i++)			// This is here for definiteness
		wData[i] = 0;
}

CSquareBlock::CSquareBlock(const CSquareBlock &cSquareBlock)
{
	for (int i = 0; i < 4; i++)
		wData[i] = cSquareBlock.wData[i];
}

CSquareBlock::CSquareBlock(const Word *pwData)
{
	for (int i = 0; i < 4; i++)
		wData[i] = pwData[i];
}

CSquareBlock::CSquareBlock(const Byte *pbData, Word wLength)
{
	SetData(pbData, wLength);
}

CSquareBlock::~CSquareBlock()
{
	for (int i = 0; i < 4; i++)			// Cleanup
		wData[i] = 0;
}

void CSquareBlock::Encrypt(const CSquareKey &cSquareKey)
{
	Word wTemp[4];
	const Word *k = cSquareKey.GetEncryptionKeys();
	wData[0] ^= k[0];
	wData[1] ^= k[1];
	wData[2] ^= k[2];
	wData[3] ^= k[3];
	for (int i = 0; i < (SQUAREROUNDS - 1); i++)
		r(wData, (k + ((i + 1) * 4)), wSquareE);
	s(wTemp, wData, 0, bSquareE);
	s(wTemp, wData, 1, bSquareE);
	s(wTemp, wData, 2, bSquareE);
	s(wTemp, wData, 3, bSquareE);
	wData[0] = wTemp[0] ^ k[SQUAREROUNDS * 4 + 0];
	wData[1] = wTemp[1] ^ k[SQUAREROUNDS * 4 + 1];
	wData[2] = wTemp[2] ^ k[SQUAREROUNDS * 4 + 2];
	wData[3] = wTemp[3] ^ k[SQUAREROUNDS * 4 + 3];
}

void CSquareBlock::Decrypt(const CSquareKey &cSquareKey)
{
	Word wTemp[4];
	const Word *k = cSquareKey.GetDecryptionKeys();
	wData[0] ^= k[0];
	wData[1] ^= k[1];
	wData[2] ^= k[2];
	wData[3] ^= k[3];
	for (int i = 0; i < (SQUAREROUNDS - 1); i++)
		r(wData, (k + ((i + 1) * 4)), wSquareD);
	s(wTemp, wData, 0, bSquareD);
	s(wTemp, wData, 1, bSquareD);
	s(wTemp, wData, 2, bSquareD);
	s(wTemp, wData, 3, bSquareD);
	wData[0] = wTemp[0] ^ k[SQUAREROUNDS * 4 + 0];
	wData[1] = wTemp[1] ^ k[SQUAREROUNDS * 4 + 1];
	wData[2] = wTemp[2] ^ k[SQUAREROUNDS * 4 + 2];
	wData[3] = wTemp[3] ^ k[SQUAREROUNDS * 4 + 3];
}

void CSquareBlock::SetData(Word *pwData)
{
	for (int i = 0; i < 4; i++)
		wData[i] = pwData[i];
}

Word CSquareBlock::GetData(Word i)
{
	return wData[i];
}

void CSquareBlock::SetData(const Byte *pbData, Word wLength)
{
	memcpy((Byte *)wData, pbData, wLength);
	memset((Byte *)wData + wLength, 0, sizeof(wData) - wLength);
	for (int i = 0; i < 4; i++)
		wData[i] = REVERSEWORD(wData[i]);
}

Byte *CSquareBlock::GetData()
{
	for (int i = 0; i < 4; i++)
		wResult[i] = REVERSEWORD(wData[i]);
	return (Byte *)wResult;
}
