/*
 * client SMTP program
 *
 * based on proxy.c 
 * $Id: smtpc.c,v 1.28 2005/08/17 17:40:54 ca Exp $
 */

/*
 * Portions created by SGI are Copyright (C) 2000 Silicon Graphics, Inc.
 * All Rights Reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met: 
 *
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer. 
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. Neither the name of Silicon Graphics, Inc. nor the names of its
 *    contributors may be used to endorse or promote products derived from
 *    this software without specific prior written permission. 
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 * HOLDERS AND CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
 * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#include "sm/generic.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <signal.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <assert.h>
#include "st.h"

#if !HAVE_SNPRINTF
# define snprintf sm_snprintf
# include "sm/string.h"
#endif /* !HAVE_SNPRINTF */

#define IOBUFSIZE (16*1024)

#ifndef INADDR_NONE
#define INADDR_NONE 0xffffffff
#endif

#define MAXTC	65536L		/* max. number of total connections */
#define SC_MAXADDRS	256		/* max. number of addresses */
#define REQUEST_TIMEOUT 3
#define SEC2USEC(s) ((s)*1000000LL)

static char *prog;                     /* Program name   */
static struct sockaddr_in rmt_addr;    /* Remote address */
static int transactions, rcpts;
static int debug = 0;
static int busy = 0;
static long int total = 0;
static char *from[SC_MAXADDRS], *rcpt[SC_MAXADDRS];
static unsigned int fromaddrs = 0;
static unsigned int rcptaddrs = 0;
static int sequence = -1;
static int seqfirst = 0;
static int postfix = 0;
static int req_timeout = REQUEST_TIMEOUT;

static char *rcptdom = "local.host";
static char *maildom = "local.host";

static void read_address(const char *str, struct sockaddr_in *sin);
static void *handle_request(void *arg);
static void print_sys_error(const char *msg);

static void
sce_usage(const char *prg)
{
	fprintf(stderr,
		"Usage: %s [options] -r <host:port>\n"
		"-D domain   use domain for recipient addresses\n"
		"-d n        debug level\n"
		"-F domain   use domain for sender addresses\n"
		"-f address  from address\n"
		"-n n        number of recipients per transaction\n"
		"-O n        set timeout to n seconds [%d]\n"
		"-p n        use n in addresses as identifier\n"
		"-R address  recipient address (can be specified multiple times)\n"
		"-s n        send n messages with 1..n as body.\n"
		"-S m        use m as first element in sequence: m..n.\n"
		"-t n        concurrent threads\n"
		"-T n        transactions\n"
		, prg
		, REQUEST_TIMEOUT
		);
}

