#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>

#undef _POSIX_C_SOURCE
#include <stdio.h>

#include "record.h"
#include "buffer.h"

uint16_t record_to_id(RecordType type) {
    switch (type) {
        case A:
            return 1;
        case NS:
            return 2;
        case CNAME:
            return 5;
        case SOA:
            return 6;
        case PTR:
            return 12;
        case MX:
            return 15;
        case TXT:
            return 16;
        case AAAA:
            return 28;
        case SRV:
            return 33;
        case CAA:
            return 257;
        default:
            return 0;
    }
}

void record_from_id(uint16_t i, RecordType* type) {
    switch (i) {
        case 1:
            *type = A;
            break;
        case 2:
            *type = NS;
            break;
        case 5:
            *type = CNAME;
            break;
        case 6:
            *type = SOA;
            break;
        case 12:
            *type = PTR;
            break;
        case 15:
            *type = MX;
            break;
        case 16:
            *type = TXT;
            break;
        case 28:
            *type = AAAA;
            break;
        case 33:
            *type = SRV;
            break;
        case 257:
            *type = CAA;
            break;
        default:
            *type = UNKOWN;
    }
}

bool str_to_qtype(const char* qstr, RecordType* qtype) {
    if (strcasecmp(qstr, "A") == 0) {
        *qtype = A;
        return true;
    } else if (strcasecmp(qstr, "NS") == 0) {
        *qtype = NS;
        return true;
    } else if (strcasecmp(qstr, "CNAME") == 0) {
        *qtype = CNAME;
        return true;
    } else if (strcasecmp(qstr, "SOA") == 0) {
        *qtype = SOA;
        return true;
    } else if (strcasecmp(qstr, "PTR") == 0) {
        *qtype = PTR;
        return true;
    } else if (strcasecmp(qstr, "MX") == 0) {
        *qtype = MX;
        return true;
    } else if (strcasecmp(qstr, "TXT") == 0) {
        *qtype = TXT;
        return true;
    } else if (strcasecmp(qstr, "AAAA") == 0) {
        *qtype = AAAA;
        return true;
    } else if (strcasecmp(qstr, "SRV") == 0) {
        *qtype = SRV;
        return true;
    } else if (strcasecmp(qstr, "CAA") == 0) {
        *qtype = CAA;
        return true;
    } else { 
        return false;
    }
    return false;
}

static void read_a_record(PacketBuffer* buffer, Record* record) {
    ARecord data;
    data.addr[0] = buffer_read(buffer);
    data.addr[1] = buffer_read(buffer);
    data.addr[2] = buffer_read(buffer);
    data.addr[3] = buffer_read(buffer);

    record->data.a = data;
}

static void read_ns_record(PacketBuffer* buffer, Record* record) {
    NSRecord data;
    buffer_read_qname(buffer, &data.host);

    record->data.ns = data;
}

static void read_cname_record(PacketBuffer* buffer, Record* record) {
    CNAMERecord data;
    buffer_read_qname(buffer, &data.host);

    record->data.cname = data;
}

static void read_soa_record(PacketBuffer* buffer, Record* record) {
    SOARecord data;
    buffer_read_qname(buffer, &data.mname);
    buffer_read_qname(buffer, &data.nname);
    data.serial = buffer_read_int(buffer);
    data.refresh = buffer_read_int(buffer);
    data.retry = buffer_read_int(buffer);
    data.expire = buffer_read_int(buffer);
    data.minimum = buffer_read_int(buffer);

    record->data.soa = data;
}

static void read_ptr_record(PacketBuffer* buffer, Record* record) {
    PTRRecord data;
    buffer_read_qname(buffer, &data.pointer);

    record->data.ptr = data;
}

static void read_mx_record(PacketBuffer* buffer, Record* record) {
    MXRecord data;
    data.priority = buffer_read_short(buffer);
    buffer_read_qname(buffer, &data.host);

    record->data.mx = data;
}

