/**
 * Cleversafe open-source code header - Version 1.1 - December 1, 2006
 *
 * Cleversafe Dispersed Storage(TM) is software for secure, private and
 * reliable storage of the world's data using information dispersal.
 *
 * Copyright (C) 2005-2007 Cleversafe, Inc.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301,
 * USA.
 *
 * Contact Information: Cleversafe, 10 W. 35th Street, 16th Floor #84,
 * Chicago IL 60616
 * email licensing@cleversafe.org
 *
 * Author: Greg Dhuse <gdhuse@cleversafe.com>
 *
 */

#include "dsd.h"

/* Net workqueue */
#define DSD_NET_WORK_SEND  0xc1
struct dsd_net_work
{
   int work_type;          /* DSD_NET_WORK_* */
   union
   {
      /* DSD_NET_WORK_SEND */
      struct               
      {
         struct socket* sk;   
         int msg_type;
         uint8_t* data;
         int len;
      } send;
   };

   /* If non-null, this will be decremented when work is done */
   atomic_t* pending_work;

   struct work_struct work;
};

struct workqueue_struct* net_workqueue = NULL;
static void __dsd_net_work( struct work_struct* work );

/* Net helpers */
static int __dsd_recv_all( struct socket* sk, 
                           uint8_t* msg, 
                           int len );
static int __dsd_send_all( struct socket* sk, 
                           const uint8_t* msg, 
                           int len );

/**
 * Create a socket and connect to daemon
 */
struct socket* dsd_connect( uint32_t ip_addr, uint16_t port )
{
   int status;
   struct socket* sk = NULL;
   struct sockaddr_in addr;
   
   /* Create socket */
   status = sock_create( PF_INET, SOCK_STREAM, IPPROTO_TCP, &sk );
   if(unlikely( status ))
   {
      printk( KERN_ALERT DSD_TAG "Cannot create socket(%d)\n", status );
      goto error0;
   }

   /* Establish connection */
   memset( &addr, 0, sizeof(addr) );
   addr.sin_family      = AF_INET;
   addr.sin_port        = htons( port );
   addr.sin_addr.s_addr = htonl( ip_addr );

   #ifdef DSD_DEBUG
   printk( KERN_ALERT DSD_TAG "Connecting(%u,%u)\n", ip_addr, port );
   #endif

   status = sk->ops->connect( sk, (struct sockaddr*)&addr, 
      sizeof(addr), O_RDWR );
   if( status && (status != -EINPROGRESS) )
   {
      printk( KERN_ALERT DSD_TAG "Connect error(%d)\n", status );
      goto error1;
   }

   #ifdef DSD_DEBUG
   printk( KERN_ALERT DSD_TAG "Connected to daemon\n" );
   #endif

   return sk;

error1:
   sock_release( sk );
error0:
   return NULL;
}

/**
 * Block until a message is read from the socket.  If the return value is
 * positive, the caller is responsible for calling kfree(*data).
 *
 * @return msg_type (eg. DSD_MSG_*), or error code
 */
int dsd_recv( struct socket* sk, uint8_t** data )
{
   int status;
   struct dsd_msg_header header;

   /* Receive dsd_msg_hdr */
   status = __dsd_recv_all( sk, (uint8_t*)&header, sizeof(header) );
   if(unlikely( 0 > status ))
   {
      return status;
   }
   header.len = ntohl( header.len );

   /* Verify magic */
   if(unlikely( header.magic != DSD_MAGIC ))
   {
      printk( KERN_ALERT DSD_TAG "Bad magic number(%d)\n", header.magic );
      return -EIO;
   }

   /* Recieve payload */
   *data = kmalloc( header.len, GFP_KERNEL );
   if(unlikely( !*data ))
   {
      printk( KERN_ALERT DSD_TAG "Cannot allocate recv buffer\n" );
      return -ENOMEM;
   }

   status = __dsd_recv_all( sk, *data, header.len );
   if(unlikely( 0 > status ))
   {
      printk( KERN_ALERT DSD_TAG "Network payload read error(%d)\n", status );
      return status;
   }

   return header.msg_type;
}

/**
 * Helper function - receive until a known buffer is full
 * or an error is encountered
 *
 * @param sk Connected socket
 * @param msg Pre-allocted buffer of size len
 * @param len Length to of buffer to fill
 * @return amount read or error
 */
static int __dsd_recv_all( struct socket* sk, 
                           uint8_t* msg, 
                           int len )
{
   int status, read = 0;
   struct kvec iov;
   struct msghdr msghdr;

   msghdr.msg_name         = 0;
   msghdr.msg_namelen      = 0;
   msghdr.msg_control      = NULL;
   msghdr.msg_controllen   = 0;
   msghdr.msg_flags        = 0;   

   iov.iov_base   = msg;
   iov.iov_len    = len;

   /* Blocking read */
   do
   {
      status = kernel_recvmsg( sk, &msghdr, &iov, 1, len-read, 0 );
      if( status > 0 )
      {
         read += status;
         iov.iov_len = len - read;
      }
      else if( 0 == status )
      {
         #ifdef DSD_DEBUG
         printk( KERN_ALERT DSD_TAG "Short read\n" );
         #endif

         return -EPIPE;
      }
   }
   while( (status >= 0 && read < len) || 
           -EAGAIN == status || 
           -ERESTARTSYS == status );

   return ( status >= 0 ) ? len : status;
}