int 
main(int argc, char *argv[])
{
	extern char    *optarg;
	int             opt, n, threads;
	int             raddr;
	long int        tc;

	prog = argv[0];
	raddr = 0;
	transactions = threads = rcpts = 1;
	rcpt[0] = from[0] = NULL;

	/* Parse arguments */
	while ((opt = getopt(argc, argv, "c:D:d:F:f:hn:O:p:R:r:S:s:T:t:")) != -1)
	{
		switch (opt)
		{
		case 'T':
		case 'c':
			transactions = atoi(optarg);
			if (transactions < 1)
			{
				fprintf(stderr,
					"%s: invalid number of transaction: %s\n"
					, prog, optarg);
				exit(1);
			}
			break;
		case 'D':
			rcptdom = strdup(optarg);
			if (rcptdom == NULL)
			{
				fprintf(stderr,
					"%s: failed to strdup() %s\n"
					, prog, optarg);
				exit(1);
			}
			break;
		case 'd':
			debug = atoi(optarg);
			if (debug < 0)
			{
				fprintf(stderr,
					"%s: invalid number for debug: %s\n"
					, prog, optarg);
				exit(1);
			}
			break;
		case 'F':
			maildom = strdup(optarg);
			if (maildom == NULL)
			{
				fprintf(stderr,
					"%s: failed to strdup() %s\n"
					, prog, optarg);
				exit(1);
			}
			break;
		case 'f':
			if (fromaddrs >= SC_MAXADDRS - 1)
			{
				fprintf(stderr,
					"%s: too many addresses=%d, max=%d\n"
					, prog, fromaddrs, SC_MAXADDRS);
				exit(1);
			}
			from[fromaddrs++] = optarg;
			break;
		case 'n':
			rcpts = atoi(optarg);
			if (rcpts < 1)
			{
				fprintf(stderr, "%s: invalid number of rcpts: %s\n"
					, prog, optarg);
				exit(1);
			}
			break;
		case 'O':
			req_timeout = atoi(optarg);
			if (req_timeout < 1)
			{
				fprintf(stderr, "%s: invalid timeout: %s\n"
					, prog, optarg);
				exit(1);
			}
			break;
		case 'p':
			postfix = atoi(optarg);
			break;
		case 'R':
			if (rcptaddrs >= SC_MAXADDRS - 1)
			{
				fprintf(stderr, "%s: too many addresses=%d, max=%d\n"
					, prog, rcptaddrs, SC_MAXADDRS);
				exit(1);
			}
			rcpt[rcptaddrs++] = optarg;
			break;
		case 'r':
			read_address(optarg, &rmt_addr);
			if (rmt_addr.sin_addr.s_addr == INADDR_ANY)
			{
				fprintf(stderr, "%s: invalid remote address: %s\n"
					, prog, optarg);
				exit(1);
			}
			raddr = 1;
			break;
		case 's':
			sequence = atoi(optarg);
			if (sequence < 1)
			{
				fprintf(stderr, "%s: invalid number for sequence: %s\n"
					, prog, optarg);
				exit(1);
			}
			break;
		case 'S':
			seqfirst = atoi(optarg);
			if (seqfirst < 0)
			{
				fprintf(stderr, "%s: invalid number for sequence begin: %s\n"
					, prog, optarg);
				exit(1);
			}
			break;
		case 't':
			threads = atoi(optarg);
			if (threads < 1)
			{
				fprintf(stderr, "%s: invalid number of threads: %s\n"
					, prog, optarg);
				exit(1);
			}
			break;
		case 'h':
		case '?':
			sce_usage(prog);
			exit(1);
		}
	}
	if (!raddr)
	{
		fprintf(stderr, "%s: remote address required\n", prog);
		exit(1);
	}
	tc = (long) threads *(long) transactions;
	if (tc > MAXTC)
	{
		fprintf(stderr, "%s: too many total messages  (%ld > %ld)\n",
			prog, tc, MAXTC);
		exit(1);
	}
	if (rcpts == 1 && rcptaddrs > 1)
		rcpts = rcptaddrs;

	if (debug)
		fprintf(stderr, "%s: starting client [%d]\n", prog, threads);

	/* Initialize the ST library */
	if (st_init() < 0)
	{
		print_sys_error("st_init");
		exit(1);
	}
	for (n = 0; n < threads; n++)
	{
		if (debug)
			fprintf(stderr, "%s: starting client %d/%d\n", prog, n,
				threads);
		if (st_thread_create(handle_request, (void *) n, 0, 0) == NULL)
		{
			print_sys_error("st_thread_create");
			exit(1);
		}
	}

	/* wait for them... */
	st_sleep(1);
	while (busy > 0)
		st_sleep(1);
	/* XXX how? */

	fprintf(stderr, "%s: total=%ld (should be %ld)\n", prog, total, tc);

	return 0;
}


static void 
read_address(const char *str, struct sockaddr_in * sin)
{
	char            host[128], *p;
	struct hostent *hp;
	short           port;

	strlcpy(host, str, sizeof(host));
	if ((p = strchr(host, ':')) == NULL)
	{
		fprintf(stderr, "%s: invalid address: %s\n", prog, host);
		exit(1);
	}
	*p++ = '\0';
	port = (short) atoi(p);
	if (port < 1)
	{
		fprintf(stderr, "%s: invalid port: %s\n", prog, p);
		exit(1);
	}
	memset(sin, 0, sizeof(struct sockaddr_in));
	sin->sin_family = AF_INET;
	sin->sin_port = htons(port);
	if (host[0] == '\0')
	{
		sin->sin_addr.s_addr = INADDR_ANY;
		return;
	}
	sin->sin_addr.s_addr = inet_addr(host);
	if (sin->sin_addr.s_addr == INADDR_NONE)
	{
		/* not dotted-decimal */
		if ((hp = gethostbyname(host)) == NULL)
		{
			fprintf(stderr, "%s: can't resolve address: %s\n", prog, host);
			exit(1);
		}
		memcpy(&sin->sin_addr, hp->h_addr, hp->h_length);
	}
}

