/*
 * Copyright 1999, Alexander Feldman <alex@varna.net>
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. Neither the name of Alexander Feldman nor the names of its contributors
 *    may be used to endorse or promote products derived from this software
 *    without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY ALEXANDER FELDMAN AND CONTRIBUTORS ``AS IS''
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL ALEXANDER FELDMAN OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#include "sssl.hpp"

SSSL_API int iSSSLError = SSSL_NO_ERROR;

extern CSSSLSpace cSSSLSpace;				// Defined in space.cpp

SSSL_API int SSSLaccept(int sockfd, struct sockaddr *addr, socklen_t *addrlen)
{
	AGREEMENT_STR strAgreement;

	Byte bRequest;
	Byte bHash;

	char szSecretKey[256];				// The full name of the private key file.
	char szPublicKey[256];				// The full name of the public key file.
	char *pszResult = NULL;
	char *pszAsymmetric = NULL;		// The name of the asymmetric method.
												// We need it to read the proper config
												// files.
	int iPublicKeySize = 0;
	int iSocket = -1;						// The accepted socket handler.
	int iResult = -1;
	int iPrivateKey = -1;
	int iPublicKey = -1;

	Byte *pbIV = NULL;
	Byte *pbPublicKey = NULL;

	CKeyExchange *pKeyExchange = NULL;
	CAsymmetricBlock *pSecretKey = NULL;
	CAsymmetricKey *pPrivateKey = NULL;
	CDigest *pKeyDigest = NULL;

	CSSSLSocket *pNewSSSLSocket;		// Automatically deleted!

#ifdef HAVE_SETRLIMIT
# ifdef RLIMIT_CORE
	struct rlimit r;
	getrlimit(RLIMIT_CORE, &r);
	r.rlim_cur = 0;
	setrlimit(RLIMIT_CORE, &r);
# endif
#endif

	CConfigurationFile *pConfigFile = new CConfigurationFile(SSSL_CONFIG_FILE, &pszResult);
	if (NULL == pConfigFile || NULL != pszResult) {
		iSSSLError = BAD_CONFIGFILE;
		return -1;
	}

	if (-1 == (iSocket = accept(sockfd, addr, addrlen))) {
		delete pConfigFile;
		iSSSLError = BAD_ACCEPTCALL;
		return -1;
	}
	if (false == ServerAgree(pConfigFile, iSocket, &strAgreement)) {
		delete pConfigFile;
		close(iSocket);
		iSSSLError = BAD_AGREEMENT;
		return -1;
	}
	if (NULL == (pszAsymmetric = FindName(strAgreement.bAsymmetric))) {
		delete pConfigFile;
		close(iSocket);
		iSSSLError = BAD_EXTENDED;
		return -1;
	}
	snprintf(szSecretKey, sizeof(szSecretKey), "%s.%s.sec", pConfigFile->GetString("server", "secret", "hostkey"), pszAsymmetric);
	snprintf(szPublicKey, sizeof(szSecretKey), "%s.%s.pub", pConfigFile->GetString("server", "public", "hostkey"), pszAsymmetric);
	delete pConfigFile;

	if (NULL == (pPrivateKey = NewAsymmetricKey(strAgreement.bAsymmetric))) {
		iResult = -1;
		iSSSLError = BAD_ALLOCCALL;
		goto accept_exit;
	}
	try
	{
		pPrivateKey->ReadPrivate(szSecretKey);
	}
	catch (const char *)
	{
		iResult = -1;
		iSSSLError = BAD_PRIVATEKEY;
		goto accept_exit;
	}

	if (false == strAgreement.fgKeyExchange) {
		if (sizeof(Byte) != scl_read(iSocket, &bRequest, sizeof(Byte))) {
			iResult = -1;
			iSSSLError = BAD_READCALL;
			goto accept_exit;
		}
		if (bRequest == GET_HASH) {
// Open the public key of the server. Allocate a buffer and read the whole
// file in the buffer. Get the fingerprint of the buffer, discard buffer and
// close file. We will need only the fingerprint value.
			if (-1 == (iPublicKey = open(szPublicKey, O_RDONLY)) ||
				 -1 == (iPublicKeySize = lseek(iPublicKey, 0, SEEK_END))) {
				iResult = -1;
				iSSSLError = BAD_PUBLICKEY;
				goto accept_exit;
			}
			if (-1 == lseek(iPublicKey, 0, SEEK_SET) ||
				 NULL == (pbPublicKey = new Byte[iPublicKeySize]) ||
				 -1 == (iPublicKeySize = read(iPublicKey, pbPublicKey, iPublicKeySize)) ||
				 NULL == (pKeyDigest = NewDigest(strAgreement.bDigest, pbPublicKey, iPublicKeySize)) ||
// Send the calculated hash to the client.
				 strAgreement.wDigestSize != (Word)scl_write(iSocket, (void *)pKeyDigest->GetFingerPrint(), strAgreement.wDigestSize)) {
				iResult = -1;
				iSSSLError = BAD_PUBLICKEY;
				goto accept_exit;
			}
			if (sizeof(Byte) != scl_read(iSocket, &bHash, sizeof(Byte)) ||
				 bHash != HASH_OK) {
				iResult = -1;
				iSSSLError = BAD_CLIENTPUBLIC;
				goto accept_exit;
			}
		} else if (bRequest == GET_KEY) {
// Send the public key in DER encoded binary format to the client.
			try
			{
				pPrivateKey->WritePublicKey(iSocket, false);
			}
			catch (const char *)
			{
				iResult = -1;
				iSSSLError = BAD_KEYEXCHANGE;
				goto accept_exit;
			}
		} else {
			iResult = -1;
			iSSSLError = BAD_EXTENDED;
			goto accept_exit;
		}
	} else {
		pKeyExchange = NewKeyExchange(strAgreement.bAsymmetric, pPrivateKey);
		if (NULL == pKeyExchange) {
			iResult = -1;
			iSSSLError = BAD_ALLOCCALL;
			goto accept_exit;
		}
		try
		{
			pKeyExchange->WritePublic(iSocket);
			pKeyExchange->ReadOtherPublic(iSocket);
			pKeyExchange->Agree();
		}
		catch (const char *)
		{
			iResult = -1;
			iSSSLError = BAD_KEYEXCHANGE;
			goto accept_exit;
		}
	}

	if (NULL == (pbIV = new Byte[strAgreement.wBlockSize]) ||
		 strAgreement.wBlockSize != (Word)scl_read(iSocket, pbIV, strAgreement.wBlockSize)) {
		iResult = -1;
		iSSSLError = BAD_ALLOCCALL;
		goto accept_exit;
	}

	if (false == strAgreement.fgKeyExchange) {
		if (NULL == (pSecretKey = NewAsymmetricBlock(strAgreement.bAsymmetric, pPrivateKey))) {
			iResult = -1;
			iSSSLError = BAD_ALLOCCALL;
			goto accept_exit;
		}
		try
		{
			pSecretKey->Read(iSocket);
		}
		catch (const char *)
		{
			iResult = -1;
			iResult = BAD_KEYEXCHANGE;
			goto accept_exit;
		}

		pSecretKey->Decrypt();
	}

// This object (pointed by pNewSSSLSocket) will be automagically deleted on
// SSSLclose or SSSLSpace destruction
	pNewSSSLSocket = new CSSSLSocket(iSocket,
												strAgreement.bSymmetric,
												strAgreement.wBlockSize,
												strAgreement.wKeySize,
												pbIV,
												(strAgreement.fgKeyExchange ? pKeyExchange->GetKey() : pSecretKey->GetData()));
	if (NULL == pNewSSSLSocket) {
		iResult = -1;
		iSSSLError = BAD_ALLOCCALL;
		goto accept_exit;
	}

	if (true == cSSSLSpace.AddSocket(pNewSSSLSocket))
		iResult = iSocket;

accept_exit:
// Memory cleanup
	if (NULL != pSecretKey)
		delete pSecretKey;
	if (NULL != pKeyExchange)
		delete pKeyExchange;
	if (NULL != pbIV)
		delete pbIV;
	if (NULL != pbPublicKey)
		delete pbPublicKey;
	if (NULL != pPrivateKey)
		delete pPrivateKey;
	if (NULL != pKeyDigest)
		delete pKeyDigest;
	if (-1 != iPrivateKey)
		close(iPrivateKey);
	if (-1 != iPublicKey)
		close(iPublicKey);

	return iResult;
}

SSSL_API int SSSLconnect(int sockfd, const struct sockaddr *addr, socklen_t addrlen)
{
	AGREEMENT_STR strAgreement;

	Byte bHash;
	Byte bRequest;

	char szSecretKey[256];				// The full name of the private key file.
	char szPublicKeyDirectory[256];
	char szPublicKeyFile[256];
	char *pszResult = NULL;
	char *pszAsymmetric = NULL;

	int iPublicKeySize;
	int iResult;

	Byte *pbPublicKeyHash = NULL;
	Byte *pbPublicKey = NULL;

	int iPublicKey = 0;

	CKeyExchange *pKeyExchange = NULL;
	CAsymmetricBlock *pSecretKey = NULL;
	CAsymmetricKey *pPrivateKey = NULL;
	CAsymmetricKey *pPublicKey = NULL;
	CDigest *pKeyDigest = NULL;

	CSSSLSocket *pNewSSSLSocket;		// Automatically deleted!

#ifdef HAVE_SETRLIMIT
#ifdef RLIMIT_CORE
	struct rlimit r;
	getrlimit(RLIMIT_CORE, &r);
	r.rlim_cur = 0;
	setrlimit(RLIMIT_CORE, &r);
#endif
#endif

	CConfigurationFile *pConfigFile = new CConfigurationFile(SSSL_CONFIG_FILE, &pszResult);
	if (NULL == pConfigFile || NULL != pszResult) {
		iSSSLError = BAD_CONFIGFILE;
		return -1;
	}

	strncpy(szPublicKeyDirectory, pConfigFile->GetString("client", "directory", PATH_PUBLICKEYDIR), sizeof(szPublicKeyDirectory));
#ifdef HAVE_WORDEXP
	wordexp_t strWordExp;
	if (0 != wordexp(szPublicKeyDirectory, &strWordExp, 0)) {
		delete pConfigFile;
		iSSSLError = BAD_EXTENDED;
		return -1;
	}
	strncpy(szPublicKeyDirectory, strWordExp.we_wordv[0], sizeof(szPublicKeyDirectory));
	wordfree(&strWordExp);
#else
	char szTmp[256];
	strcpy(szTmp, szPublicKeyDirectory);
	char *p = szTmp;
	char *q = szPublicKeyDirectory;
	while (*p && q < szPublicKeyDirectory + 255) {
		if (*p == '~') {
			char *r = getenv("HOME");
			strcpy(q, r);
			q += strlen(r);
			p++;
			continue;
		}
		*q++ = *p++;
	}
	*q = '\0';
#endif

	iResult = connect(sockfd, addr, addrlen);
	if (-1 == iResult) {
		delete pConfigFile;
		iSSSLError = BAD_CONNECTCALL;
		return -1;
	}
	if (false == ClientAgree(pConfigFile, sockfd, &strAgreement)) {
		delete pConfigFile;
		close(sockfd);
		iSSSLError = BAD_AGREEMENT;
		return -1;
	}
	if (NULL == (pszAsymmetric = FindName(strAgreement.bAsymmetric))) {
		delete pConfigFile;
		close(sockfd);
		iSSSLError = BAD_EXTENDED;
		return -1;
	}
	snprintf(szSecretKey, sizeof(szSecretKey), "%s.%s.sec", pConfigFile->GetString("server", "secret", "hostkey"), pszAsymmetric);
//	snprintf(szPublicKeyFile, sizeof(szPublicKeyFile), "%s/%s:%d.%s.pub", szPublicKeyDirectory, inet_ntoa(((struct sockaddr_in *)addr)->sin_addr), ntohs(((struct sockaddr_in *)addr)->sin_port), pszAsymmetric);
#ifndef WIN32
	snprintf(szPublicKeyFile, sizeof(szPublicKeyFile), "%s/%s.%s.pub", szPublicKeyDirectory, inet_ntoa(((struct sockaddr_in *)addr)->sin_addr), pszAsymmetric);
#else
	snprintf(szPublicKeyFile, sizeof(szPublicKeyFile), "%s\\%s.%s.pub", szPublicKeyDirectory, inet_ntoa(((struct sockaddr_in *)addr)->sin_addr), pszAsymmetric);
#endif

	delete pConfigFile;

	if (true == strAgreement.fgKeyExchange) {
		if (NULL == (pPrivateKey = NewAsymmetricKey(strAgreement.bAsymmetric))) {
			iResult = -1;
			iSSSLError = BAD_ALLOCCALL;
			goto connect_exit;
		}
		try
		{
			pPrivateKey->ReadPrivate(szSecretKey);
		}
		catch (const char *)
		{
			iResult = -1;
			iSSSLError = BAD_PRIVATEKEY;
			goto connect_exit;
		}

		if (NULL == (pKeyExchange = NewKeyExchange(strAgreement.bAsymmetric, pPrivateKey))) {
			iResult = -1;
			iSSSLError = BAD_ALLOCCALL;
			goto connect_exit;
		}
		try
		{
			pKeyExchange->ReadOtherPublic(sockfd);
			pKeyExchange->WritePublic(sockfd);
			pKeyExchange->Agree();
		}
		catch (const char *)
		{
			iResult = -1;
			iSSSLError = BAD_KEYEXCHANGE;
			goto connect_exit;
		}
	}

// This object (pointed by pNewSSSLSocket) will be automagically deleted on
// SSSLclose or SSSLSpace destruction
	pNewSSSLSocket = new CSSSLSocket(sockfd,
												strAgreement.bSymmetric,
												strAgreement.wBlockSize,
												strAgreement.wKeySize,
												NULL,
												(strAgreement.fgKeyExchange ? pKeyExchange->GetKey() : NULL));
	if (NULL == pNewSSSLSocket) {
		iSSSLError = BAD_ALLOCCALL;
		close(sockfd);
		return -1;
	}

	cSSSLSpace.AddSocket(pNewSSSLSocket);

	if (false == strAgreement.fgKeyExchange) {
		if (NULL == (pPublicKey = NewAsymmetricKey(strAgreement.bAsymmetric))) {
			iResult = -1;
			iSSSLError = BAD_ALLOCCALL;
			goto connect_exit;
		}
		if (-1 != (iPublicKey = open(szPublicKeyFile, O_RDONLY))) {
// The file exists (perhaps).
			if (-1 == (iPublicKeySize = lseek(iPublicKey, 0, SEEK_END))) {
				iResult = -1;
				iSSSLError = BAD_EXTENDED;
				goto connect_exit;
			}
			if (-1 == lseek(iPublicKey, 0, SEEK_SET)) {
				iResult = -1;
				iSSSLError = BAD_PUBLICKEY;
				goto connect_exit;
			}
			if (NULL == (pbPublicKey = new Byte[iPublicKeySize])) {
				iResult = -1;
				iSSSLError = BAD_PUBLICKEY;
				goto connect_exit;
			}
			if (-1 == (iPublicKeySize = read(iPublicKey, pbPublicKey, iPublicKeySize))) {
				iResult = -1;
				iSSSLError = BAD_PUBLICKEY;
				goto connect_exit;
			}
			if (-1 == lseek(iPublicKey, 0, SEEK_SET)) {
				iResult = -1;
				iSSSLError = BAD_PUBLICKEY;
				goto connect_exit;
			}
			if (NULL == (pKeyDigest = NewDigest(strAgreement.bDigest, pbPublicKey, iPublicKeySize))) {
				iResult = -1;
				iSSSLError = BAD_PUBLICKEY;
				goto connect_exit;
			}

			bRequest = GET_HASH;
			if (sizeof(Byte) != scl_write(sockfd, &bRequest, sizeof(Byte))) {
				iResult = -1;
				iSSSLError = BAD_WRITECALL;
				goto connect_exit;
			}

// The server will (hopefully) send us a hash of its public key. Read it.
			if (NULL == (pbPublicKeyHash = new Byte[strAgreement.wDigestSize]) ||
				 strAgreement.wDigestSize != (Word)scl_read(sockfd, pbPublicKeyHash, strAgreement.wDigestSize) ||
				 0 != memcmp(pbPublicKeyHash, pKeyDigest->GetFingerPrint(), strAgreement.wDigestSize)) {
				bHash = HASH_BAD;
				scl_write(sockfd, &bHash, sizeof(Byte));
				iResult = -1;
				iSSSLError = BAD_HASH;
				goto connect_exit;
			}
			bHash = HASH_OK;
			if (sizeof(Byte) != scl_write(sockfd, &bHash, sizeof(Byte))) {
				iResult = -1;
				iSSSLError = BAD_WRITECALL;
				goto connect_exit;
			}
			close(iPublicKey);

			try
			{
				pPublicKey->ReadPublic(szPublicKeyFile, true);
			}
			catch (const char *)
			{
				iResult = -1;
				iSSSLError = BAD_PUBLICKEY;
				goto connect_exit;
			}
		} else {
			bRequest = GET_KEY;
			if (sizeof(Byte) != scl_write(sockfd, &bRequest, sizeof(Byte))) {
				iResult = -1;
				iSSSLError = BAD_WRITECALL;
				goto connect_exit;
			}
// Get the public key in DER encoded binary format from the server.
			try
			{
				pPublicKey->ReadPublicKey(sockfd, false);
			}
			catch (const char *)
			{
				iResult = -1;
				iSSSLError = BAD_KEYEXCHANGE;
				goto connect_exit;
			}

// Write it to a file for further use.

			mkdir(szPublicKeyDirectory, 0700);

			try
			{
				pPublicKey->WritePublic(szPublicKeyFile);
			}
			catch (const char *)
			{
				iResult = -1;
				iSSSLError = BAD_SAVEPUBLICKEY;
				goto connect_exit;
			}
		}
	}

	if (strAgreement.wBlockSize != (Word)scl_write(sockfd,
																  pNewSSSLSocket->GetIV(),
																  pNewSSSLSocket->GetIVSize())) {
		iResult = -1;
		iSSSLError = BAD_KEYEXCHANGE;
		goto connect_exit;
	}

	if (false == strAgreement.fgKeyExchange) {
		if (NULL == (pSecretKey = NewAsymmetricBlock(strAgreement.bAsymmetric, pPublicKey))) {
			iResult = -1;
			iSSSLError = BAD_ALLOCCALL;
			goto connect_exit;
		}
		pSecretKey->SetData(pNewSSSLSocket->GetKey(), pNewSSSLSocket->GetKeySize());

		pSecretKey->Encrypt();

		try
		{
			pSecretKey->Write(sockfd);
		}
		catch (const char *)
		{
			iSSSLError = BAD_KEYEXCHANGE;
			iResult = -1;
		}
	}

connect_exit:
// Memory cleanup
	if (NULL != pSecretKey)
		delete pSecretKey;
	if (NULL != pKeyExchange)
		delete pKeyExchange;
	if (NULL != pbPublicKeyHash)
		delete pbPublicKeyHash;
	if (NULL != pbPublicKey)
		delete pbPublicKey;
	if (NULL != pPublicKey)
		delete pPublicKey;
	if (NULL != pKeyDigest)
		delete pKeyDigest;
// Buffered I/O cleanup
	if (-1 == iResult)
		SSSLclose(sockfd);
	if (-1 != iPublicKey)
		close(iPublicKey);

	return iResult;
}

SSSL_API ssize_t SSSLread(int sockfd, void *buf, size_t count)
{
	iSSSLError = SSSL_NO_ERROR;

	size_t i, j = 0;

	do {
#ifdef WIN32
		i = recv(sockfd, (char *)buf + j, count - j, 0);
#else
		i = read(sockfd, (char *)buf + j, count - j);
#endif
		if (i == (size_t)-1) {
			iSSSLError = BAD_READCALL;
			break;
		}
		j += i;
	} while (j != count && i != 0);

	cSSSLSpace.FindSocket(sockfd)->GetCryptographer()->DecryptData((Byte *)buf, (int)count);

	return j;
}

SSSL_API ssize_t SSSLwrite(int sockfd, const void *buf, size_t count)
{
	iSSSLError = SSSL_NO_ERROR;

	size_t i, j = 0;

	Byte *copy = new Byte[count];
	if (NULL == copy) {
		iSSSLError = BAD_ALLOCCALL;
		return -1;
	}
	cSSSLSpace.FindSocket(sockfd)->GetCryptographer()->EncryptData((Byte *)buf, (Byte *)copy, (int)count);

	do {
#ifdef WIN32
		i = send(sockfd, (const char *)copy + j, count - j, 0);
#else
		i = write(sockfd, copy + j, count - j);
#endif
		if ((size_t)-1 == i) {
			iSSSLError = BAD_WRITECALL;
			break;
		}
		j += i;
	} while (j != count && i != 0);
	delete copy;
	return j;
}

SSSL_API int SSSLclose(int sockfd)
{
	cSSSLSpace.DeleteSocket(sockfd);
	return close(sockfd);
}

char *ppszSSSLerrors[] =
{
	"success",
	"memory allocation error",
	"invalid 'accept' call",
	"invalid 'connect' call",
	"error in private key",
	"error in public key",
	"invalid public key (delete the saved public key if the original has changed)",
	"bad configuration file",
	"could not negotiate a compatible set of protocols",
	"invalid 'open' call",
	"error in symmetric key exchange",
	"bad SSSL socket",
	"invalid 'read' call",
	"invalid 'write' call",
	"extended error",
	"error in saving a public key (check the ownership of the public directory)",
	"the client's copy of the public key does not match"
};

SSSL_API char *SSSLerror()
{
	return ppszSSSLerrors[iSSSLError];
}
