Source code for oasislmf.pytools.lec.aggreports.outputs.wheatsheaf_mean
import numba as nb
import numpy as np
from oasislmf.pytools.lec.utils import get_wheatsheaf_items_idx_data
@nb.njit(cache=True, error_model="numpy")
[docs]
def get_wheatsheaf_max_count(
wheatsheaf_items,
wheatsheaf_items_start_end,
max_summary_id,
):
maxcount = np.full((max_summary_id), -1, dtype=np.int64)
for start, end in wheatsheaf_items_start_end:
summary_id = wheatsheaf_items[start]["summary_id"]
size = end - start
if size < maxcount[summary_id - 1]:
continue
maxcount[summary_id - 1] = size
return maxcount
@nb.njit(cache=True, error_model="numpy")
[docs]
def fill_wheatsheaf_mean_items(
wheatsheaf_mean_items,
wheatsheaf_items,
wheatsheaf_items_start_end,
maxcounts,
max_summary_id,
num_sidxs,
):
# Track start and end indices for each summary_id
items_start_end = np.full((max_summary_id, 2), -1, dtype=np.int32)
# Compute cumulative start indices
pos = 0
for idx in range(max_summary_id):
if maxcounts[idx] > 0:
items_start_end[idx][0] = pos # Start index
pos += maxcounts[idx]
items_start_end[idx][1] = pos # End index
for idx in range(max_summary_id * num_sidxs):
ws_start, ws_end = wheatsheaf_items_start_end[idx]
if ws_start == -1:
continue
sidx, summary_id = get_wheatsheaf_items_idx_data(idx, num_sidxs)
filtered_items = wheatsheaf_items[ws_start:ws_end]
sorted_idxs = np.argsort(filtered_items["value"])[::-1]
sorted_items = filtered_items[sorted_idxs]
wsm_start, wsm_end = items_start_end[summary_id - 1]
for i, item in enumerate(sorted_items):
# Compute position in the flat array
insert_idx = wsm_start + i
# Store values
wheatsheaf_mean_items[insert_idx]["summary_id"] = summary_id
wheatsheaf_mean_items[insert_idx]["value"] += item["value"]
return items_start_end