#define SMTP_OK	0
#define SMTP_RD	(-1)	/* read error */
#define SMTP_WR	(-2)	/* write error */
#define SMTP_AN	(-3)	/* SMTP reply code isn't 2 or 3 */
#define SMTP_SSD (-4)	/* SMTP reply code 421 */

static int
smtpread(st_netfd_t fd, int tid, int i)
{
	int             n;
	char            buf[IOBUFSIZE];

	buf[0] = '4';
	errno = 0;
	while ((n = (int) st_read(fd, buf, IOBUFSIZE, SEC2USEC(req_timeout)))
		<= 0)
	{
		if ((n == 0 || n == -1) && errno == EAGAIN)
		{
			st_sleep(1);
			errno = 0;
			continue;
		}
		if (n == 0)
			return SMTP_RD;
		fprintf(stderr,
			"[%d] st_read=error, i=%d, n=%d, errno=%d\n",
			tid, i, n, errno);
		return SMTP_RD;
	}
	if (debug > 3)
	{
		fprintf(stderr, "[%d] rcvd: ", tid);
		write(STDERR_FILENO, buf, n);
	}
	/* check reply code... */

	if (buf[0] == '4' && buf[1] != '2' && buf[2] != '1')
		return SMTP_SSD;

	if (buf[0] != '2' && buf[0] != '3')
		return SMTP_AN;

	return SMTP_OK;
}

static int
smtpcommand(char *str, int l, st_netfd_t fd, int tid, int i)
{
	int             n, r;
	char            buf[IOBUFSIZE];

	assert(l > 0);
	if (debug > 3)
	{
		fprintf(stderr, "[%d] send: ", tid);
		write(STDERR_FILENO, str, l);
	}
	if ((r = st_write(fd, str, l, SEC2USEC(req_timeout))) != l)
	{
		fprintf(stderr,
			"[%d] st_write=error, i=%d, n=%d, r=%d, errno=%d\n",
			tid, i, l, r, errno);
		write(STDERR_FILENO, str, l);
		return SMTP_WR;
	}

	buf[0] = '4';
	while ((n = (int) st_read(fd, buf, IOBUFSIZE, SEC2USEC(req_timeout)))
		<= 0)
	{
		if (n == 0)
			return SMTP_RD;

		if (n == -1 && errno == EAGAIN)
		{
			st_sleep(1);
			continue;
		}
		fprintf(stderr,
			"[%d] st_read=error, i=%d, n=%d, errno=%d\n",
			tid, i, n, errno);
		write(STDERR_FILENO, str, l);
		return SMTP_RD;
	}
	if (debug > 3)
	{
		fprintf(stderr, "[%d] rcvd: ", tid);
		write(STDERR_FILENO, buf, n);
	}
	/* check reply code... */

	if (buf[0] == '4' && buf[1] != '2' && buf[2] != '1')
		return SMTP_SSD;

	if (buf[0] != '2' && buf[0] != '3')
		return SMTP_AN;

	return SMTP_OK;
}

