Source code for oasislmf.pytools.aal.utils

import numba as nb
import numpy as np


"""
Custom numba heap implementation for summaries index (based on Python heapq)
A custom heap is required for summaries index as we need to store and sort
based on a tuple of 3 np.int32s. This complex datatype is currently not supported
by numba for standard python heapq.

Stores 5 ints, and sorts lexographically on the first 3 ints, then next 2 ints
First 3 ints are summary_index, period_no, file_idx
Last 2 ints are partial_file_idx, row_number in file[file_idx]
"""


@nb.njit(cache=True, error_model="numpy")
def _resize_heap(heap, current_capacity):
    """Doubles the heap capacity"""
    new_capacity = current_capacity * 2
    new_heap = np.zeros((new_capacity, 5), dtype=heap.dtype)
    for i in range(current_capacity):
        new_heap[i] = heap[i]
    return new_heap


@nb.njit(cache=True, error_model="numpy")
def _lex_compare(a, b):
    """Performs lexicographical comparison for all elements in arrays a and b.
    We compare all elements, and not just the first 3, as we also need to order
    by file_offset when summary_id, period_no, file_idx, which the partial_file_idx
    and row_num values do"""
    for i in range(5):
        if a[i] < b[i]:
            return True
        elif a[i] > b[i]:
            return False
    return False


@nb.njit(cache=True, error_model="numpy")
def _swap_rows(heap, i, j):
    """Swap rows i and j in heap"""
    for k in range(5):
        temp = heap[i, k]
        heap[i, k] = heap[j, k]
        heap[j, k] = temp


@nb.njit(cache=True, error_model="numpy")
def _sift_down(heap, startpos, pos):
    """Heapq _siftdown"""
    # Follow the path to the root, moving parents down until finding a place
    # newitem fits.
    while pos > startpos:
        parentpos = (pos - 1) >> 1
        if _lex_compare(heap[pos], heap[parentpos]):
            _swap_rows(heap, pos, parentpos)
            pos = parentpos
            continue
        break


@nb.njit(cache=True, error_model="numpy")
def _sift_up(heap, pos, endpos):
    """Heapq _siftup"""
    startpos = pos
    # Bubble up the smaller child until hitting a leaf.
    childpos = 2 * pos + 1  # leftmost child position
    while childpos < endpos:
        # Set childpos to index of smaller child.
        rightpos = childpos + 1
        if rightpos < endpos and not _lex_compare(heap[childpos], heap[rightpos]):
            childpos = rightpos
        # Move the smaller child up.
        if pos != childpos:
            _swap_rows(heap, pos, childpos)
        pos = childpos
        childpos = 2 * pos + 1
    # The leaf at pos is empty now.  Put newitem there, and bubble it up
    # to its final resting place (by sifting its parents down).
    _sift_down(heap, startpos, pos)


@nb.njit(cache=True, error_model="numpy")
[docs] def heap_push(heap, size, element): """Heapq heappush""" if size >= len(heap): heap = _resize_heap(heap, len(heap)) heap[size] = element _sift_down(heap, 0, size) return heap, size + 1
@nb.njit(cache=True, error_model="numpy")
[docs] def heap_pop(heap, size): """Heapq heappop""" if size <= 0: raise ValueError("Heap underflow: Cannot pop from an empty heap.") lastelt = heap[size - 1].copy() if size - 1 > 0: returnitem = heap[0].copy() heap[0] = lastelt _sift_up(heap, 0, size - 1) return returnitem, heap, size - 1 return lastelt, heap, size - 1
@nb.njit(cache=True, error_model="numpy")
[docs] def init_heap(num_rows=4): """Initialise heap""" return np.zeros((num_rows, 5), dtype=np.int32)
@nb.njit(cache=True, error_model="numpy")