#!/usr/bin/env python

import sys, sha256, srp6a, aes, gcm

# This number is known to be prime, and 2 is one of its primitive roots.
N = 125617018995153554710546479714086468244499594888726646874671447258204721048803L
g = 2
srp = srp6a.SRP(N, g, sha256.hash)

class ClientError(Exception): pass

class Client:
    def __init__(self, hostname='', port=7277):
        import socket
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.sock.connect((hostname, 7277))
        if self.get().split()[0] != 'Passpet':
            raise ClientError('not a Passpet server')
        self.put(1)
        self.get()

    def put(self, *messages):
        for message in messages:
            print '-> %s' % message
            self.sock.send('%s\n' % message)

    def get(self):
        message = ''
        while 1:
            c = self.sock.recv(1)
            if c == '\n':
                break
            if not c:
                raise ClientError('connection closed')
            message += c
        print '<- %s' % message
        if not message.startswith('+'):
            raise ClientError(message)
        return message[1:]

class EncryptedClient:
    def __init__(self, client, inmode, outmode):
        self.client = client
        self.inmode = inmode
        self.outmode = outmode

    def put(self, *messages):
        print repr(messages)
        for message in messages:
            print '=> %s' % message
            data, mac = outmode.encrypt(message)
            line = ''.join(['%02x' % ord(c) for c in data + mac])
            self.client.put(line)

    def get(self):
        line = self.client.get()
        try:
            data = ''.join([chr(int(line[i:i+2], 16))
                            for i in range(0, len(line), 2)])
            message = inmode.decrypt(data[:-16], data[-16:])
        except ValueError:
            raise ClientError('invalid encrypted message')
        print '<= %s' % message
        if not message.startswith('+'):
            raise ClientError(message)
        return message[1:]

hostname = (sys.argv + ['localhost'])[1]

while 1:
    print '''
create <username> <password> <k1>
list <username>
delete <username> <index> <password>
read <username> <index> <password>
write <username> <index> <password>

>''',

    words = raw_input().split()

    try:
        command, args = words[0], words[1:]
    except:
        continue

    try:
        if command == 'quit':
            break

        elif command == 'create':
            username, password, k1 = args
            salt, verifier = srp.setup(password)
            print 'password:', password
            print 'salt:', salt
            print 'verifier:', verifier
            c = Client(hostname)
            c.put('create', username, k1, salt, verifier)
            index = c.get()
            print 'ok, index is', index

        elif command == 'list':
            username, = args
            c = Client(hostname)
            c.put('list', username)
            results = c.get().split()
            for result in results:
                index, k1 = result.split(':')
                print 'index: %s, k1: %s' % (index, k1)

        elif command in ['delete', 'read', 'write']:
            username, index, password = args
            c = Client(hostname)
            c.put('login', username, index)
            R = srp6a.Object()
            client = srp.login(username, password, R)
            client_push = client.next()
            try:
                for message in client:
                    if message is None:
                        client_push(int(c.get()))
                    else:
                        c.put(message)
                noncebytes = [int(b, 16) for b in c.get().split()]
                nonce = aes.tolong(noncebytes)
            except srp6a.SRPError:
                raise ClientError('login failed')

            cipher = aes.AES(128, R.key)
            inmode = gcm.GCM(cipher, nonce=nonce)
            outmode = gcm.GCM(cipher, nonce=nonce ^ 1)
            ec = EncryptedClient(c, inmode, outmode)

            if command == 'delete':
                ec.put('delete')
                ec.get()

            elif command == 'read':
                ec.put('read')
                file = ec.get()
                print 'ok, file is %r' % file

            elif command == 'write':
                print 'old_mac:',
                old_mac = raw_input().strip()
                print 'new_file:',
                new_file = raw_input().strip()
                ec.put('write', old_mac, new_file)
                ec.get()
    except ClientError, e:
        print 'error', e.args
