Source code for oasislmf.pytools.common.hashmap

"""
Robin Hood hash table for numba JIT code.

Supports structured-record keys (e.g. np.dtype([('a','i4'),('b','u1')])) and
scalar keys — numeric (int8..int64, uint8..uint64, bool, float32, float64) or
unicode (fixed-width 'U<n>' / variable-length unicode strings). Hash and
equality are synthesized per dtype at compile time via @numba.extending.overload,
so each call site does typed field accesses only — no byte view, no buffer
allocation.

All hashmap state is stored in a single packed uint8 buffer ("table") holding
[info | lookup_table | index_table]. Use unpack(table) to get views.

Public API (convenience — unpacks per call):
    init_dict(hint_size)                                -> table
    try_add_key(table, key_storage, key, i_item=None)   -> slot | new_slot_bit, or i_add_key_fail
    find_key(table, key_storage, key)                   -> slot, or NOT_FOUND
    rehash(table, key_storage)                          -> new_table
    unpack(table)                                       -> (info, lookup, index) views

Internal API (hot-path — caller unpacks once):
    _try_add_key(info, lookup, index, key_storage, key, i_item=None)
    _find_key(info, lookup, index, key_storage, key)

Performance (indicative — varies with hardware, numba / numpy versions,
and workload; here on int32 scalar keys, 1M ops, 50% hit / 50% miss):
    Public  try_add_key  : a few M ops/s     (unpack overhead per call)
    Public  find_key     : a few M ops/s
    Internal _try_add_key: ~1.5–2x faster than public
    Internal _find_key   : ~2–3x faster than public

Records, floats, and unicode keys are slower than int32 scalars roughly in
proportion to the number of fields / code points hashed.

Rule of thumb: use the public API for one-off or low-volume calls. For tight
loops (inserting/looking up thousands of keys), unpack once and call the
internal _impl functions directly.

try_add_key has two modes selected via `i_item` (compile-time literal branch —
both modes specialize with no runtime overhead):

    i_item is None (by-value mode):
        On insertion, the hashmap writes `key` to key_storage[info[HM_INFO_N_VALID]] and
        stores info[HM_INFO_N_VALID] in the index. The caller doesn't track positions;
        key_storage[:info[HM_INFO_N_VALID]] ends up holding the unique keys in insertion
        order.

    i_item is an int (by-position mode):
        On insertion, stores `i_item` in the index. The hashmap does NOT
        write to key_storage — the caller is asserting that
        key_storage[i_item] already holds `key`.

Sizing and rehash:

    init_dict(hint_size) pre-allocates the table so that `hint_size` unique
    keys fit within the load factor (n_full_factor = 0.95). When the caller
    knows an upper bound on the number of unique keys and passes it as
    hint_size, the load-factor guard (info[HM_INFO_N_VALID] >= info[HM_INFO_N_FULL])
    is guaranteed never to fire and can be omitted from the insert loop.

    The `while result == i_add_key_fail` rehash loop must still be kept even
    with a correctly pre-sized table. i_add_key_fail signals that the Robin
    Hood displacement chain exceeded max_rh due to hash collisions — this is
    independent of the load factor and can (rarely) occur at any occupancy.

    When hint_size is not known up front (e.g. streaming inserts), the caller
    must check the load-factor guard before each insertion:

        if info[HM_INFO_N_VALID] >= info[HM_INFO_N_FULL]:
            table = rehash(table, key_storage)
            info, lookup, index = unpack(table)

Example — by-value insert, size known (pre-sized table, no load-factor check)::

    from oasislmf.pytools.common.hashmap import (
        init_dict, unpack, rehash, _try_add_key,
        i_add_key_fail, new_slot_bit, slot_mask,
    )

    @nb.jit(cache=True)
    def build_vuln_dict(vuln_ids):
        key_storage = np.empty(len(vuln_ids), dtype=np.int32)
        table = init_dict(len(vuln_ids))  # pre-sized: no load-factor rehash needed
        info, lookup, index = unpack(table)

        for i in range(len(vuln_ids)):
            result = _try_add_key(info, lookup, index, key_storage, vuln_ids[i])
            while result == i_add_key_fail:  # Robin Hood collision — still possible
                table = rehash(table, key_storage)
                info, lookup, index = unpack(table)
                result = _try_add_key(info, lookup, index, key_storage, vuln_ids[i])

            dense_idx = index[result & slot_mask]  # works for both new and existing

        return table, key_storage[:info[HM_INFO_N_VALID]]

Example — by-value insert, size unknown (must check load factor)::

    from oasislmf.pytools.common.hashmap import (
        init_dict, unpack, rehash, _try_add_key,
        i_add_key_fail, new_slot_bit, slot_mask,
        HM_INFO_N_VALID, HM_INFO_N_FULL,
    )

    @nb.jit(cache=True)
    def build_vuln_dict(vuln_ids):
        key_storage = np.empty(len(vuln_ids), dtype=np.int32)
        table = init_dict()
        info, lookup, index = unpack(table)

        for i in range(len(vuln_ids)):
            if info[HM_INFO_N_VALID] >= info[HM_INFO_N_FULL]:
                table = rehash(table, key_storage)
                info, lookup, index = unpack(table)

            result = _try_add_key(info, lookup, index, key_storage, vuln_ids[i])
            while result == i_add_key_fail:
                table = rehash(table, key_storage)
                info, lookup, index = unpack(table)
                result = _try_add_key(info, lookup, index, key_storage, vuln_ids[i])

            dense_idx = index[result & slot_mask]  # works for both new and existing

        return table, key_storage[:info[HM_INFO_N_VALID]]

Example — lookup with internal API::

    from oasislmf.pytools.common.hashmap import (
        unpack, _find_key, NOT_FOUND,
    )

    @nb.jit(cache=True)
    def lookup_vuln_ids(table, key_storage, query_ids, out):
        info, lookup, index = unpack(table)
        for i in range(len(query_ids)):
            slot = _find_key(info, lookup, index, key_storage, query_ids[i])
            if slot != NOT_FOUND:
                out[i] = index[slot]   # dense index
            else:
                out[i] = -1

Notes:
    - The hash function (fnv1a) is deterministic with no random seed. Both
      `table` and `key_storage` can be saved/loaded from binary files and
      used directly with find_key — no rebuild needed.
    - Float keys (scalar or record field) are hashed AND compared bit-wise:
      IEEE 754 bytes are reinterpreted as a same-width uint for both fnv1a
      and key_eq. Consequences:
        * +0.0 and -0.0 are distinct keys (their sign bits differ).
        * NaN equals NaN iff the bit patterns match (no `NaN != NaN` rule).
        * Two NaN values with different mantissa/sign bits are distinct keys.
      Integer/uint/bool keys keep ordinary value-equality semantics.
    - Fixed-width unicode keys ('U<n>', whether scalar or record field) are
      hashed and compared code-point-by-code-point, with trailing NUL
      padding stripped (numpy's fixed-width convention). So "ab" stored in
      'U3' hashes and compares equal to "ab" stored in 'U4'.
"""
import numpy as np
import numba as nb
from numba.extending import overload
from numba.core import types as nbt
from pandas.api.types import is_numeric_dtype


