/**
 * Cleversafe open-source code header - Version 1.1 - December 1, 2006
 *
 * Cleversafe Dispersed Storage(TM) is software for secure, private and
 * reliable storage of the world's data using information dispersal.
 *
 * Copyright (C) 2005-2007 Cleversafe, Inc.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301,
 * USA.
 *
 * Contact Information: Cleversafe, 10 W. 35th Street, 16th Floor #84,
 * Chicago IL 60616
 * email licensing@cleversafe.org
 *
 * Author: Greg Dhuse <gdhuse@cleversafe.com>
 *
 */

#include "dsdnet.h"
#include "ksocket.h"

// Net workitem
#define DSD_NET_WORK_SEND  0xc1
typedef struct _DSD_NET_WORK
{
   int workType;           // DSD_NET_WORK_*
   WDFDEVICE device;

   union
   {
      // DSD_NET_WORK_SEND
      struct
      {
         uint8_t msgType;
         uint8_t* data;
         int len;
      } Send;
   } Type; 

} DSD_NET_WORK, *PDSD_NET_WORK;

WDF_DECLARE_CONTEXT_TYPE_WITH_NAME( DSD_NET_WORK, GetWork );
static void DsdNetWorkitem( IN WDFWORKITEM workitem );

/**
 * Attempt to connect a socket on the provided address and port
 */
NTSTATUS
DsdNetConnect( IN uint32_t ipAddr, 
               IN uint16_t port,
               OUT int* sock )
{
   NTSTATUS status;
   struct sockaddr_in addr;

   KdPrint(( DSDNET_TAG "Connecting: %u,%u\n", ipAddr, port ));

   *sock = socket( AF_INET, SOCK_STREAM, 0 );

   addr.sin_family      = AF_INET;
   addr.sin_port        = htons( port );
   addr.sin_addr.s_addr = htonl( ipAddr );

   status = connect( *sock, (struct sockaddr*)&addr, sizeof(addr) );

   return status;
}

/**
 * Close socket - when this is called the socket will be
 * freed and any attempt to use it will result in a fault.
 */
NTSTATUS
DsdNetClose( IN WDFDEVICE device )
{
   PDSD_DEV dsd;

   dsd = GetDsd( device );
   close( dsd->socket );

   return STATUS_SUCCESS;
}

/**
 * Block until a message can be read from the provided socket
 * Note: It is assumed that only one thread will be blocking on recv,
 *       so no locking is done on the socket
 *
 * @param device Device containing socket
 * @param data Pointer that will be made to point to a new message buffer
 * @param msgType Will be set to the DSD_MSG_* type corresponding to data
 */
NTSTATUS
DsdNetRecv( IN WDFDEVICE device, 
            OUT uint8_t** data, 
            OUT uint8_t* msgType )
{
   unsigned int received;
   NTSTATUS status;
   PDSD_DEV dsd;
   struct dsd_msg_header header;

   VERIFY_IS_IRQL_PASSIVE_LEVEL();

   *msgType = 0;
   dsd = GetDsd( device );

   // Receive header
   received = 0;
   RtlZeroMemory( &header, sizeof(header) );
   while( received < sizeof(header) )
   {
      status = recv( dsd->socket, 
                     (char*)&header + received, 
                     sizeof(header) - received, 
                     0 );
      if( status <= 0 )
      {
         status = (status == 0) ? STATUS_PORT_DISCONNECTED : status;
         goto cleanup;
      }
      received += status;
   }
   header.len = ntohl( header.len );

   // Verify magic
   if( header.magic != DSD_MAGIC )
   {
      /* FIXME */
      status = -1;
      goto cleanup;
   }

   // Receive payload
   *data = ExAllocatePoolWithTag( PagedPool, header.len, 'gerG' );
   if( !*data )
   {
      status = STATUS_INSUFFICIENT_RESOURCES;
      goto cleanup;
   }

   received = 0;
   while( received < header.len )
   {
      status = recv( dsd->socket, 
                     (char*)*data + received, 
                     header.len - received, 
                     0 );
      if( status <= 0 )
      {
         ExFreePool( *data );
         status = (status == 0) ? STATUS_PORT_DISCONNECTED : status;
         goto cleanup;
      }
      received += status;
   }
   *msgType = header.msg_type;
   status = STATUS_SUCCESS;

cleanup:
   return status;
}

/**
 * Send a message on the provided socket.  Message will be dispatched
 * from a workitem at a later time.
 *
 * NOTE: May be called at IRQL <= DISPATCH_LEVEL
 * NOTE: This function assumes ownership of 'data' and it is freed
 *
 * @param device Device containing socket
 * @param msgType DSD_MSG_* type of message
 * @param data Message payload
 * @param len Payload length in bytes
 */
