/*
 * asn1.c -- functions for decoding encoded ASN.1
 */

#include <stdlib.h>
#include <stdio.h>
#ifdef __STDC__
#include <limits.h>
#endif                          /* __STDC__ */
#include <ctype.h>
#include <string.h>
#include "asn1.h"

#ifndef USHRT_MAX
#define USHRT_MAX 65535
#endif

#ifndef UINT_MAX
#define UINT_MAX USHRT_MAX
#endif

#ifndef SIZE_T_MAX
#define SIZE_T_MAX UINT_MAX
#endif

#define BOOLEAN_TYPE    1
#define INTEGER_TYPE    2
#define BITSTRING_TYPE  3
#define NULL_TYPE       5

#ifndef TRUE
#define FALSE 0
#define TRUE 1
#endif

#define BIT_8        128
#define BIT_1          1
#define BITS_54321    31
#define BITS_7654321 127

#define punt(x) { printf(" -- %s\n", x); exit(1); }

#define the_class(x) (*x >> 6 & 3)
#define the_complexity(x) ((*x >> 5) & BIT_1)

static unsigned int
the_tag_number(encoding)
    char          **encoding;
/*
 * This function returns the tag number of an ASN.1
 * encoding. The argument must point to a pointer to
 * the first octet of the encoding. The pointer is
 * advanced to the first octet of the length of the
 * encoding.
 * 
 * BUG: Tag numbers greater than UINT_MAX have their
 * high order bits stripped.
 */
{
    unsigned int    t;

    t = ((unsigned char) **encoding) & BITS_54321;
    (*encoding)++;
    if (BITS_54321 == t) {      /* high tag number? */
        t = 0;
        while (**encoding & BIT_8) {
            t << = 7;
            t |= **encoding & BITS_7654321;
            (*encoding)++;
        }
        t << = 7;
        t |= **encoding;
        (*encoding)++;
    }
    return t;
}

static          size_t
the_length(encoding)
    char          **encoding;
/*
 * This function returns the length of an ASN.1
 * encoding. The pointer passed in must be pointing
 * to the first length octet of the encoding. It has
 * the side effect of advancing the pointer past the
 * length octets of the encoding.
 * 
 * BUGS:
 * 
 * This routine does not handle indefinite encodings.
 * i.e. A length octet of 0x80 causes SIZE_T_MAX to
 * be returned. This is the same value that is
 * returned when the length octets encode a length
 * equal to the highest value for a size_t.
 * 
 * This function does not handle the definite form for
 * values greater than SIZE_T_MAX. In this case it
 * returns SIZE_T_MAX-1. The same value is returned
 * when valid. This bug is unimportant with integers
 * of 16 or more bits and data that is not expected
 * to get longer than a couple kilobytes.
 * 
 */
{
    size_t          length;
    short           n;

    length = 0;
    if (BIT_8 == (**encoding))  /* indefinite form? */
        length = SIZE_T_MAX;
    else if (**encoding & BIT_8) {    /* long form? */
        n = (short) (**encoding & BITS_7654321);
        if (sizeof length < n)
            length = SIZE_T_MAX - 1;  /* We can only
                                       * take so
                                       * much. */
        else {
            (*encoding)++;
            for (; n; n--) {
                length << = 8;
                length |= (unsigned char) **encoding;
                (*encoding)++;
            }
        }
    } else {                    /* short form */
        length = (unsigned char) **encoding;
        (*encoding)++;
    }
    return length;
}

static int
isprintable(c)
    int             c;
/*
 * This is different from the standard C function
 * isprint.  This only allows characters that can be
 * in an ASN.1 PrintableString (universal type 19.)
 */
{
    if (isalpha(c))
        return TRUE;
    if (isdigit(c))
        return TRUE;
    if (NULL != strrchr(" '()+,-./:=?", c))
        return TRUE;
    return FALSE;
}

static void
show_string(string, length)
    char           *string;
    size_t          length;
{
    size_t          i;
    unsigned char   c;

    putchar('\"');
    for (i = 0; i < length; i++) {
        c = string[i];
        if (isprint(c) && (127 > c)) {
            switch (c) {
            case '\\':
            case '\"':
                putchar('\\');
            default:
                putchar(c);
            }
        } else if (iscntrl(c)) {
            putchar('\\');
            switch (c) {
            case '\a':
                putchar('a');
                break;
            case '\b':
                putchar('b');
                break;
            case '\f':
                putchar('f');
                break;
            case '\n':
                putchar('n');
                break;
            case '\r':
                putchar('r');
                break;
            case '\t':
                putchar('t');
                break;
            case '\v':
                putchar('v');
                break;
            default:
                printf("x%02hX", c);
            }
        } else
            printf("\\x%02hX", c);
    }
    putchar('\"');
}