# Seed and per-field accumulation prime for the FNV-1a-style accumulation step.
[docs] init_hash = np.uint64(14695981039346656037)
[docs] FNV_PRIME = np.uint64(1099511628211)
# Murmur3 fmix64 constants — used in the final avalanche step. Single-shift # finalize (h ^= h >> 33) wasn't strong enough to mix biased low-bit inputs # (e.g. integers spaced by a power of 2), leaving most of the table unused. # Full Murmur3 fmix64 fixes this; cost is two extra multiplications per # hash call.
[docs] M3_C1 = np.uint64(0xff51afd7ed558ccd)
[docs] M3_C2 = np.uint64(0xc4ceb9fe1a85ec53)
# lookup_dtype bit layout (uint16 → 16 bits total): # top bit : 1 = slot occupied, 0 = empty # 4 bits : Robin Hood displacement index (0..15) # 11 bits : high bits of the hash (`hash_key >> hash_key_shift`) # Concretely for uint16: full_bit=0x8000, i_rh in [0x0000..0x7800] step 0x0800, # hash mask = 0x07FF. If lookup_dtype is widened, the i_rh field stays at 4 bits # and the hash field absorbs the extra width.
[docs] lookup_dtype = np.uint16
[docs] nb_lookup_dtype = nb.from_dtype(lookup_dtype)
[docs] lookup_bitsize = nb_lookup_dtype(0).itemsize * 8
[docs] lookup_hash_bitsize = lookup_bitsize - 5
[docs] hash_key_shift = 64 - lookup_hash_bitsize
[docs] full_bit = lookup_dtype(1 << (lookup_bitsize - 1)) # top bit: slot-occupied flag (uint16: 0x8000)
[docs] hash_mask = lookup_dtype((2**lookup_bitsize - 1) >> 5) # low (bitsize-5) bits: hash payload (uint16: 0x07FF)
[docs] max_rh = lookup_dtype(1 << (lookup_bitsize - 1)) # i_rh ceiling (16 increments past 0, equals full_bit numerically)
[docs] i_rh_mask = lookup_dtype(0b1111 << lookup_hash_bitsize) # 4 bits below top: Robin Hood index (uint16: 0x7800)
[docs] i_rh_increment = lookup_dtype(0b1 << lookup_hash_bitsize) # +1 step in the Robin Hood index (uint16: 0x0800)
[docs] full_rh = lookup_dtype(0b11111 << lookup_hash_bitsize) # full_bit | i_rh=15: max-poor lookval (uint16: 0xF800)
[docs] n_full_factor = 0.95
[docs] inverse_n_full_factor = 1 / n_full_factor
[docs] index_dtype = np.uint32
[docs] nb_index_dtype = nb.from_dtype(index_dtype)
# Named offsets into the 3-element info array returned by unpack().
[docs] HM_INFO_MASK = 0 # bitmask for slot indexing (table_size - 1)
[docs] HM_INFO_N_VALID = 1 # number of unique keys currently stored
[docs] HM_INFO_N_FULL = 2 # max occupancy before rehash (table_size * 0.95)
[docs] i_add_key_fail = index_dtype(np.iinfo(index_dtype).max)
# try_add_key returns (slot index) | (new_slot_bit if the key was just inserted). # Table sizes are bounded well below half the index_dtype range, so the top bit # is always free to carry the flag. `i_add_key_fail` is all-ones and is checked # *before* the new-bit test, so there is no ambiguity with a new-slot encoding.
[docs] index_bitsize = np.dtype(index_dtype).itemsize * 8
[docs] new_slot_bit = index_dtype(1 << (index_bitsize - 1))
[docs] slot_mask = index_dtype(~new_slot_bit)
# find_key returns NOT_FOUND when the key is not in the table.
[docs] NOT_FOUND = index_dtype(np.iinfo(index_dtype).max)
# --------------------------------------------------------------------------- # Packed table layout: a single uint8 buffer holding [info | lookup | index]. # # info: 3 × index_dtype values (HM_INFO_MASK, HM_INFO_N_VALID, HM_INFO_N_FULL) # lookup: table_size × lookup_dtype values # index: table_size × index_dtype values # # table_size is derived from mask (= info[HM_INFO_MASK]) + 1. All three sections are # accessed via views into the buffer, so mutations go through to the same # backing memory. # ---------------------------------------------------------------------------
[docs] LOOKUP_ITEMSIZE = nb.int64(np.dtype(lookup_dtype).itemsize)
[docs] INDEX_ITEMSIZE = nb.int64(np.dtype(index_dtype).itemsize)
[docs] INFO_N = nb.int64(3)
[docs] INFO_BYTES = INFO_N * INDEX_ITEMSIZE
# --------------------------------------------------------------------------- # Bit helpers # --------------------------------------------------------------------------- @nb.jit(nb_lookup_dtype(nb.uint64), cache=True)
[docs] def extract_hash_bit(hash_key): return hash_key >> np.uint64(hash_key_shift)
@nb.jit(nb_lookup_dtype(nb_lookup_dtype, nb_lookup_dtype, nb_lookup_dtype), cache=True)
[docs] def make_lookup_val(is_full, i_rh, hash_lookup_bit): return is_full | i_rh | hash_lookup_bit
# --------------------------------------------------------------------------- # Per-dtype specialized hash and equality. # Each of ``fnv1a`` and ``key_eq`` is defined twice: # 1. A pure-Python body (executed when the function is called outside JIT) # that dispatches on the runtime numpy dtype kind ('f', 'U', else). # 2. One or more ``@overload`` arms that numba's typing pass uses to # synthesize a typed impl per call site (records, numeric scalars, # unicode scalars). The arms do typed field accesses — no byte view, # no buffer allocation — and each compiles to specialized code. # --------------------------------------------------------------------------- def _float_as_uint(fld): """Reinterpret a numpy float scalar's IEEE 754 bytes as a same-width uint.""" return fld.view('u4' if fld.itemsize == 4 else 'u8') def _fmix64(h): """Murmur3 64-bit finalizer (fmix64) — three rounds of shift / multiply / shift. Strong avalanche: even a single-bit input change touches every output bit with ~50% probability. Eliminates the bucket skew seen with biased low-bit inputs (floats in [0,1), stride-spaced ints). Bijective, so distinct inputs always produce distinct outputs.""" h = h ^ np.uint64(h >> np.uint64(33)) h = h * M3_C1 h = h ^ np.uint64(h >> np.uint64(33)) h = h * M3_C2 h = h ^ np.uint64(h >> np.uint64(33)) return h
[docs] def fnv1a(record, h=init_hash): """FNV-1a hash of a numpy scalar or structured record, followed by a finalizer mix. Under JIT, @overload synthesizes a typed impl with the same contract. Floats are hashed via bit-cast (IEEE 754 bytes reinterpreted as uint), so +0.0 and -0.0 hash distinctly and NaN bit patterns are preserved. Fixed-width unicode scalars and record fields ('U<n>') are hashed code-point-by-code-point after stripping trailing NUL padding (numpy's fixed-width convention). The finalizer is Murmur3 fmix64 (three shift/multiply rounds; see ``_fmix64``). Without it, biased low-bit inputs (e.g. floats in [0, 1) or integers spaced by a power of 2) produce a non-uniform bucket distribution and exceed the Robin Hood probe limit. Note: plain Python ``float`` / ``int`` keys (without a numpy dtype) take the ``np.uint64(record)`` value-cast path. Production code uses numpy types throughout, so this restriction is academic.""" if hasattr(record, 'dtype') and record.dtype.names is not None: for fname in record.dtype.names: fld = record[fname] if fld.dtype.kind == 'f': fld = _float_as_uint(fld) h = (h ^ np.uint64(fld)) * FNV_PRIME elif fld.dtype.kind == 'U': for c in str(fld): h = (h ^ np.uint64(ord(c))) * FNV_PRIME else: h = (h ^ np.uint64(fld)) * FNV_PRIME return _fmix64(h) if hasattr(record, 'dtype') and record.dtype.kind == 'U': for c in str(record): h = (h ^ np.uint64(ord(c))) * FNV_PRIME return _fmix64(h) if hasattr(record, 'dtype') and record.dtype.kind == 'f': record = _float_as_uint(record) return _fmix64((h ^ np.uint64(record)) * FNV_PRIME)
@overload(fnv1a)
[docs] def fnv1a_overload_record(record, h=init_hash): if not isinstance(record, nbt.Record): return None field_names = list(record.fields) src = ["def impl(record, h=init_hash):"] for fname in field_names: ftype = record.fields[fname][0] if isinstance(ftype, nbt.Float): # bit-cast: reinterpret IEEE 754 bytes as uint64 via a 1-element buffer. # The buffer is stack-allocated by LLVM (no heap alloc) since its # size is known at compile time and it doesn't escape the function. if ftype.bitwidth == 64: src.append(" _buf = np.empty(1, dtype=np.float64)") src.append(f" _buf[0] = record['{fname}']") src.append(f" h = (h ^ _buf.view(np.uint64)[0]) * {FNV_PRIME}") else: # float32 → zero-extend to 64 bits after bit-cast to uint32 src.append(" _buf = np.empty(1, dtype=np.float32)") src.append(f" _buf[0] = record['{fname}']") src.append(f" h = (h ^ np.uint64(_buf.view(np.uint32)[0])) * {FNV_PRIME}") elif isinstance(ftype, nbt.UnicodeCharSeq): # Fixed-width unicode (e.g. 'U3') field. str(field) trims trailing # NUL padding via numba's UnicodeCharSeq.__len__ (numpy's fixed- # width convention), so "ab" stored in U3 hashes identically to # "ab" stored in U4. Hash each code point with FNV-1a. src.append(f" _s = str(record['{fname}'])") src.append(" for _i in range(len(_s)):") src.append(" h = (h ^ np.uint64(ord(_s[_i]))) * FNV_PRIME") else: # int/uint/bool: np.uint64() is already a zero-extension, equivalent to bitcast src.append( f" h = (h ^ np.uint64(record['{fname}'])) * {FNV_PRIME}" ) # Final Murmur3 fmix64 — three rounds of shift / mult / shift. See _fmix64(). src.append(" h ^= h >> np.uint64(33)") src.append(" h = h * M3_C1") src.append(" h ^= h >> np.uint64(33)") src.append(" h = h * M3_C2") src.append(" h ^= h >> np.uint64(33)") src.append(" return h") ns = {'np': np, 'init_hash': init_hash, 'FNV_PRIME': FNV_PRIME, 'M3_C1': M3_C1, 'M3_C2': M3_C2} exec("\n".join(src), ns) return ns['impl']
@overload(fnv1a)
[docs] def fnv1a_overload_scalar(key, h=init_hash): """Handles any numeric scalar: int8..int64, uint8..uint64, float32, float64, bool. Floats are bit-cast to match the record overload's behavior. Unicode scalars (``UnicodeCharSeq`` / ``UnicodeType``) are handled by ``fnv1a_overload_unichr`` (the next overload arm).""" if not isinstance(key, nbt.Number): return None if isinstance(key, nbt.Float): if key.bitwidth == 64: def impl(key, h=init_hash): _buf = np.empty(1, dtype=np.float64) _buf[0] = key h = (h ^ _buf.view(np.uint64)[0]) * FNV_PRIME h ^= h >> np.uint64(33) h = h * M3_C1 h ^= h >> np.uint64(33) h = h * M3_C2 h ^= h >> np.uint64(33) return h return impl def impl(key, h=init_hash): # float32 _buf = np.empty(1, dtype=np.float32) _buf[0] = key h = (h ^ np.uint64(_buf.view(np.uint32)[0])) * FNV_PRIME h ^= h >> np.uint64(33) h = h * M3_C1 h ^= h >> np.uint64(33) h = h * M3_C2 h ^= h >> np.uint64(33) return h return impl def impl(key, h=init_hash): h = (h ^ np.uint64(key)) * FNV_PRIME h ^= h >> np.uint64(33) h = h * M3_C1 h ^= h >> np.uint64(33) h = h * M3_C2 h ^= h >> np.uint64(33) return h return impl
@overload(fnv1a)
[docs] def fnv1a_overload_unichr(key, h=init_hash): """Handles unicode scalar keys — both fixed-width ``UnicodeCharSeq`` (e.g. when indexed inside JIT from a 'U<n>' array) and variable-length ``UnicodeType`` (e.g. when a numpy.str_ scalar is passed in from Python and lowered to a unicode string). ``str(key)`` is a no-op for UnicodeType and trims trailing NUL padding for UnicodeCharSeq, so both converge on the same hash for the same logical string.""" if not isinstance(key, (nbt.UnicodeCharSeq, nbt.UnicodeType)): return None def impl(key, h=init_hash): _s = str(key) for _i in range(len(_s)): h = (h ^ np.uint64(ord(_s[_i]))) * FNV_PRIME h ^= h >> np.uint64(33) h = h * M3_C1 h ^= h >> np.uint64(33) h = h * M3_C2 h ^= h >> np.uint64(33) return h return impl
[docs] def key_eq(a, b): """Equality of two keys for numpy scalars / structured records. Under JIT, @overload synthesizes a typed impl with the same contract. Records are compared field-by-field (NOT via ``a.tobytes() == b.tobytes()``, which would also compare any padding bytes left uninitialized by ``np.empty`` under aligned dtypes). Float fields and float scalars are bit-compared so +0.0 ≠ -0.0 and NaN == NaN when bit patterns match; int/uint/bool fields use value equality. Fixed-width unicode scalars and record fields ('U<n>') compare via numpy / numba's UnicodeCharSeq equality, which strips trailing NUL padding — matching fnv1a's hashing semantics. Note: plain Python ``float`` / ``int`` keys (without a numpy dtype) fall back to Python ``==`` semantics. Production code uses numpy types throughout, so this restriction is academic.""" if hasattr(a, 'dtype') and a.dtype.names is not None: for fname in a.dtype.names: fa, fb = a[fname], b[fname] if fa.dtype.kind == 'f': if fa.tobytes() != fb.tobytes(): return False elif fa != fb: return False return True if hasattr(a, 'dtype') and a.dtype.kind == 'f': return a.tobytes() == b.tobytes() return a == b
@overload(key_eq)
[docs] def key_eq_overload_record(a, b): if not (isinstance(a, nbt.Record) and isinstance(b, nbt.Record)): return None field_names = list(a.fields) src = ["def impl(a, b):"] for i, fname in enumerate(field_names): ftype = a.fields[fname][0] if isinstance(ftype, nbt.Float): # bit-compare: load each field into a 1-element buffer and compare # as the matching-width uint. Matches fnv1a's bit-cast hashing so # +0.0 ≠ -0.0 and NaN equals NaN iff bit patterns match. if ftype.bitwidth == 64: src.append(f" _a{i} = np.empty(1, dtype=np.float64)") src.append(f" _b{i} = np.empty(1, dtype=np.float64)") src.append(f" _a{i}[0] = a['{fname}']") src.append(f" _b{i}[0] = b['{fname}']") src.append(f" if _a{i}.view(np.uint64)[0] != _b{i}.view(np.uint64)[0]: return False") else: src.append(f" _a{i} = np.empty(1, dtype=np.float32)") src.append(f" _b{i} = np.empty(1, dtype=np.float32)") src.append(f" _a{i}[0] = a['{fname}']") src.append(f" _b{i}[0] = b['{fname}']") src.append(f" if _a{i}.view(np.uint32)[0] != _b{i}.view(np.uint32)[0]: return False") else: src.append(f" if a['{fname}'] != b['{fname}']: return False") src.append(" return True") ns = {'np': np} exec("\n".join(src), ns) return ns['impl']
@overload(key_eq)
[docs] def key_eq_overload_scalar(a, b): """Handles any numeric scalar. Floats are bit-compared to match the fnv1a hash side: +0.0 ≠ -0.0, NaN equals NaN iff bit patterns match. Unicode scalar pairs are handled by ``key_eq_overload_unichr`` (the next overload arm).""" if not (isinstance(a, nbt.Number) and isinstance(b, nbt.Number)): return None if isinstance(a, nbt.Float): if a.bitwidth == 64: def impl(a, b): _a = np.empty(1, dtype=np.float64) _b = np.empty(1, dtype=np.float64) _a[0] = a _b[0] = b return _a.view(np.uint64)[0] == _b.view(np.uint64)[0] return impl def impl(a, b): # float32 _a = np.empty(1, dtype=np.float32) _b = np.empty(1, dtype=np.float32) _a[0] = a _b[0] = b return _a.view(np.uint32)[0] == _b.view(np.uint32)[0] return impl def impl(a, b): return a == b return impl
@overload(key_eq)
[docs] def key_eq_overload_unichr(a, b): """Handles unicode scalar key pairs — any combination of ``UnicodeCharSeq`` and ``UnicodeType``. Numba's `==` for these types (``charseq_eq``) compares code-by-code after trimming trailing NUL padding on the UnicodeCharSeq side, matching fnv1a's hashing semantics.""" str_types = (nbt.UnicodeCharSeq, nbt.UnicodeType) if not (isinstance(a, str_types) and isinstance(b, str_types)): return None def impl(a, b): return a == b return impl
# --------------------------------------------------------------------------- # Packed table: unpack and creation # --------------------------------------------------------------------------- @nb.jit(cache=True, inline='always')
[docs] def unpack(table): """Extract (info, lookup_table, index_table) views from a packed buffer. info: uint8 view → .view(index_dtype) → 3-element array [HM_INFO_MASK, HM_INFO_N_VALID, HM_INFO_N_FULL] lookup: uint8 view → .view(lookup_dtype) → table_size elements index: uint8 view → .view(index_dtype) → table_size elements """ info = table[:INFO_BYTES].view(index_dtype) table_size = nb.int64(info[HM_INFO_MASK]) + nb.int64(1) lookup_end = INFO_BYTES + table_size * LOOKUP_ITEMSIZE lookup = table[INFO_BYTES:lookup_end].view(lookup_dtype) index_end = lookup_end + table_size * INDEX_ITEMSIZE index = table[lookup_end:index_end].view(index_dtype) return info, lookup, index
@nb.jit(cache=True)
[docs] def init_dict(hint_size=15): """Create a packed table buffer. Returns a single uint8 array.""" init_size = nb.int64(16) mask = index_dtype(0b1111) while init_size < (hint_size * inverse_n_full_factor): mask = index_dtype((mask << index_dtype(1)) + index_dtype(1)) init_size *= 2 total_bytes = INFO_BYTES + init_size * LOOKUP_ITEMSIZE + init_size * INDEX_ITEMSIZE table = np.zeros(total_bytes, dtype=np.uint8) info = table[:INFO_BYTES].view(index_dtype) info[HM_INFO_MASK] = mask info[HM_INFO_N_VALID] = index_dtype(0) info[HM_INFO_N_FULL] = index_dtype(init_size * n_full_factor) return table
# --------------------------------------------------------------------------- # Robin Hood hashmap operations # --------------------------------------------------------------------------- @nb.jit(cache=True) def _move_key(lookup_table, index_table, mask, i_lookup): """Find the next empty slot to the right of i_lookup and shift entries up by one slot to make room. Returns False if a key would become 'too poor' to fit (signals a rehash is needed).""" i_lookup_start = i_lookup while full_bit <= lookup_table[i_lookup & mask] < full_rh: i_lookup += index_dtype(1) if lookup_table[i_lookup & mask] >= full_rh: return False while i_lookup > i_lookup_start: lookup_table[i_lookup & mask] = ( (lookup_table[(i_lookup - index_dtype(1)) & mask]) + i_rh_increment ) index_table[i_lookup & mask] = ( index_table[(i_lookup - index_dtype(1)) & mask] ) i_lookup -= index_dtype(1) return True @nb.jit(cache=True) def _try_add_key(info, lookup_table, index_table, key_storage, key, i_item=None): """Internal: insert-or-find on unpacked views. Hot-path callers use this directly to avoid per-call unpack overhead. Two modes selected via `i_item`: i_item is None (by-value mode): On insertion, writes `key` to key_storage[info[HM_INFO_N_VALID]] and stores info[HM_INFO_N_VALID] in the index. The caller doesn't track positions; `key_storage[:info[HM_INFO_N_VALID]]` ends up with the unique keys in insertion order. i_item is an int (by-position mode): On insertion, stores `i_item` in the index. Skips the storage write — the caller is asserting that key_storage[i_item] already holds `key` (typically because key_storage IS the input array and key = key_storage[i_item]). The `i_item is None` check is a numba compile-time literal branch — both modes compile to specialized code with no runtime overhead. Args: info, lookup_table, index_table: unpacked views. key_storage: 1D array. Read for equality checks against stored keys; also written to in by-value mode. key: the key value to insert or find. i_item: None (by-value) or int (by-position). Returns: i_add_key_fail — rehash needed slot — key already present at `slot` slot | new_slot_bit — key was just inserted at `slot` """ mask = info[HM_INFO_MASK] hash_key = np.uint64(fnv1a(key)) i_lookup = index_dtype(hash_key & mask) hash_lookup_bit = extract_hash_bit(hash_key) i_rh = nb_lookup_dtype(0) i_lookup_res = i_add_key_fail while i_rh < max_rh: masked_i_lookup = i_lookup & mask lookup_val = make_lookup_val(full_bit, i_rh, hash_lookup_bit) if lookup_val < lookup_table[masked_i_lookup]: # poorer, keep probing i_rh += i_rh_increment i_lookup += index_dtype(1) elif lookup_val == lookup_table[masked_i_lookup]: if key_eq(key, key_storage[index_table[masked_i_lookup]]): i_lookup_res = masked_i_lookup # found existing break # hash bits collide but keys differ — keep probing i_rh += i_rh_increment i_lookup += index_dtype(1) else: # we're richer — displace and insert here moved = _move_key(lookup_table, index_table, mask, masked_i_lookup) if moved: if i_item is None: # by-value: write key to storage, use info[HM_INFO_N_VALID] as index key_storage[info[HM_INFO_N_VALID]] = key index_table[masked_i_lookup] = info[HM_INFO_N_VALID] else: # by-position: caller already has key at key_storage[i_item] index_table[masked_i_lookup] = i_item lookup_table[masked_i_lookup] = lookup_val info[HM_INFO_N_VALID] += index_dtype(1) i_lookup_res = masked_i_lookup | new_slot_bit # just inserted break return i_lookup_res @nb.jit(cache=True)
[docs] def try_add_key(table, key_storage, key, i_item=None): """Try to insert `key`. Caller must check for i_add_key_fail first, then mask off `new_slot_bit` to recover the slot index. See `_try_add_key` for the `i_item` mode semantics. Returns: i_add_key_fail — rehash needed slot — key already present at `slot` slot | new_slot_bit — key was just inserted at `slot` """ info, lookup_table, index_table = unpack(table) return _try_add_key(info, lookup_table, index_table, key_storage, key, i_item)
@nb.jit(cache=True) def _find_key(info, lookup_table, index_table, key_table, key): """Internal: lookup on unpacked views.""" mask = info[HM_INFO_MASK] hash_key = np.uint64(fnv1a(key)) hash_lookup_bit = extract_hash_bit(hash_key) i_lookup = index_dtype(hash_key & mask) i_rh = nb_lookup_dtype(0) while i_rh < max_rh: masked_i_lookup = i_lookup & mask lookup_val = make_lookup_val(full_bit, i_rh, hash_lookup_bit) if lookup_val < lookup_table[masked_i_lookup]: # poorer, keep probing i_rh += i_rh_increment i_lookup += index_dtype(1) elif lookup_val == lookup_table[masked_i_lookup]: if key_eq(key, key_table[index_table[masked_i_lookup]]): return masked_i_lookup # hash bits collide but keys differ — keep probing i_rh += i_rh_increment i_lookup += index_dtype(1) else: # we're richer — key can't exist past this point in Robin Hood return NOT_FOUND return NOT_FOUND @nb.jit(cache=True)
[docs] def find_key(table, key_table, key): """Look up `key` in the table without inserting. Returns the slot index on hit, or NOT_FOUND on miss. Note: takes a key *value* (not an index into key_table), since the caller typically has a raw key in hand rather than a position in key_table. """ info, lookup_table, index_table = unpack(table) return _find_key(info, lookup_table, index_table, key_table, key)
@nb.jit(cache=True)
[docs] def rehash(table, key_table): """Double the table size and re-insert every live key. Returns a new packed table buffer (the old one becomes stale).""" info, lookup_table, index_table = unpack(table) while info[HM_INFO_N_VALID] > (info[HM_INFO_N_FULL] >> np.uint8(3)): new_table_size = nb.int64(lookup_table.shape[0]) * 2 new_table = np.zeros( INFO_BYTES + new_table_size * LOOKUP_ITEMSIZE + new_table_size * INDEX_ITEMSIZE, dtype=np.uint8 ) # Write info BEFORE unpack — unpack reads mask to compute view sizes new_mask = (info[HM_INFO_MASK] << index_dtype(1)) + index_dtype(1) new_table[:INFO_BYTES].view(index_dtype)[HM_INFO_MASK] = new_mask new_table[:INFO_BYTES].view(index_dtype)[HM_INFO_N_VALID] = index_dtype(0) new_table[:INFO_BYTES].view(index_dtype)[HM_INFO_N_FULL] = index_dtype(new_table_size * n_full_factor) new_info, new_lookup, new_index = unpack(new_table) for i_lookup in range(lookup_table.shape[0]): masked_i_lookup = index_dtype(i_lookup) & info[HM_INFO_MASK] if lookup_table[masked_i_lookup] >= full_bit: i_item = index_table[masked_i_lookup] added = _try_add_key(new_info, new_lookup, new_index, key_table, key_table[i_item], i_item) if added == i_add_key_fail: break else: break # retry with doubled table — update refs for next iteration info = new_info lookup_table = new_lookup index_table = new_index table = new_table else: raise Exception("rehashed too many times") return new_table
@nb.jit(cache=True)
[docs] def jit_factorize(key_table): """Return a 1-based id per row of `key_table`; identical rows share an id. Uses _try_add_key in by-position mode (i_item passed): no extra storage allocated, no input modification. index[slot] points to the i_item of the first occurrence; agg_id is tracked via info[HM_INFO_N_VALID] (which increments to match the insertion order).""" table = init_dict() info, lookup_table, index_table = unpack(table) res = np.empty(key_table.shape[0], dtype=np.uint64) for i_item_int in range(key_table.shape[0]): i_item = index_dtype(i_item_int) # cold path: rehash if full if info[HM_INFO_N_VALID] >= info[HM_INFO_N_FULL]: table = rehash(table, key_table) info, lookup_table, index_table = unpack(table) result = _try_add_key( info, lookup_table, index_table, key_table, key_table[i_item], i_item ) # cold path: insertion failed (max_rh exceeded) — rehash and retry while result == i_add_key_fail: if info[HM_INFO_N_VALID] < (info[HM_INFO_N_FULL] >> np.uint8(3)): raise Exception("rehashed too many times") table = rehash(table, key_table) info, lookup_table, index_table = unpack(table) result = _try_add_key( info, lookup_table, index_table, key_table, key_table[i_item], i_item ) # by-position: index[slot] is i_item of first occurrence. # For new keys, that's the current i_item; for existing keys, an earlier i_item. # Either way, info[HM_INFO_N_VALID] (post-increment) is the next agg_id, and res of # the first occurrence holds the assigned agg_id. if result & new_slot_bit: # new: just inserted, info[HM_INFO_N_VALID] post-increment == this key's agg_id res[i_item] = np.uint64(info[HM_INFO_N_VALID]) else: # existing: copy agg_id from the first occurrence res[i_item] = res[index_table[result]] return res
[docs] def factorize(df): """pd.factorize-equivalent driver for a pandas DataFrame. Per-column coercion to a numpy-friendly dtype: - Nullable integer columns ('Int*') → float32 (NaN-preserving). - Other numeric columns → kept as-is. - Everything else → fixed-width unicode 'U<max_len>' via astype(str). The resulting columns are packed into a single structured array and handed to ``jit_factorize`` for grouping.""" np_arrays = {} np_dtypes = [] for _name, _dtype in df.dtypes.items(): if _dtype.name.startswith('Int'): np_dtypes.append((_name, 'f')) # convert int-with-nan to float np_arrays[_name] = df[_name].to_numpy(dtype='f') elif is_numeric_dtype(_dtype): np_dtypes.append((_name, _dtype.name)) np_arrays[_name] = df[_name].to_numpy(dtype=_dtype.name) else: _serie = df[_name].astype(str) max_str_len = int(np.max(_serie.str.len())) np_dtypes.append((_name, f'U{max_str_len}')) np_arrays[_name] = _serie.to_numpy(dtype=f'U{max_str_len}') arr = np.empty(df.shape[0], dtype=np_dtypes) for _name, _dtype in np_dtypes: arr[_name] = np_arrays[_name] return jit_factorize(arr)