NTSTATUS
DsdNetSendAtomic( IN WDFDEVICE device, 
                  IN uint8_t msgType, 
                  IN uint8_t* data, 
                  IN int len )
{
   NTSTATUS status;
   PDSD_DEV dsd;
   PDSD_NET_WORK work;
   WDFWORKITEM workitem;
   WDF_WORKITEM_CONFIG workitemConfig;
   WDF_OBJECT_ATTRIBUTES workitemAttributes;

   dsd = GetDsd( device );

   // Initialize workitem
   WDF_OBJECT_ATTRIBUTES_INIT( &workitemAttributes );
   WDF_OBJECT_ATTRIBUTES_SET_CONTEXT_TYPE( &workitemAttributes, DSD_NET_WORK );
   WDF_WORKITEM_CONFIG_INIT( &workitemConfig, DsdNetWorkitem );

   workitemAttributes.ParentObject = device;
   status = WdfWorkItemCreate( &workitemConfig,
                               &workitemAttributes,
                               &workitem );
   if( !NT_SUCCESS(status) )
   {
      return status;
   }

   // Initialize work
   work = GetWork( workitem );
   work->workType          = DSD_NET_WORK_SEND;
   work->device            = device;
   work->Type.Send.msgType = msgType;
   work->Type.Send.data    = data;
   work->Type.Send.len     = len;

   InterlockedIncrement( &dsd->pendingWork );
   WdfWorkItemEnqueue( workitem );
   return STATUS_SUCCESS;
}

/**
 * Send a message on the provided socket.  Message will be sent immediately.
 *
 * NOTE: Must be called at IRQL == PASSIVE_LEVEL
 *
 * @param device Device containing socket
 * @param msgType DSD_MSG_* type of message
 * @param data Message payload
 * @param len Payload length in bytes
 */
NTSTATUS
DsdNetSend( IN WDFDEVICE device, 
            IN uint8_t msgType, 
            IN uint8_t* data, 
            IN int len )
{
   int sent;
   NTSTATUS status;
   PDSD_DEV dsd;
   struct dsd_msg_header header;

   VERIFY_IS_IRQL_PASSIVE_LEVEL();

   dsd = GetDsd( device );
   WdfWaitLockAcquire( dsd->socketLock, NULL );

   // Send header
   memset( &header, 0, sizeof(header) );
   header.magic      = DSD_MAGIC;
   header.msg_type   = msgType;
   header.len        = htonl( len );

   sent = 0;
   while( sent < sizeof(header) )
   {
      status = send( dsd->socket, 
                     (char*)&header + sent, 
                     sizeof(header) - sent, 
                     0 );
      if( status <= 0 )
      {
         status = (status == 0) ? STATUS_PORT_DISCONNECTED : status;
         goto cleanup;
      }
      sent += status;
   }

   // Send payload
   sent = 0;
   while( sent < len )
   {
      status = send( dsd->socket, 
                     (char*)data + sent, 
                     len - sent, 
                     0 );
      if( status <= 0 )
      {
         status = (status == 0) ? STATUS_PORT_DISCONNECTED : status;
         goto cleanup;
      }
      sent += status;
   }
   
cleanup:
   WdfWaitLockRelease( dsd->socketLock );
   return STATUS_SUCCESS;
}

static void DsdNetWorkitem( IN WDFWORKITEM workitem )
{
   PDSD_DEV dsd;
   NTSTATUS status;
   PDSD_NET_WORK work;

   work = GetWork( workitem );
   dsd = GetDsd( work->device );
   if( work->workType == DSD_NET_WORK_SEND )
   {
      status = DsdNetSend( work->device,
                           work->Type.Send.msgType, 
                           work->Type.Send.data, 
                           work->Type.Send.len );
      ExFreePool( work->Type.Send.data );
   }

   WdfObjectDelete( workitem );
   InterlockedDecrement( &dsd->pendingWork );
   return;
}

/**
 * Network byte-swapping
 */
void bswap( OUT PVOID dst, IN PVOID src, IN size_t len )
{
   size_t i;
   uint8_t* bs = (uint8_t*)src;
   uint8_t* bd = (uint8_t*)dst + len - 1;

   for( i = 0; i < len; ++i )
   {
      *bd-- = *bs++;
   }
}

uint64_t htonll( IN uint64_t val )
{
   uint64_t ret = val;
   bswap( &ret, &val, sizeof(val) );
   return ret;
}

uint32_t htonl( IN uint32_t val )
{
   uint32_t ret = val;
   bswap( &ret, &val, sizeof(val) );
   return ret;
}

uint16_t htons( IN uint16_t val )
{
   uint16_t ret = val;
   bswap( &ret, &val, sizeof(val) );
   return ret;
}

uint64_t ntohll( IN uint64_t val )
{
   return htonll( val );
}

uint32_t ntohl( IN uint32_t val )
{
   return htonl( val );
}

uint16_t ntohs( IN uint16_t val )
{
   return htons( val );
}