static void read_txt_record(PacketBuffer* buffer, Record* record) {
    TXTRecord data;
    data.len = 0;
    data.text = malloc(sizeof(uint8_t*) * 2);
    uint8_t capacity = 2;
    uint8_t total = record->len;
    while (1) {
        if (data.len >= capacity) {
            if (capacity >= 128) {
                capacity = 255;
            } else {
                capacity *= 2;
            }
            data.text = realloc(data.text, sizeof(uint8_t*) * capacity);
        }

        buffer_read_string(buffer, &data.text[data.len]);
        if(data.text[data.len][0] == 0) {
            free(data.text[data.len]);
            break;
        }

        data.len++;

        total -= data.text[data.len - 1][0] + 1;
        if (total == 0) break;
    }

    record->data.txt = data;
}

static void read_aaaa_record(PacketBuffer* buffer, Record* record) {
    AAAARecord data;
    for (int i = 0; i < 16; i++) {
        data.addr[i] = buffer_read(buffer);
    }

    record->data.aaaa = data;
}

static void read_srv_record(PacketBuffer* buffer, Record* record) {
    SRVRecord data;
    data.priority = buffer_read_short(buffer);
    data.weight = buffer_read_short(buffer);
    data.port = buffer_read_short(buffer);
    buffer_read_qname(buffer, &data.target);

    record->data.srv = data;
}

static void read_caa_record(PacketBuffer* buffer, Record* record, int header_pos) {
    CAARecord data;
    data.flags = buffer_read(buffer);
    data.length = buffer_read(buffer);
    buffer_read_n(buffer, &data.tag, data.length);
    int value_len = ((int)record->len) + header_pos - buffer_get_index(buffer);
    buffer_read_n(buffer, &data.value, (uint8_t)value_len);

    record->data.caa = data;
}

bool read_record(PacketBuffer* buffer, Record* record) {
    buffer_read_qname(buffer, &record->domain);

    uint16_t qtype_num = buffer_read_short(buffer);
    record_from_id(qtype_num, &record->type);

    record->cls = buffer_read_short(buffer);
    record->ttl = buffer_read_int(buffer);
    record->len = buffer_read_short(buffer);

    int header_pos = buffer_get_index(buffer);

    switch (record->type) {
        case A:
            read_a_record(buffer, record);
            break;
        case NS:
            read_ns_record(buffer, record);
            break;
        case CNAME:
            read_cname_record(buffer, record);
            break;
        case SOA:
            read_soa_record(buffer, record);
            break;
        case PTR:
            read_ptr_record(buffer, record);
            break;
        case MX:
            read_mx_record(buffer, record);
            break;
        case TXT:
            read_txt_record(buffer, record);
            break;
        case AAAA:
            read_aaaa_record(buffer, record);
            break;
        case SRV:
            read_srv_record(buffer, record);
            break;
        case CAA:
            read_caa_record(buffer, record, header_pos);
            break;
        default:
            buffer_step(buffer, record->len);
            free(record->domain);
            return false;
    }

    return true;
}

static void write_a_record(PacketBuffer* buffer, ARecord* data) {
    buffer_write_short(buffer, 4);
    buffer_write(buffer, data->addr[0]);
    buffer_write(buffer, data->addr[1]);
    buffer_write(buffer, data->addr[2]);
    buffer_write(buffer, data->addr[3]);
}

static void write_ns_record(PacketBuffer* buffer, NSRecord* data) {
    int pos = buffer_get_index(buffer);
    buffer_write_short(buffer, 0);

    buffer_write_qname(buffer, data->host);

    int size = buffer_get_index(buffer) - pos - 2;
    buffer_set_uint16_t(buffer, (uint16_t)size, pos);
}

static void write_cname_record(PacketBuffer* buffer, CNAMERecord* data) {
    int pos = buffer_get_index(buffer);
    buffer_write_short(buffer, 0);

    buffer_write_qname(buffer, data->host);

    int size = buffer_get_index(buffer) - pos - 2;
    buffer_set_uint16_t(buffer, (uint16_t)size, pos);
}

static void write_soa_record(PacketBuffer* buffer, SOARecord* data) {
    int pos = buffer_get_index(buffer);
    buffer_write_short(buffer, 0);

    buffer_write_qname(buffer, data->mname);
    buffer_write_qname(buffer, data->nname);
    buffer_write_int(buffer, data->serial);
    buffer_write_int(buffer, data->refresh);
    buffer_write_int(buffer, data->retry);
    buffer_write_int(buffer, data->expire);
    buffer_write_int(buffer, data->minimum);

    int size = buffer_get_index(buffer) - pos - 2;
    buffer_set_uint16_t(buffer, (uint16_t)size, pos);
}

