#include "client.h"
#include "../packet/question.h"
#include "addr.h"
#include "binding.h"

#include <resolv.h>
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <string.h>
#include <time.h>
#include <errno.h>
#include <sys/time.h>

void resolve_default_server(IpAddr* addr) {
    res_init();
    addr->data.v4 = _res.nsaddr_list[0].sin_addr;
    addr->type = V4;
}

static void push_question(Question** buf, Question q, uint8_t* amount, uint8_t* capacity) {
    if (*amount >= *capacity) {
        *capacity *= 2;
        *buf = realloc(*buf, sizeof(Question) * *capacity);
    }
    (*buf)[*amount] = q;
    (*amount)++;
}

void resolve_questions(int argc, char** argv, uint8_t* len, Question** questions) {
    assert(argc > 0);

    uint8_t amount = 0;
    uint8_t capacity = 1;
    Question* buf = malloc(sizeof(Question) * capacity);

    Question q;
    RecordType t;
    if (argc == 1 && str_to_qtype(argv[0], &t)) {
        create_question("", NS, &q);
        push_question(&buf, q, &amount, &capacity);
        *questions = buf;
        *len = 1;
        return;
    }

    for (int i = 0; i < argc; i++) {
        if (str_to_qtype(argv[i], &t)) {
            if (i + 1 == argc) break;
            create_question(argv[i+1], t, &q);
            i++;
        } else if (i + 1 < argc && str_to_qtype(argv[i+1], &t)){
            create_question(argv[i], t, &q); 
            i++;
        } else {
            create_question(argv[i], A, &q);
        }
        push_question(&buf, q, &amount, &capacity);
    }

    *questions = buf;
    *len = amount;
}

static void print_result(Packet* packet, const char* type, uint32_t ms) {
    printf(">> Recieved response\n");
    printf("Id: %hu, Code: %s\n", 
        packet->header.id,
        str_from_code(packet->header.rescode)
    );
    printf("Questions: %hu, Answers: %hu, Authorities: %hu, Resources: %hu\n",
        packet->header.questions,
        packet->header.answers,
        packet->header.authoritative_entries,
        packet->header.resource_entries
    );
    
    if (packet->header.questions > 0) {
        printf("\n>> Question Section\n");
        for (uint16_t i = 0; i < packet->header.questions; i++) {
            print_question(&packet->questions[i]);
        }
    }

    if (packet->header.answers > 0) {
        printf("\n>> Answer Section\n");
        for (uint16_t i = 0; i < packet->header.answers; i++) {
            print_record(&packet->answers[i]);
        }
    }  

    if (packet->header.authoritative_entries > 0) {
        printf("\n>> Authority Section\n");
        for (uint16_t i = 0; i < packet->header.authoritative_entries; i++) {
            print_record(&packet->authorities[i]);
        }
    }  

    if (packet->header.resource_entries > 0) {
        printf("\n>> Resource Section\n");
        for (uint16_t i = 0; i < packet->header.resource_entries; i++) {
            print_record(&packet->resources[i]);
        }
    }

    printf("\n>> Query time: %ums (%s)\n", ms, type);
}

void resolve(IpAddr server, Options options, Question* questions, uint8_t len) {
    struct timeval stop, start;
    gettimeofday(&start, NULL);

    SocketAddr addr;
    create_socket_addr(options.port, server, &addr); 

    Packet req;
    memset(&req, 0, sizeof(Packet));
    srand(time(NULL));
    req.header.id = rand() % 65536;
    req.header.questions = len;
    req.header.recursion_desired = true;
    req.questions = questions;

    Connection conn;
    if (options.force_tcp) goto tcp;

    if (!create_request(UDP, &addr, &conn)) {
        printf("error: failed to create udp request: %s\n", strerror(errno));
        exit(EXIT_FAILURE);
    }

    Packet res;
    if (!request_packet(&conn, &req, &res)) {
        free_request(&conn);
        free_packet(&req);
        printf("error: failed to request udp packet: %s\n", strerror(errno));
        exit(EXIT_FAILURE);
    }

    free_request(&conn);

    if (!res.header.truncated_message) {
        gettimeofday(&stop, NULL);
        uint32_t ms = (stop.tv_sec - start.tv_sec) * 1000000 + stop.tv_usec - start.tv_usec;

        print_result(&res, "UDP", ms / 1000);
        free_packet(&req);
        free_packet(&res);
        return;
    }

    free_packet(&res);
    printf("Response truncated, retrying in TCP...\n");
    
tcp:

    if (!create_request(TCP, &addr, &conn)) {
        printf("error: failed to create tcp request: %s\n", strerror(errno));
        exit(EXIT_FAILURE);
    }

    if (!request_packet(&conn, &req, &res)) {
        free_request(&conn);
        printf("error: failed to request tcp packet: %s\n", strerror(errno));
        exit(EXIT_FAILURE);
    }

    gettimeofday(&stop, NULL);
    uint32_t ms = (stop.tv_sec - start.tv_sec) * 1000000 + stop.tv_usec - start.tv_usec;

    print_result(&res, "TCP", ms / 1000);

    free_request(&conn);
    free_packet(&req);
    free_packet(&res);
}