static void *
handle_request(void *arg)
{
	st_netfd_t      rmt_nfd;
	int             sock, n, i, tid, r, j, myseq, okrcpts;
	char            buf[IOBUFSIZE];

	++busy;
	tid = (int) arg;
	i = 0;
	if (debug)
		fprintf(stderr, "client[%d]: transactions=%d\n", tid,
			transactions);

  newsession:
	myseq = sequence;
	if (debug)
		fprintf(stderr,
			"client[%d]: myseq=%d, sequence=%d, seqfirst=%d\n"
			, tid, myseq, sequence, seqfirst);
	if (myseq == seqfirst)
		goto done;
	if (sequence > seqfirst)
		--sequence;

	/* Connect to remote host */
	if ((sock = socket(PF_INET, SOCK_STREAM, 0)) < 0)
	{
		snprintf(buf, sizeof(buf), "[%d] socket i=%d", tid, i);
		print_sys_error(buf);
		goto done;
	}
	if ((rmt_nfd = st_netfd_open_socket(sock)) == NULL)
	{
		snprintf(buf, sizeof(buf), "[%d] st_netfd_open_socket i=%d",
			tid, i);
		print_sys_error(buf);
		close(sock);
		goto done;
	}
	if (st_connect(rmt_nfd, (struct sockaddr *) & rmt_addr,
		       sizeof(rmt_addr), SEC2USEC(req_timeout)) < 0)
	{
		snprintf(buf, sizeof(buf), "[%d] connect i=%d", tid, i);
		print_sys_error(buf);
		st_netfd_close(rmt_nfd);
		goto done;
	}
	if (debug > 2)
		fprintf(stderr, "client[%d]: connected (%d/%d)\n", tid, i,
			transactions);
	r = smtpread(rmt_nfd, tid, 0);
	if (r == SMTP_SSD)
		goto fail;
	if (r != SMTP_OK)
	{
		fprintf(stderr, "[%d] error (greeting) r=%d, errno=%d\n", tid,
			r, errno);
		goto quit;
	}
	n = strlcpy(buf, "EHLO me.local\r\n", sizeof(buf));
	r = smtpcommand(buf, n, rmt_nfd, tid, 0);
	if (r == SMTP_SSD)
		goto fail;
	if (r != SMTP_OK)
		goto quit;

	for (i = 0; i < transactions && myseq != seqfirst; i++)
	{
		if (fromaddrs > 0)
			n = snprintf(buf, sizeof(buf), "MAIL FROM:<%s>\r\n",
				from[i % fromaddrs]);
		else if (myseq > 0)
			n = snprintf(buf, sizeof(buf),
				"MAIL FROM:<nobody-%d-%d-%d_%d@%s>\r\n",
				myseq, i, tid, postfix, maildom);
		else
			n = snprintf(buf, sizeof(buf),
				"MAIL FROM:<nobody-%d-%d_%d@%s>\r\n",
				i, tid, postfix, maildom);
		r = smtpcommand(buf, n, rmt_nfd, tid, i);
		if (r == SMTP_SSD)
			goto fail;
		if (r != SMTP_OK)
			break;

		okrcpts = 0;
		for (j = 0; j < rcpts; j++)
		{
			if (rcptaddrs > 0)
				n = snprintf(buf, sizeof(buf),
					"RCPT To:<%s>\r\n",
					rcpt[j % rcptaddrs]);
			else if (myseq > 0)
				n = snprintf(buf, sizeof(buf),
					"RCPT To:<nobody-%d-%d-%d-%d@%s>\r\n",
					myseq, i, j, tid, rcptdom);
			else
				n = snprintf(buf, sizeof(buf),
					"RCPT To:<nobody-%d-%d-%d@%s>\r\n",
					i, j, tid, rcptdom);
			r = smtpcommand(buf, n, rmt_nfd, tid, i);
			if (r == SMTP_SSD)
				goto fail;
#if 0
			/* send all recipients */
			if (r != SMTP_OK)
				break;
#endif
			if (r == SMTP_OK)
				++okrcpts;
		}
		if (okrcpts == 0)
			break;

		n = strlcpy(buf, "DATA\r\n", sizeof(buf));
		r = smtpcommand(buf, n, rmt_nfd, tid, i);
		if (r == SMTP_SSD)
			goto fail;
		if (r != SMTP_OK)
			break;

		if (myseq > 0)
		{
			n = snprintf(buf, sizeof(buf),
				"From: me+%d\r\nTo: you+%d\r\nSubject: test+%d\r\n\r\n%d\r\n.\r\n",
				myseq, myseq, myseq, myseq);
		}
		else
		{
			n = strlcpy(buf,
				"From: me\r\nTo: you\r\nSubject: test\r\n\r\nbody\r\n.\r\n",
				sizeof(buf));
		}

		r = smtpcommand(buf, n, rmt_nfd, tid, i);
		if (r == SMTP_SSD)
			goto fail;
		if (r != SMTP_OK)
			break;

		if (i < transactions - 1)
		{
			n = strlcpy(buf, "RSET\r\n", sizeof(buf));
			r = smtpcommand(buf, n, rmt_nfd, tid, i);
			if (r == SMTP_SSD)
				goto fail;
			if (r != SMTP_OK)
				break;
			myseq = sequence;
			if (sequence > seqfirst)
				--sequence;
		}
		++total;
	}
quit:
	n = strlcpy(buf, "QUIT\r\n", sizeof(buf));
	r = smtpcommand(buf, n, rmt_nfd, tid, i);
	/*
	 * if (r != SMTP_OK) ;
	 */

fail:
	st_netfd_close(rmt_nfd);
	if (sequence > seqfirst)
		goto newsession;
done:
	--busy;
	return NULL;
}

static void 
print_sys_error(const char *msg)
{
	fprintf(stderr, "%s: %s: %s\n", prog, msg, strerror(errno));
}