static void write_ptr_record(PacketBuffer* buffer, PTRRecord* data) {
    int pos = buffer_get_index(buffer);
    buffer_write_short(buffer, 0);

    buffer_write_qname(buffer, data->pointer);

    int size = buffer_get_index(buffer) - pos - 2;
    buffer_set_uint16_t(buffer, (uint16_t)size, pos);
}

static void write_mx_record(PacketBuffer* buffer, MXRecord* data) {
    int pos = buffer_get_index(buffer);
    buffer_write_short(buffer, 0);

    buffer_write_short(buffer, data->priority);
    buffer_write_qname(buffer, data->host);

    int size = buffer_get_index(buffer) - pos - 2;
    buffer_set_uint16_t(buffer, (uint16_t)size, pos);
}

static void write_txt_record(PacketBuffer* buffer, TXTRecord* data) {
    int pos = buffer_get_index(buffer);
    buffer_write_short(buffer, 0);

    if(data->len == 0) {
        return;
    }

    for(uint8_t i = 0; i < data->len; i++) {
        buffer_write_string(buffer, data->text[i]);
    }

    int size = buffer_get_index(buffer) - pos - 2;
    buffer_set_uint16_t(buffer, (uint16_t)size, pos);
}

static void write_aaaa_record(PacketBuffer* buffer, AAAARecord* data) {
    buffer_write_short(buffer, 16);

    for (int i = 0; i < 16; i++) {
        buffer_write(buffer, data->addr[i]);
    }
}

static void write_srv_record(PacketBuffer* buffer, SRVRecord* data) {
    int pos = buffer_get_index(buffer);
    buffer_write_short(buffer, 0);

    buffer_write_short(buffer, data->priority);
    buffer_write_short(buffer, data->weight);
    buffer_write_short(buffer, data->port);
    buffer_write_qname(buffer, data->target);

    int size = buffer_get_index(buffer) - pos - 2;
    buffer_set_uint16_t(buffer, (uint16_t)size, pos);
}

static void write_caa_record(PacketBuffer* buffer, CAARecord* data) {
    int pos = buffer_get_index(buffer);
    buffer_write_short(buffer, 0);
    buffer_write(buffer, data->flags);
    buffer_write(buffer, data->length);
    buffer_write_n(buffer, data->tag + 1, data->tag[0]);
    buffer_write_n(buffer, data->value + 1, data->value[0]);

    int size = buffer_get_index(buffer) - pos - 2;
    buffer_set_uint16_t(buffer, (uint16_t)size, pos);
}

static void write_record_header(PacketBuffer* buffer, Record* record) {
    buffer_write_qname(buffer, record->domain);
    uint16_t id = record_to_id(record->type);
    buffer_write_short(buffer, id);
    buffer_write_short(buffer, record->cls);
    buffer_write_int(buffer, record->ttl);
}

void write_record(PacketBuffer* buffer, Record* record) {
    switch(record->type) {
        case A:
            write_record_header(buffer, record);
            write_a_record(buffer, &record->data.a);
            break; 
        case NS:
            write_record_header(buffer, record);
            write_ns_record(buffer, &record->data.ns);
            break;
        case CNAME:
            write_record_header(buffer, record);
            write_cname_record(buffer, &record->data.cname);
            break;
        case SOA:
            write_record_header(buffer, record);
            write_soa_record(buffer, &record->data.soa);
            break;
        case PTR:
            write_record_header(buffer, record);
            write_ptr_record(buffer, &record->data.ptr);
            break;
        case MX:
            write_record_header(buffer, record);
            write_mx_record(buffer, &record->data.mx);
            break;
        case TXT:
            write_record_header(buffer, record);
            write_txt_record(buffer, &record->data.txt);
            break;
        case AAAA:
            write_record_header(buffer, record);
            write_aaaa_record(buffer, &record->data.aaaa);
            break;
        case SRV:
            write_record_header(buffer, record);
            write_srv_record(buffer, &record->data.srv);
            break;
        case CAA:
            write_record_header(buffer, record);
            write_caa_record(buffer, &record->data.caa);
            break;
        default:
            break;
    }

}

