/*
 * Copyright 1999, Alexander Feldman <alex@varna.net>
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. Neither the name of Alexander Feldman nor the names of its contributors
 *    may be used to endorse or promote products derived from this software
 *    without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY ALEXANDER FELDMAN AND CONTRIBUTORS ``AS IS''
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL ALEXANDER FELDMAN OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#include "rabin.hpp"

CRabinKey::CRabinKey()
{
	fgHoldKey = false;
}

CRabinKey::CRabinKey(Word wModulusSize)
{
	if ((wModulusSize < 512) || (wModulusSize > 3072))
		throw(BAD_RABIN_MODULUSSIZE);
	CRabinKey::wModulusSize = wModulusSize;
	GenerateKeys();
	fgHoldKey = true;
	fgEncryptOnly = false;
}

CRabinKey::CRabinKey(const CBigNumber &cModulus)
{
	wModulusSize = cModulus.GetWords() * BITSINWORD;

	N = cModulus;
	
	fgHoldKey = true;
	fgEncryptOnly = true;
}

CRabinKey::CRabinKey(const CBigNumber &cModulus,
							const CProbablePrime &cFirstPrime,
							const CProbablePrime &cSecondPrime)
{
	wModulusSize = cModulus.GetWords() * BITSINWORD;

	N = cModulus;
	P = cFirstPrime;
	Q = cSecondPrime;

	if ((wModulusSize < 512) || (wModulusSize > 3072))
		throw(BAD_RABIN_MODULUSSIZE);

	fgHoldKey = true;
	fgEncryptOnly = false;
}

CRabinKey::CRabinKey(const CRabinKey &cRabinKey)
{
	wModulusSize = cRabinKey.wModulusSize;
	P = cRabinKey.P;
	Q = cRabinKey.Q;
	N = cRabinKey.N;

	if ((wModulusSize < 512) || (wModulusSize > 3072))
		throw(BAD_RABIN_MODULUSSIZE);

	fgHoldKey = true;
	fgEncryptOnly = false;
}

void CRabinKey::GenerateKeys()
{
	P.SetRandom(wModulusSize / 2, true);
	
	while (P % 4 != 3)
		P += 1;
	
	while (!P.IsPrime())
		P += 4;

	Q.SetRandom(wModulusSize / 2, true);
	
	while (Q % 4 != 3)
		Q += 1;

	while (!Q.IsPrime())
		Q += 4;

	N = P * Q;
}

void CRabinKey::Dump()
{
	printf("p [%d bits] = ", P.GetWords() * BITSINWORD); P.Dump();
	printf("q [%d bits] = ", Q.GetWords() * BITSINWORD); Q.Dump();
	printf("n [%d bits] = ", N.GetWords() * BITSINWORD); N.Dump();
}

void CRabinKey::WritePrivateKey(int iOut, bool fgBase64)
{
	CDEREncodedBigNumber cVersion((Word)RABIN_PRIVATE_KEY_VERSION);
	CDEREncodedBigNumber cModulus(N);
	CDEREncodedBigNumber cPrime1(P);
	CDEREncodedBigNumber cPrime2(Q);

	CDEREncodedSequence cSequence;
	
	cSequence.AddPrimitive(cVersion);
	cSequence.AddPrimitive(cModulus);
	cSequence.AddPrimitive(cPrime1);
	cSequence.AddPrimitive(cPrime2);
	
	if (true == fgBase64) {
		write_string(iOut, BEGIN_RABIN_PRIVATE_KEY "\n");
		cSequence.WriteBase64(iOut, true);
		write_string(iOut, "\n" END_RABIN_PRIVATE_KEY "\n");
	} else {
		cSequence.Write(iOut);
	}
}

void CRabinKey::WritePublicKey(int iOut, bool fgBase64)
{
	CDEREncodedBigNumber cModulus(N);

	CDEREncodedSequence cSequence;
	
	cSequence.AddPrimitive(cModulus);

	if (true == fgBase64) {
		write_string(iOut, BEGIN_RABIN_PUBLIC_KEY "\n");
		cSequence.WriteBase64(iOut, true);
		write_string(iOut, "\n" END_RABIN_PUBLIC_KEY "\n");
	} else {
		cSequence.Write(iOut);
	}
}

void CRabinKey::ReadPrivateKey(int iIn, bool fgBase64)
{
	CDEREncodedSequence cSequence;
	if (true == fgBase64) {
		if (false == match_string(iIn, BEGIN_RABIN_PRIVATE_KEY "\n"))
			throw(KEYFILE_ERROR);
		cSequence.ReadBase64(iIn);
		if (false == match_string(iIn, "\n" END_RABIN_PRIVATE_KEY "\n"))
			throw(KEYFILE_ERROR);
	} else {
		cSequence.Read(iIn);
	}

	CDEREncodedBigNumber cVersion(cSequence);
	CDEREncodedBigNumber cModulus(cSequence);
	CDEREncodedBigNumber cPrime1(cSequence);
	CDEREncodedBigNumber cPrime2(cSequence);

	N = cModulus;
	P = cPrime1;
	Q = cPrime2;

	wModulusSize = cModulus.GetWords() * BITSINWORD;

	if ((wModulusSize < 512) || (wModulusSize > 3072))
		throw(BAD_RABIN_MODULUSSIZE);

	fgHoldKey = true;
	fgEncryptOnly = false;
}

void CRabinKey::ReadPublicKey(int iIn, bool fgBase64)
{
	CDEREncodedSequence cSequence;
	if (true == fgBase64) {
		if (false == match_string(iIn, BEGIN_RABIN_PUBLIC_KEY "\n"))
			throw(KEYFILE_ERROR);
		cSequence.ReadBase64(iIn);
		if (false == match_string(iIn, "\n" END_RABIN_PUBLIC_KEY "\n"))
			throw(KEYFILE_ERROR);
	} else {
		cSequence.Read(iIn);
	}

	CDEREncodedBigNumber cModulus(cSequence);

	N = cModulus;

	wModulusSize = cModulus.GetWords() * BITSINWORD;

	if ((wModulusSize < 512) || (wModulusSize > 3072))
		throw(BAD_RABIN_MODULUSSIZE);

	fgHoldKey = true;
	fgEncryptOnly = true;
}

Word CRabinKey::Check()
{
	if (!fgHoldKey || fgEncryptOnly)
		throw(BAD_RABIN_OPERATION);
	
	Word wResult = RABIN_OK;
	
	if (!P.IsPrime())
		wResult |= RABIN_PNOTPRIME;
	if (!Q.IsPrime())
		wResult |= RABIN_QNOTPRIME;
	if (P % 4 != 3)
		wResult |= RABIN_PNOTCON;
	if (Q % 4 != 3)
		wResult |= RABIN_QNOTCON;
	if (N != P * Q)
		wResult |= RABIN_BADN;
	
	return wResult;
}

CRabinBlock::CRabinBlock(const CRabinKey &cRabinKey)
{
	cKey = cRabinKey;
}

CRabinBlock::CRabinBlock(const CRabinKey &cRabinKey, void *pvData, Word wData)
{
	cKey = cRabinKey;
	cData = CBigNumber(pvData, wData);
}

CRabinBlock::CRabinBlock(const CRabinKey &cRabinKey, const CBigNumber &cRabinData)
{
	cKey = cRabinKey;
	cData = cRabinData;
}

void CRabinBlock::Encrypt()
{
	if (!cKey.HoldKeyFlag())
		throw(BAD_RABIN_OPERATION);

	cData.SetCheckSum();
	cData.Sqr();
	cData.Mod(cKey.GetModulus());
}

void CRabinBlock::Decrypt()
{
	if (!cKey.HoldKeyFlag() || cKey.EncryptOnlyFlag())
		throw(BAD_RSA_OPERATION);

	CBigNumber D = cKey.GetP() + 1;
	CBigNumber E = cKey.GetQ() + 1;
	D.Shr(2);
	E.Shr(2);
	CBigNumber m1 = CBigNumber::ModExp(cData, D, cKey.GetP());
	CBigNumber m2 = (cKey.GetP() - m1) % cKey.GetModulus();
	CBigNumber m3 = CBigNumber::ModExp(cData, E, cKey.GetQ());
	CBigNumber m4 = (cKey.GetQ() - m3) % cKey.GetModulus();
	
	CBigNumber a = cKey.GetQ() * CBigNumber::ModInv(cKey.GetQ(), cKey.GetP());
	CBigNumber b = cKey.GetP() * CBigNumber::ModInv(cKey.GetP(), cKey.GetQ());
	cData = (a * m1 + b * m3) % cKey.GetModulus();
	if (cData.CheckSum())
		return;
	cData = (a * m1 + b * m4) % cKey.GetModulus();
	if (cData.CheckSum())
		return;
	cData = (a * m2 + b * m3) % cKey.GetModulus();
	if (cData.CheckSum())
		return;
	cData = (a * m2 + b * m4) % cKey.GetModulus();
	cData.CheckSum();							// Strip checksum as we do not need that
}

void CRabinBlock::Dump()
{
	printf("d = "); cData.Dump();
}

void CRabinBlock::Write(int iOut)
{
	CDEREncodedBigNumber cAll(cData);
	cAll.Write(iOut);
}

void CRabinBlock::Read(int iIn)
{
	CDEREncodedBigNumber cAll;

	cAll.Read(iIn);

	cData = cAll;
}

void CRabinBlock::SetData(Byte *pbData, Word wDataLength)
{
	cData = CBigNumber(pbData, wDataLength);
}

Byte *CRabinBlock::GetData()
{
	return (Byte *)(cData.GetData() + 1);
}

Word CRabinBlock::GetDataSize()
{
	return cData.GetData()[0];
}
