#include "memory.h"
#include "string.h"
#include "../panic.h"
#include "stdio.h"
#include "util.h"
#include <stddef.h>

void init_heap(Heap* h, size_t base, size_t size) {
    h->first = (HeapBlock*)base;
    h->first->prev = (HeapBlock*)0;
    h->first->next = 0;
    h->first->used = false;
    h->first->size = size - sizeof(HeapBlock);
    h->first->tag = "_initial";
}

void* alloc(Heap* heap, size_t size) {
    HeapBlock* first_free = heap->first;
    while(first_free->used || first_free->size < size) {
        first_free = first_free->next;
        if(first_free == 0) {
            panic("Failed to allocate heap block of size %d", size);
        }
    }
    // first_free now contains the first free block that is large enough
    int empty_space = first_free->size - size;
    if(empty_space > sizeof(HeapBlock)) {
        // create new free block
        HeapBlock* next_block = (HeapBlock*)((char*)first_free + sizeof(HeapBlock) + size);
        int next_size = empty_space - sizeof(HeapBlock);
        *next_block = (HeapBlock){first_free, first_free->next, 0, next_size, "_free_m"};
        first_free->size = size;
        first_free->used = true;
        first_free->next = next_block;
        first_free->tag = "";
        return (char*)first_free + sizeof(HeapBlock);
    } else {
        // not enough space to create new free block
        first_free->used = true;
        first_free->tag = "";
        return (char*)first_free + sizeof(HeapBlock);
    }
}

void free(void* ptr) {
    HeapBlock* block_ptr = ptr - sizeof(HeapBlock);
    if(!block_ptr->used) {
        panic("Double free or invalid pointer at 0x%p", ptr);
    }
    block_ptr->used = false;
    block_ptr->tag = "_free";
    if(block_ptr->next != 0 && !block_ptr->next->used) {
        // merge with next block
        block_ptr->size += block_ptr->next->size + sizeof(HeapBlock);
        if(block_ptr->next->next != 0) {
            block_ptr->next->next->prev = block_ptr;
        }
        block_ptr->next = block_ptr->next->next;
    }
    if(block_ptr->prev != 0 && !block_ptr->prev->used) {
        // merge with previous block
        HeapBlock* block_ptr_new = block_ptr->prev;
        block_ptr_new->size += block_ptr->size + sizeof(HeapBlock);
        if(block_ptr->next != 0) {
            block_ptr->next->prev = block_ptr_new;
        }
        block_ptr_new->next = block_ptr->next;
    }
}

void* realloc(Heap* h, void* ptr, size_t size) {
    HeapBlock* block = (HeapBlock*)((char*)ptr - sizeof(HeapBlock));
    if(block->size >= size) {
        return ptr;
    }
    if(block->next != NULL && !block->next->used) {
        size_t available_space = block->size + block->next->size + sizeof(HeapBlock);
        if(available_space >= size) {
            // expand into subsequent free block
            if(available_space > size + sizeof(HeapBlock)) {
                // create new free block after
                HeapBlock* new_free = (HeapBlock*)((char*)ptr + size);
                if(block->next->next != NULL) {
                    block->next->next->prev = new_free;
                }
                *new_free = (HeapBlock){block, block->next->next, false, available_space - size - sizeof(HeapBlock), "_realloc"};
                block->next = new_free;
                block->size = size;
            } else {
                // cannot create new free block
                if(block->next) {
                    
                }
            }
            return ptr;
        }
    }
    // allocate a new block and free this one
    void* new_ptr = alloc(h, size);
    memcpy(new_ptr, ptr, block->size);
    free(ptr);
    return new_ptr;
}

char* hgettag(void* ptr) {
    HeapBlock* block = (HeapBlock*)((char*)ptr - sizeof(HeapBlock));
    return block->tag;
}

void hsettag(void* ptr, char* tag) {
    HeapBlock* block = (HeapBlock*)((char*)ptr - sizeof(HeapBlock));
    block->tag = tag;
}

void heaptrace(Heap* heap) {
    HeapBlock* block = heap->first;
    printfln("%$0dHEAP TRACE");
    while(block != 0) {
        char* tag = hgettag((char*)block + sizeof(HeapBlock));
        printfln("%$0a%p: %$0f%s", block, tag);
        printfln("%$0b  size=0x%x", block->size);
        if(block->used) {
            printfln("%$0c  USED");
        } else {
            printfln("%$0e  FREE");
        }
        block = block->next;
    }
    printfln("%$0dEND");
    resetcol();
}
