#include "buffer.h"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

struct PacketBuffer {
    uint8_t* arr;
    int capacity;
    int index;
    int size;
};

PacketBuffer* buffer_create(int capacity) {
    PacketBuffer* buffer = malloc(sizeof(PacketBuffer));
    buffer->arr = malloc(capacity);
    memset(buffer->arr, 0, capacity);
    buffer->capacity = capacity;
    buffer->size = 0;
    buffer->index = 0;
    return buffer;
}

void buffer_free(PacketBuffer* buffer) {
    free(buffer->arr);
    free(buffer);
}

void buffer_seek(PacketBuffer* buffer, int index) {
    buffer->index = index;
}

uint8_t buffer_read(PacketBuffer* buffer) {
    if (buffer->index >= buffer->size) {
        return 0;
    }
    uint8_t data = buffer->arr[buffer->index];
    buffer->index++;
    return data;
}

uint16_t buffer_read_short(PacketBuffer* buffer) {
    return
        (uint16_t) buffer_read(buffer) << 8 |
        (uint16_t) buffer_read(buffer);
}

uint32_t buffer_read_int(PacketBuffer* buffer) {
    return
        (uint32_t) buffer_read(buffer) << 24 |
        (uint32_t) buffer_read(buffer) << 16 |
        (uint32_t) buffer_read(buffer) << 8 |
        (uint32_t) buffer_read(buffer);
}

uint8_t buffer_get(PacketBuffer* buffer, int index) {
    if (index >= buffer->size) {
        return 0;
    }
    uint8_t data = buffer->arr[index];
    return data;
}

uint8_t* buffer_get_range(PacketBuffer* buffer, int start, int len) {
    uint8_t* arr = malloc(len);
    for (int i = 0; i < len; i++) {
        arr[i] = buffer_get(buffer, start + i);
    }
    return arr;
}

uint16_t buffer_get_size(PacketBuffer* buffer) {
    return (uint16_t) buffer->size + 1;
}

static void write(uint8_t** buffer, uint8_t* size, uint8_t* capacity, uint8_t data) {
    if (*size >= *capacity) {
        if (*capacity >= 128) {
            *capacity = 255;
        } else {
            *capacity *= 2;
        }
        *buffer = realloc(*buffer, *capacity); 
    }
    (*buffer)[*size] = data;
    (*size)++; 
}

void buffer_read_qname(PacketBuffer* buffer, uint8_t** out) {
    int index = buffer->index;
    int jumped = 0;

    int max_jumps = 5;
    int jumps_performed = 0;

    uint8_t length = 0;
    uint8_t capacity = 8;
    *out = malloc(capacity);
    write(out, &length, &capacity, 0);

    while(1) {
        if (jumps_performed > max_jumps) {
            break;
        }

        uint8_t len = buffer_get(buffer, index);

        if ((len & 0xC0) == 0xC0) {
            if (jumped == 0) {
                buffer_seek(buffer, index + 2);
            } 

            uint16_t b2 = (uint16_t) buffer_get(buffer, index + 1);
            uint16_t offset = ((((uint16_t) len) ^ 0xC0) << 8) | b2;
            index = (int) offset;
            jumped = 1;
            jumps_performed++;
            continue;         
        }

        index++;

        if (len == 0) {
            break;
        }

        if (length > 1) {
            write(out, &length, &capacity, '.');
        }

        uint8_t* range = buffer_get_range(buffer, index, len);
        for (uint8_t i = 0; i < len; i++) {
            write(out, &length, &capacity, range[i]);
        } 
        free(range);

        index += (int) len;
    }

    if (jumped == 0) {
        buffer_seek(buffer, index);
    }
    
    (*out)[0] = length - 1;
}

void buffer_read_string(PacketBuffer* buffer, uint8_t** out) {
    uint8_t len = buffer_read(buffer);
    buffer_read_n(buffer, out, len);
}

static void buffer_expand(PacketBuffer* buffer, int capacity) {
    if (buffer->capacity >= capacity) return;

    buffer->arr = realloc(buffer->arr, capacity);
    memset(buffer->arr + buffer->capacity, 0, capacity - buffer->capacity);
    buffer->capacity = capacity;
}

void buffer_read_n(PacketBuffer* buffer, uint8_t** out, uint8_t len) {
    buffer_expand(buffer, buffer->index + len + 1);
    *out = malloc(len + 1);
    *out[0] = len;
    memcpy(*out + 1, buffer->arr + buffer->index, len);
    buffer->index += len;
}

void buffer_write(PacketBuffer* buffer, uint8_t data) {
    buffer_expand(buffer, buffer->index);
    if (buffer->size < buffer->index + 1) {
        buffer->size = buffer->index + 1;
    }
    buffer->arr[buffer->index] = data;
    buffer->index++;
}

void buffer_write_short(PacketBuffer* buffer, uint16_t data) {
    buffer_write(buffer, (uint8_t)(data >> 8));
    buffer_write(buffer, (uint8_t)(data & 0xFF));
}

void buffer_write_int(PacketBuffer* buffer, uint32_t data) {
    buffer_write(buffer, (uint8_t)(data >> 24));
    buffer_write(buffer, (uint8_t)(data >> 16));
    buffer_write(buffer, (uint8_t)(data >> 8));
    buffer_write(buffer, (uint8_t)(data & 0xFF));
}

void buffer_write_qname(PacketBuffer* buffer, uint8_t* in) {
    uint8_t part = 0;
    uint8_t len = in[0];
    
    buffer_write(buffer, 0);
    
    if (len == 0) {
        return;
    }    

    for(uint8_t i = 0; i < len; i ++) {
        if (in[i+1] == '.') {
            buffer_set(buffer, part, buffer->index - (int)part - 1);
            buffer_write(buffer, 0);
            part = 0;
        } else {
            buffer_write(buffer, in[i+1]);
            part++;
        }
    }
    buffer_set(buffer, part, buffer->index - (int)part - 1);
    buffer_write(buffer, 0);
}

void buffer_write_string(PacketBuffer* buffer, uint8_t* in) {
    buffer_write(buffer, in[0]);
    buffer_write_n(buffer, in + 1, in[0]);
}

void buffer_write_n(PacketBuffer* buffer, uint8_t* in, int len) {
    buffer_expand(buffer, buffer->size + len);
    memcpy(buffer->arr + buffer->index, in, len);
    buffer->size += len;
    buffer->index += len;
}

void buffer_set(PacketBuffer* buffer, uint8_t data, int index) {
    if (index > buffer->size) {
        return;
    }
    buffer->arr[index] = data;
}

void buffer_set_uint16_t(PacketBuffer* buffer, uint16_t data, int index) {
    buffer_set(buffer, (uint8_t)(data >> 8), index);
    buffer_set(buffer, (uint8_t)(data & 0xFF), index + 1);
}

int buffer_get_index(PacketBuffer* buffer) {
    return buffer->index;
}

void buffer_step(PacketBuffer* buffer, int len) {
    buffer->index += len;
}

uint8_t* buffer_get_ptr(PacketBuffer* buffer) {
    return buffer->arr;
}