void free_record(Record* record) {
    free(record->domain);
    switch (record->type) {
        case NS:
            free(record->data.ns.host);
            break;
        case CNAME:
            free(record->data.cname.host);
            break;
        case SOA:
            free(record->data.soa.mname);
            free(record->data.soa.nname);
            break;
        case PTR:
            free(record->data.ptr.pointer);
            break;
        case MX:
            free(record->data.mx.host);
            break;
        case TXT:
            for (uint8_t i = 0; i < record->data.txt.len; i++) {
                free(record->data.txt.text[i]);
            }
            free(record->data.txt.text);
            break;
        case SRV:
            free(record->data.srv.target);
            break;
        case CAA:
            free(record->data.caa.value);
            free(record->data.caa.tag);
            break;
        default:
            break;
        }
}

static const char* class_to_str(Record* record) {
    switch(record->cls) {
        case 1: return "IN";
        case 3: return "CH";
        case 4: return "HS";
        default: return "??";
    }
}

static const char* qtype_to_str(Record* record) {
    switch(record->type) {
        case A: return "A";
        case NS: return "NS";
        case CNAME: return "CNAME";
        case SOA: return "SOA";
        case PTR: return "PTR";
        case MX: return "MX";
        case TXT: return "TXT";
        case AAAA: return "AAAA";
        case SRV: return "SRV";
        case CAA: return "CAA";
        default: return "UNKOWN";
    }
}

static void print_record_data(Record* record) {
    switch(record->type) {
        case A:
            printf("%hhu.%hhu.%hhu.%hhu",
                record->data.a.addr[0],
                record->data.a.addr[1],
                record->data.a.addr[2],
                record->data.a.addr[3]
            );
            break;
        case NS:
            printf("%.*s",
                record->data.ns.host[0],
                record->data.ns.host + 1
            );
            break;
        case CNAME:
            printf("%.*s",
                record->data.cname.host[0],
                record->data.cname.host + 1
            );
            break;
        case SOA:
            printf("%.*s %.*s %u %u %u %u %u",
                record->data.soa.mname[0],
                record->data.soa.mname + 1,
                record->data.soa.nname[0],
                record->data.soa.nname + 1,
                record->data.soa.serial,
                record->data.soa.refresh,
                record->data.soa.retry,
                record->data.soa.expire,
                record->data.soa.minimum
            );
            break;
        case PTR:
            printf("%.*s",
                record->data.ptr.pointer[0],
                record->data.ptr.pointer + 1
            );
            break;
        case MX:
            printf("%.*s %hu",
                record->data.mx.host[0],
                record->data.mx.host + 1,
                record->data.mx.priority
            );
            break;
        case TXT:
            for(uint8_t i = 0; i < record->data.txt.len; i++) {
                printf("%.*s",
                    record->data.txt.text[i][0],
                    record->data.txt.text[i] + 1
                );
            }
            break;
        case AAAA:
            for(int i = 0; i < 8; i++) {
                printf("%02hhx%02hhx:",
                    record->data.a.addr[i*2 + 0],
                    record->data.a.addr[i*2 + 1]
                );
            }
            printf(":");
            break;
        case SRV:
            printf("SRV (%hu %hu %hu %.*s",
                record->data.srv.priority,
                record->data.srv.weight,
                record->data.srv.port,
                record->data.srv.target[0],
                record->data.srv.target + 1
            );
            break;
        case CAA:
            printf("%hhu %.*s %.*s",
                record->data.caa.flags,
                record->data.caa.tag[0],
                record->data.caa.tag + 1,
                record->data.caa.value[0],
                record->data.caa.value + 1
            );
            break;
        default:
            break;
    } 
}

void print_record(Record* record) {
    printf("%.*s.\t%s %s\t",
        record->domain[0],
        record->domain + 1,
        class_to_str(record),
        qtype_to_str(record)
    );
    print_record_data(record);
    printf("\n");
}