/**
 * Send a delayed message on the provided socket (atomic).
 *
 * @param sk Connected socket
 * @param msg_type One of DSD_MSG_*
 * @param data Payload.  Will be kfree'd when sent
 * @param len Payload length in bytes
 * @param pending_work Optional - if non-null, incremented when 
 *                     work is queued and decremented when complete
 * @return Send was successfully added to work queue
 */
int dsd_send_atomic( struct socket* sk, 
                     int msg_type, 
                     uint8_t* data, 
                     int len,
                     atomic_t* pending_work )
{
   struct dsd_net_work* net_work 
      = kmalloc( sizeof(struct dsd_net_work), GFP_ATOMIC );

   if( !net_work )
   {
      return -ENOMEM;
   }

   net_work->work_type  = DSD_NET_WORK_SEND;
   net_work->pending_work     = pending_work;

   net_work->send.sk          = sk;
   net_work->send.msg_type    = msg_type;
   net_work->send.len         = len;
   net_work->send.data        = data;

   INIT_WORK_NAR( &net_work->work, __dsd_net_work );
   if( net_work->pending_work )
   {
      atomic_inc( net_work->pending_work );
   }
   return queue_work( net_workqueue, &net_work->work );
}


/**
 * Send a message on the provided socket immediately (might sleep)
 * This function is responsible for calling kfree(data)
 */
int dsd_send( struct socket* sk, 
              int msg_type, 
              uint8_t* data, 
              int len )
{
   int status;
   struct dsd_msg_header header; 

   might_sleep();

   /* Send header */
   memset( &header, 0, sizeof(header) );
   header.magic      = DSD_MAGIC;
   header.msg_type   = msg_type;
   header.len        = htonl( len );

   status = __dsd_send_all( sk, (uint8_t*)&header, sizeof(header) );
   if( status < 0 )
   {
      printk( KERN_ALERT DSD_TAG "Network header send error(%d)\n", status );
      return status;
   }

   /* Send payload */
   status = __dsd_send_all( sk, data, len );
   if( status < 0 )
   {
      printk( KERN_ALERT DSD_TAG "Network payload send error(%d)\n", status );
      return status;
   }

   kfree( data );
   return status;
}

/**
 * Helper function - send a known buffer until all len bytes
 * are sent or an error is encountered
 *
 * @param sk Connected socket
 * @param msg Pre-allocted buffer of size len
 * @param len Length of buffer to send
 * @return amount sent or error
 */
static int __dsd_send_all( struct socket* sk, 
                           const uint8_t* msg, 
                           int len )
{
   int status, written = 0;
   struct kvec iov;
   struct msghdr msghdr;

   msghdr.msg_name         = 0;
   msghdr.msg_namelen      = 0;
   msghdr.msg_control      = NULL;
   msghdr.msg_controllen   = 0;
   msghdr.msg_flags        = 0;   

   iov.iov_base   = (uint8_t*)msg;
   iov.iov_len    = len;

   /* Blocking write */
   do
   {
      status = kernel_sendmsg( sk, &msghdr, &iov, 1, len-written );
      if( status >= 0 )
      {
         written += status;
         iov.iov_len = len-written;
      }
   }
   while( (status >= 0 && written < len) || 
           -EAGAIN == status  || 
           -ERESTARTSYS == status );

   return ( status >= 0 ) ? len : status;
}

/**
 * Callback to process net workqueue items
 */
static void __dsd_net_work( struct work_struct* work )
{
   struct dsd_net_work* net_work 
      = container_of( work, struct dsd_net_work, work );

   #ifdef DSD_DEBUG
   printk( KERN_ALERT DSD_TAG "net: Running workqueue\n" );
   #endif

   if( net_work->work_type == DSD_NET_WORK_SEND )
   {
      dsd_send( net_work->send.sk, 
                net_work->send.msg_type, 
                net_work->send.data, 
                net_work->send.len );
   }

   if( net_work->pending_work )
   {
      atomic_dec( net_work->pending_work );
   }

   /* FIXME: Is the rest of net_work leaked? */
   work_release( work );

   #ifdef DSD_DEBUG
   printk( KERN_ALERT DSD_TAG "net: Workqueue done\n" );
   #endif
}

/**
 * Initialize network module
 */
int __init dsd_net_init( void )
{
   net_workqueue = create_singlethread_workqueue( DSD_DEVICE_PREFIX "_net" );
   
   return 0;
}

/**
 * Cleanup network module
 */
void __exit dsd_net_exit( void )
{
   destroy_workqueue( net_workqueue );
}