static void
show_hex_string(string, length)
    char           *string;
    size_t          length;
{
    size_t          i;

    putchar('\'');
    for (i = 0; i < length; i++) {
        printf("%02hX", string[i]);
    }
    fputs("\'H", stdout);
}

static void
show_octet_string(string, length)
    char           *string;
    size_t          length;
{
    size_t          i;
    unsigned char   c;

    for (i = 0; i < length; i++) {
        if (!isprintable(string[i]))
            break;
    }
    if (i < length)
        show_hex_string(string, length);
    else
        show_string(string, length);
}

static void
show_bit_string(string, length)
    char           *string;
    size_t          length;
{
    size_t          i;
    unsigned char   c;
    unsigned char   mask;
    unsigned char   first_unused_bit;

    if (1 > length)
        punt("Bitstring length must be > 0.");
    if (1 == length && 0 != *string)
        punt("An empty bitstring must have no \
              unused bits.");
    if (7 < *string)
        punt("can't have more than 7 unused bits \
              in last octet.");
    putchar('\'');
    length--;
    for (i = 1; i < length; i++) {
        c = string[i];
        for (mask = BIT_8; mask; mask >> = 1)
            putchar(mask & c ? '1' : '0');
    }
    if (0 < length) {
        first_unused_bit = (unsigned char)
                           (1 << *string >> 1);
        for (mask = BIT_8;
             mask ^ first_unused_bit;
             mask >> = 1)
            putchar(mask & string[length] ? '1' : '0');
    }
    fputs("\'B", stdout);
}

static void
show_boolean(string, length)
    char           *string;
    size_t          length;
{
    if (0 == *string)
        printf("FALSE");
    else
        printf("TRUE");
}

static void
show_integer(string, length)
    char           *string;
    size_t          length;
{
    signed char    *c;
    signed char    *end;
    long            x;

    x = 0;
    end = string + length;
    for (c = string; c < end; c++) {
        x << = 8;
        x |= *c;
    }
    printf("%ld", x);
}

char           *
show_value_notation(encoding, length, level)
    char           *encoding;
    size_t          length;
    unsigned int    level;
/*
 * This function converts from transfer syntax to
 * value notation. It returns a pointer to the octet
 * past the encoding. It calls itself to decode
 * constructed encodings.
 */
{
    char           *next_octet;
    int             i;
    int             asn1_identifier;
    size_t          asn1_length;
    char           *asn1_contents;
    unsigned short  id_class;
    unsigned short  id_complexity;
    unsigned short  id_tag_number;
    static char    *class[] =
    {"UNIVERSAL ", "APPLICATION ", "", "PRIVATE "};

    if (99 < level)
        punt("We're in too deep.");
    next_octet = encoding + length;
    asn1_identifier = (unsigned char) *encoding;
    id_class = the_class(encoding);
    id_complexity = the_complexity(encoding);
    id_tag_number = the_tag_number(&encoding);

    for (i = level; i; i--)
        putchar('\t');
    printf("[%s%d] ", class[id_class], id_tag_number);

    asn1_length = the_length(&encoding);
    if (asn1_length == SIZE_T_MAX - 1)
        punt("Too long");
    if (asn1_length == SIZE_T_MAX)
        punt("Indefinite length");

    asn1_contents = encoding;
    if (id_complexity) {    /* Do we have a
                             * constructed encoding? */
        printf("{\n");
        while (encoding < (asn1_contents+asn1_length)){
            encoding = show_value_notation(encoding,
                                           asn1_length,
                                           level + 1);
            if (encoding < (asn1_contents+asn1_length))
                putchar(',');
            putchar('\n');
        }
        for (i = level; i; i--)
            putchar('\t');
        putchar('}');
    } else {
        switch (asn1_identifier) {
        case 0:
            punt("There is no such type.");
            break;
        case BOOLEAN_TYPE:
            show_boolean(asn1_contents, asn1_length);
            break;
        case INTEGER_TYPE:
            show_integer(asn1_contents, asn1_length);
            break;
        case BITSTRING_TYPE:
            show_bit_string(asn1_contents,asn1_length);
            break;
        case NULL_TYPE:
            printf("NULL");
            break;
        default:
            show_octet_string(asn1_contents,
                              asn1_length);
        }
    }
    if (next_octet < asn1_contents + asn1_length)
        punt("Datum did not end where expected.");
    return asn1_contents + asn1_length;
}
