Source code for oasislmf.warmup

#!/usr/bin/env python3
"""Warm the Numba JIT cache for all pytools modules.

Runs each tool's compilation in a parallel process pool so that all ~191
Numba-compiled functions are cached in __pycache__ before the first real
model run.  This eliminates the 163-365 s cold-start overhead.

Test assets are bundled in oasislmf/_data/warmup/ (~400 KB) so this works
both from a source checkout and after ``pip install oasislmf`` from PyPI.

Usage (CLI):
    oasislmf warmup

Usage (standalone):
    python -m oasislmf.warmup

Usage (Dockerfile):
    RUN pip install oasislmf && oasislmf warmup

Usage (pytest):
    pytest tests/pytools/test_jit_compilation.py::test_jit_compile_all -v

Disable (when using NUMBA_DISABLE_JIT):
    NUMBA_DISABLE_JIT=1 oasislmf warmup
"""

import contextlib
import functools
import os
import shutil
import subprocess
import sys
import traceback
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from tempfile import TemporaryDirectory


# ---------------------------------------------------------------------------
# Locate bundled warmup assets inside the installed package.
# ---------------------------------------------------------------------------

def _get_data_dir():
    """Return path to oasislmf/_data/warmup/, whether installed or in-tree."""
    return Path(__file__).resolve().parent / "_data" / "warmup"


_DATA_DIR = _get_data_dir()


[docs] occurrence_rel_path = Path("input", "occurrence.bin")
[docs] periods_rel_path = Path("input", "periods.bin")
[docs] returnperiods_rel_path = Path("input", "returnperiods.bin")
[docs] quantile_rel_path = Path("input", "quantile.bin")
[docs] correlations_rel_path = Path("input", "correlations.bin")
[docs] coverages_rel_path = Path("input", "coverages.bin")
[docs] events_rel_path = Path("input", "events.bin")
[docs] items_rel_path = Path("input", "items.bin")
[docs] amplifications_rel_path = Path("input", "amplifications.bin")
[docs] damage_bin_dict_rel_path = Path("static", "damage_bin_dict.bin")
[docs] footprint_rel_path = Path("static", "footprint.bin")
[docs] footprint_idx_rel_path = Path("static", "footprint.idx")
[docs] vulnerability_rel_path = Path("static", "vulnerability.bin")
[docs] lossfactors_rel_path = Path("static", "lossfactors.bin")
[docs] summary_rel_path = Path("work", "gul", "summarypy.bin")
def _copy_rel_path(src, dest): os.makedirs(Path(dest).parent, exist_ok=True) shutil.copyfile(src, dest) # --------------------------------------------------------------------------- # Worker functions — top-level for pickling by ProcessPoolExecutor. # Each does all imports locally so child processes start clean. # --------------------------------------------------------------------------- def _compile_fmpy(): """FM pipeline — normal + stepped calcrules (sequential to avoid cache races).""" from oasislmf.computation.run.exposure import RunExposure for subdir, perils in [ ("fmpy/Q1_1", ["WTC"]), ("fmpy/fm54", ["WTC"]), ]: src = _DATA_DIR / subdir if not src.exists(): continue with TemporaryDirectory() as tmpdir: RunExposure( src_dir=str(src), run_dir=tmpdir, loss_factor=[1.0], output_level='port', output_file=str(Path(tmpdir) / "loc_summary.csv"), fmpy_sort_output=True, kernel_alloc_rule_il=2, kernel_alloc_rule_ri=2, intermediary_csv=True, model_perils_covered=perils, ).run() def _run_stage(cmd, stdin_path=None, stdout_path=None, cwd=None, timeout=300): """Run a single subprocess stage, reading/writing via files rather than pipes. Using files instead of shell pipes (cmd1 | cmd2) means each process runs to completion — and commits all Numba .nbi cache writes — before the next stage starts. Shell pipes run all stages concurrently, causing races on the shared .nbi index that silently drop type variants from the cache. """ with contextlib.ExitStack() as stack: stdin_fh = stack.enter_context(open(stdin_path, 'rb')) if stdin_path else subprocess.DEVNULL stdout_fh = stack.enter_context(open(stdout_path, 'wb')) if stdout_path else subprocess.DEVNULL result = subprocess.run( cmd, stdin=stdin_fh, stdout=stdout_fh, stderr=subprocess.PIPE, cwd=str(cwd) if cwd is not None else None, timeout=timeout, ) if result.returncode != 0: raise RuntimeError( f"{cmd[0]} failed (rc={result.returncode}):\n" f"stderr: {result.stderr.decode()}" ) def _compile_modelpy_gulpy_gulmc(): """modelpy + gulpy + gulmc — each stage run sequentially via temp files. Shell pipes (evepy | modelpy | gulpy) run all three processes concurrently. modelpy and gulpy both JIT-compile shared functions (e.g. mv_read) and race to update the same .nbi cache index, silently dropping type variants. Running each stage sequentially via _run_stage eliminates the race: every process finishes and flushes its cache writes before the next one starts. modelpy output is captured once and reused for both gulpy and gulmc. """ with TemporaryDirectory() as tmpdir: workspace = Path(tmpdir) / "workspace" needed_files = [correlations_rel_path, coverages_rel_path, events_rel_path, items_rel_path, damage_bin_dict_rel_path, footprint_rel_path, footprint_idx_rel_path, vulnerability_rel_path] for rel_path in needed_files: _copy_rel_path((_DATA_DIR / rel_path), workspace / rel_path) evepy_out = Path(tmpdir) / "evepy_out.bin" modelpy_out = Path(tmpdir) / "modelpy_out.bin" # Stage 1: generate events _run_stage(["evepy", "1", "1"], stdout_path=evepy_out, cwd=workspace) # Stage 2: modelpy — fully complete before gulpy/gulmc start so the # modelpy JIT cache is stable when the next stage also uses it _run_stage(["modelpy"], stdin_path=evepy_out, stdout_path=modelpy_out, cwd=workspace) # Stage 3a: gulpy — reads the captured modelpy output from file _run_stage( ["gulpy", "-a1", "-S1", "-L0", "--random-generator=2"], stdin_path=modelpy_out, cwd=workspace, ) # Stage 3b: gulmc — sequential after gulpy, reuses same modelpy output _run_stage( ["gulmc", "-a1", "-S1", "-L0", "--ignore-correlation", "--random-generator=2"], stdin_path=modelpy_out, cwd=workspace, ) def _compile_summarypy(): """summarypy manager on single_summary_set.""" from oasislmf.pytools.summary.cli import manager with TemporaryDirectory() as tmpdir: manager.main( create_summarypy_files=False, low_memory=True, output_zeros=False, static_path=_DATA_DIR / "summarypy" / "single_summary_set", run_type="gul", files_in=[_DATA_DIR / "input" / "gul.bin"], summary_sets_output=["-1", str(Path(tmpdir) / 'gul_S1_summary.bin')] ) def _compile_eltpy(): """eltpy manager — event loss table.""" import numpy as np from unittest.mock import patch from oasislmf.pytools.common.data import oasis_int, oasis_float from oasislmf.pytools.elt.manager import main as elt_main with TemporaryDirectory() as tmpdir: out_file = Path(tmpdir) / "selt.csv" with patch( 'oasislmf.pytools.elt.manager.read_event_rates', return_value=(np.array([], dtype=oasis_int), np.array([], dtype=oasis_float)) ): elt_main( run_dir=Path(tmpdir), files_in=_DATA_DIR / summary_rel_path, ext="csv", selt=out_file, ) def _compile_pltpy(): """pltpy manager — period loss table (with occurrence for occ JIT).""" from oasislmf.pytools.plt.manager import main as plt_main with TemporaryDirectory() as tmpdir: out_file = Path(tmpdir) / "splt.csv" plt_main( run_dir=_DATA_DIR, files_in=_DATA_DIR / summary_rel_path, ext="csv", splt=out_file, ) def _compile_aalpy(): """aalpy manager — annual aggregate loss.""" from oasislmf.pytools.aal.manager import main as aal_main with TemporaryDirectory() as tmpdir: workspace = Path(tmpdir) / "workspace" for rel_path in [occurrence_rel_path, summary_rel_path]: _copy_rel_path((_DATA_DIR / rel_path), workspace / rel_path) out_dir = workspace / "out" out_dir.mkdir() out_file = out_dir / "aal.csv" aal_main( run_dir=workspace, subfolder="gul", aal=out_file, ext="csv", meanonly=False, ) def _compile_lecpy(): """lecpy manager — all 8 report flags for max JIT coverage.""" from oasislmf.pytools.lec.manager import main as lec_main with TemporaryDirectory() as tmpdir: workspace = Path(tmpdir) / "workspace" needed_files = [occurrence_rel_path, periods_rel_path, returnperiods_rel_path, summary_rel_path] for rel_path in needed_files: _copy_rel_path((_DATA_DIR / rel_path), workspace / rel_path) out_dir = workspace / "out" out_dir.mkdir() ept_file = out_dir / "ept.csv" psept_file = out_dir / "psept.csv" lec_main( run_dir=workspace, subfolder="gul", use_return_period=True, agg_full_uncertainty=True, agg_wheatsheaf=True, agg_sample_mean=True, agg_wheatsheaf_mean=True, occ_full_uncertainty=True, occ_wheatsheaf=True, occ_sample_mean=True, occ_wheatsheaf_mean=True, ept=ept_file, psept=psept_file, ext="csv", ) def _compile_katpy(): """katpy manager — sorted mode for nb_heapq JIT.""" from oasislmf.pytools.kat.manager import main as kat_main with TemporaryDirectory() as tmpdir: out_file = Path(tmpdir) / "katpy_qplt.csv" kat_main( dir_in=_DATA_DIR / "katpy", qplt=True, out=out_file, unsorted=False, ) def _compile_plapy(): """plapy manager — post-loss amplification.""" from tempfile import NamedTemporaryFile from oasislmf.pytools.pla.manager import run with TemporaryDirectory() as tmpdir: workspace = Path(tmpdir) / "workspace" for rel_path in [amplifications_rel_path, lossfactors_rel_path]: _copy_rel_path((_DATA_DIR / rel_path), workspace / rel_path) with NamedTemporaryFile(prefix='pla', dir=str(tmpdir)) as pla_out: run( run_dir=str(workspace), file_in=str(_DATA_DIR / "input" / "gul.bin"), file_out=pla_out.name, input_path='input', static_path='static', secondary_factor=1, uniform_factor=0 ) # --------------------------------------------------------------------------- # Silence helper — suppresses all logging and stdout/stderr in worker processes. # --------------------------------------------------------------------------- class _silence: """Context manager that suppresses all logging and stdout/stderr. Safe to use both in worker processes and in the main process — restores original state on exit. """ def __enter__(self): import logging self._prev_disable = logging.root.manager.disable self._prev_stdout = sys.stdout self._prev_stderr = sys.stderr logging.disable(logging.CRITICAL) self._devnull = open(os.devnull, 'w') sys.stdout = self._devnull sys.stderr = self._devnull return self def __exit__(self, *exc): import logging sys.stdout = self._prev_stdout sys.stderr = self._prev_stderr logging.disable(self._prev_disable) self._devnull.close() return False def _silence_func(func): """silence decorator""" @functools.wraps(func) def silenced_func(*args, **kwargs): with _silence(): return func(*args, **kwargs) return silenced_func def _make_silent(fn, name): """Wrap fn with _silence_func and fix __name__/__qualname__ for pickling.""" silent = _silence_func(fn) silent.__name__ = name silent.__qualname__ = name return silent _compile_fmpy_silent = _make_silent(_compile_fmpy, '_compile_fmpy_silent') _compile_modelpy_gulpy_gulmc_silent = _make_silent(_compile_modelpy_gulpy_gulmc, '_compile_modelpy_gulpy_gulmc_silent') _compile_lecpy_silent = _make_silent(_compile_lecpy, '_compile_lecpy_silent') _compile_aalpy_silent = _make_silent(_compile_aalpy, '_compile_aalpy_silent') _compile_eltpy_silent = _make_silent(_compile_eltpy, '_compile_eltpy_silent') _compile_pltpy_silent = _make_silent(_compile_pltpy, '_compile_pltpy_silent') _compile_katpy_silent = _make_silent(_compile_katpy, '_compile_katpy_silent') _compile_summarypy_silent = _make_silent(_compile_summarypy, '_compile_summarypy_silent') _compile_plapy_silent = _make_silent(_compile_plapy, '_compile_plapy_silent') # --------------------------------------------------------------------------- # Task registry — ordered heaviest-first so the pool starts slow tasks early. # ---------------------------------------------------------------------------
[docs] ALL_SILENT_TASKS = { "fmpy": _compile_fmpy_silent, "modelpy_gulpy_gulmc": _compile_modelpy_gulpy_gulmc_silent, "lecpy": _compile_lecpy_silent, "aalpy": _compile_aalpy_silent, "eltpy": _compile_eltpy_silent, "pltpy": _compile_pltpy_silent, "katpy": _compile_katpy_silent, "summarypy": _compile_summarypy_silent, "plapy": _compile_plapy_silent, }
[docs] def warmup(max_workers=None): """Run all JIT compilations in parallel. Args: max_workers: Max parallel processes. Defaults to cpu_count. Returns: dict of task_name -> error for any failures (empty on success). """ if not _DATA_DIR.is_dir(): print(f" Warmup assets not found at {_DATA_DIR} — skipping.") return {} if max_workers is None: max_workers = min(os.cpu_count() or 4, len(ALL_SILENT_TASKS)) try: from tqdm import tqdm has_tqdm = True except ImportError: has_tqdm = False errors = {} with ProcessPoolExecutor(max_workers=max_workers) as pool: futures = {pool.submit(fn): name for name, fn in ALL_SILENT_TASKS.items()} pbar = tqdm(total=len(futures), desc="warmup", unit="task", bar_format="{desc}: {bar} {n_fmt}/{total_fmt} [{elapsed}<{remaining}]") if has_tqdm else None for future in as_completed(futures): name = futures[future] try: future.result() except Exception as e: errors[name] = e if pbar: pbar.set_postfix_str(name) pbar.update(1) if pbar: pbar.close() return errors
[docs] def main(): if os.environ.get("NUMBA_DISABLE_JIT") == "1": print("NUMBA_DISABLE_JIT=1 — skipping JIT warmup.") return print(f"Warming Numba JIT cache ({len(ALL_SILENT_TASKS)} tasks, " f"max_workers={os.cpu_count()}) ...") errors = warmup() if errors: print(f"\nFAILED — {len(errors)} task(s):", file=sys.stderr) for name, err in errors.items(): print(f" {name}:", file=sys.stderr) traceback.print_exception(type(err), err, err.__traceback__, file=sys.stderr) sys.exit(1) else: pkg_root = Path(__file__).resolve().parent cache_count = sum(1 for _ in pkg_root.rglob("*.nbi")) print(f"Done — {cache_count} Numba cache files written.")
if __name__ == "__main__": main()