# katpy/manager.py
from collections import Counter
from contextlib import ExitStack
import csv
import glob
import heapq
import logging
import numba as nb
import numpy as np
import shutil
from pathlib import Path
import tempfile
from oasislmf.pytools.common.data import write_ndarray_to_fmt_csv
from oasislmf.pytools.common.utils.nb_heapq import heap_pop, heap_push, init_heap
from oasislmf.pytools.elt.data import MELT_dtype, MELT_fmt, MELT_headers, QELT_dtype, QELT_fmt, QELT_headers, SELT_dtype, SELT_fmt, SELT_headers
from oasislmf.pytools.plt.data import MPLT_dtype, MPLT_fmt, MPLT_headers, QPLT_dtype, QPLT_fmt, QPLT_headers, SPLT_dtype, SPLT_fmt, SPLT_headers
from oasislmf.pytools.utils import redirect_logging
[docs]
logger = logging.getLogger(__name__)
[docs]
KAT_MAP = {
KAT_SELT: {
"name": "SELT",
"headers": SELT_headers,
"dtype": SELT_dtype,
"fmt": SELT_fmt,
},
KAT_MELT: {
"name": "MELT",
"headers": MELT_headers,
"dtype": MELT_dtype,
"fmt": MELT_fmt,
},
KAT_QELT: {
"name": "QELT",
"headers": QELT_headers,
"dtype": QELT_dtype,
"fmt": QELT_fmt,
},
KAT_SPLT: {
"name": "SPLT",
"headers": SPLT_headers,
"dtype": SPLT_dtype,
"fmt": SPLT_fmt,
},
KAT_MPLT: {
"name": "MPLT",
"headers": MPLT_headers,
"dtype": MPLT_dtype,
"fmt": MPLT_fmt,
},
KAT_QPLT: {
"name": "QPLT",
"headers": QPLT_headers,
"dtype": QPLT_dtype,
"fmt": QPLT_fmt,
},
}
[docs]
def check_file_extensions(file_paths):
"""Check file path extensions are all identical
Args:
file_paths (List[str | os.PathLike]): List of csv file paths.
Returns:
ext (str): file extension as a str
"""
first_ext = file_paths[0].suffix
if all(fp.suffix == first_ext for fp in file_paths):
return first_ext
raise RuntimeError("ERROR: katpy has input files with different file extensions. Make sure all input files are of the same type.")
[docs]
def csv_concat_unsorted(
stack,
file_paths,
files_with_header,
headers,
out_file,
):
"""Concats CSV files in order they are passed in.
Args:
stack (ExitStack): Exit Stack.
file_paths (List[str | os.PathLike]): List of csv file paths.
files_with_header (List[bool]): Bool list of files with header present
headers (List[str]): Headers to write
out_file (str | os.PathLike): Output Concatenated CSV file.
"""
first_header_written = False
with stack.enter_context(out_file.open("wb")) as out:
for i, fp in enumerate(file_paths):
with stack.enter_context(fp.open("rb")) as csv_file:
if files_with_header[i]:
# Read first line (header)
first_line = csv_file.readline()
# Write the expected header at the start of the file
if not first_header_written:
out.write(",".join(headers).encode() + b"\n")
first_header_written = True
shutil.copyfileobj(csv_file, out)
[docs]
def bin_concat_unsorted(
stack,
file_paths,
out_file,
):
"""Concats Binary files in order they are passed in.
Args:
stack (ExitStack): Exit Stack.
file_paths (List[str | os.PathLike]): List of bin file paths.
out_file (str | os.PathLike): Output Concatenated Binary file.
"""
with stack.enter_context(out_file.open('wb')) as out:
for fp in file_paths:
with fp.open('rb') as bin_file:
shutil.copyfileobj(bin_file, out)
@nb.njit(cache=True, error_model="numpy")
[docs]
def merge_elt_data(memmaps):
"""Merge sorted chunks using a k-way merge algorithm
Args:
memmaps (List[np.memmap]): List of temporary file memmaps
Yields:
buffer (ndarray): yields sorted buffer from memmaps
"""
min_heap = init_heap(num_compare=1)
size = 0
# Initialize the min_heap with the first row of each memmap
for i, mmap in enumerate(memmaps):
if len(mmap) > 0:
first_row = mmap[0]
min_heap, size = heap_push(min_heap, size, np.array(
[first_row["EventId"], i, 0], dtype=np.int32
))
buffer_size = 1000000
buffer = np.empty(buffer_size, dtype=memmaps[0].dtype)
bidx = 0
# Perform the k-way merge
while size > 0:
# The min heap will store the smallest row at the top when popped
element, min_heap, size = heap_pop(min_heap, size)
file_idx = element[-2]
row_num = element[-1]
smallest_row = memmaps[file_idx][row_num]
# Add to buffer and yield when full
buffer[bidx] = smallest_row
bidx += 1
if bidx >= buffer_size:
yield buffer[:bidx]
bidx = 0
# Push the next row from the same file into the heap if there are any more rows
if row_num + 1 < len(memmaps[file_idx]):
next_row = memmaps[file_idx][row_num + 1]
min_heap, size = heap_push(min_heap, size, np.array(
[next_row["EventId"], file_idx, row_num + 1], dtype=np.int32
))
yield buffer[:bidx]
@nb.njit(cache=True, error_model="numpy")
[docs]
def merge_plt_data(memmaps):
"""Merge sorted chunks using a k-way merge algorithm
Args:
memmaps (List[np.memmap]): List of temporary file memmaps
Yields:
buffer (ndarray): yields sorted buffer from memmaps
"""
min_heap = init_heap(num_compare=2)
size = 0
# Initialize the min_heap with the first row of each memmap
for i, mmap in enumerate(memmaps):
if len(mmap) > 0:
first_row = mmap[0]
min_heap, size = heap_push(min_heap, size, np.array(
[first_row["EventId"], first_row["Period"], i, 0], dtype=np.int32
))
buffer_size = 1000000
buffer = np.empty(buffer_size, dtype=memmaps[0].dtype)
bidx = 0
# Perform the k-way merge
while size > 0:
# The min heap will store the smallest row at the top when popped
element, min_heap, size = heap_pop(min_heap, size)
file_idx = element[-2]
row_num = element[-1]
smallest_row = memmaps[file_idx][row_num]
# Add to buffer and yield when full
buffer[bidx] = smallest_row
bidx += 1
if bidx >= buffer_size:
yield buffer[:bidx]
bidx = 0
# Push the next row from the same file into the heap if there are any more rows
if row_num + 1 < len(memmaps[file_idx]):
next_row = memmaps[file_idx][row_num + 1]
min_heap, size = heap_push(min_heap, size, np.array(
[next_row["EventId"], next_row["Period"], file_idx, row_num + 1], dtype=np.int32
))
yield buffer[:bidx]
[docs]
def run(
out_file,
file_type=None,
files_in=None,
dir_in=None,
concatenate_selt=False,
concatenate_melt=False,
concatenate_qelt=False,
concatenate_splt=False,
concatenate_mplt=False,
concatenate_qplt=False,
unsorted=False,
):
"""Concatenate CSV files (optionally sorted)
Args:
out_file (str | os.PathLike): Output Concatenated CSV file.
file_type (str, optional): Input file type suffix, if not discernible from input files. Defaults to None.
files_in (List[str | os.PathLike], optional): Individual CSV file paths to concatenate. Defaults to None.
dir_in (str | os.PathLike, optional): Path to the directory containing files for concatenation. Defaults to None.
concatenate_selt (bool, optional): Concatenate SELT CSV file. Defaults to False.
concatenate_melt (bool, optional): Concatenate MELT CSV file. Defaults to False.
concatenate_qelt (bool, optional): Concatenate QELT CSV file. Defaults to False.
concatenate_splt (bool, optional): Concatenate SPLT CSV file. Defaults to False.
concatenate_mplt (bool, optional): Concatenate MPLT CSV file. Defaults to False.
concatenate_qplt (bool, optional): Concatenate QPLT CSV file. Defaults to False.
unsorted (bool, optional): Do not sort by event/period ID. Defaults to False.
"""
input_files = []
# Check and add files from dir_in
if dir_in:
dir_in = Path(dir_in)
if not dir_in.exists():
raise FileNotFoundError(f"ERROR: Directory \'{dir_in}\' does not exist")
if not dir_in.is_dir():
raise ValueError(f"ERROR: \'{dir_in}\' is not a directory.")
dir_csv_input_files = glob.glob(str(dir_in / "*.csv"))
dir_bin_input_files = glob.glob(str(dir_in / "*.bin"))
if not dir_csv_input_files and not dir_bin_input_files:
logger.warning(f"Warning: No valid files found in directory \'{dir_in}\'")
input_files += [Path(file).resolve() for file in dir_csv_input_files + dir_bin_input_files]
input_files.sort()
# Check and add files from files_in
if files_in:
for file in files_in:
path = Path(file).resolve()
if not path.exists():
raise FileNotFoundError(f"ERROR: File \'{path}\' does not exist.")
if not path.is_file():
raise FileNotFoundError(f"ERROR: File \'{path}\' is not a valid file.")
input_files.append(path)
if not input_files:
raise RuntimeError("ERROR: katpy has no input CSV files to join")
out_file = Path(out_file).resolve()
input_type = check_file_extensions(input_files)
if file_type:
input_type = "." + file_type
else:
if input_type == "": # Inputs are pipes for example
raise RuntimeError("ERROR: katpy, no discernible file type suffix found from input files, please provide a file_type")
# If out_file is a csv and input_files are not csvs, then output to temporary outfile
# of type input_type, and convert to csv after
convert_to_csv = False
if out_file.suffix != input_type:
if out_file.suffix == ".csv":
csv_out_file_path = out_file
temp_file = tempfile.NamedTemporaryFile(suffix=input_type, delete=False)
out_file = Path(temp_file.name)
convert_to_csv = True
else:
raise RuntimeError(f"ERROR: katpy does not support concatenating input files of type {input_type} to output type {out_file.suffix}")
output_flags = [
concatenate_selt,
concatenate_melt,
concatenate_qelt,
concatenate_splt,
concatenate_mplt,
concatenate_qplt,
]
sort_by_event = any(output_flags[KAT_SELT:KAT_QELT + 1])
sort_by_period = any(output_flags[KAT_SPLT:KAT_QPLT + 1])
assert sort_by_event + sort_by_period == 1, "incorrect flag config in katpy"
file_type = output_flags.index(True)
with ExitStack() as stack:
if input_type == ".csv":
files_with_header, header = find_csv_with_header(stack, input_files)
headers = header.strip().split(",")
check_correct_headers(headers, file_type)
if unsorted:
csv_concat_unsorted(stack, input_files, files_with_header, headers, out_file)
elif sort_by_event:
header_idxs = get_header_idxs(headers, ["EventId"])
csv_concat_sort_by_headers(
stack,
input_files,
files_with_header,
headers,
header_idxs,
lambda values: int(values[0]),
out_file,
)
elif sort_by_period:
header_idxs = get_header_idxs(headers, ["EventId", "Period"])
csv_concat_sort_by_headers(
stack,
input_files,
files_with_header,
headers,
header_idxs,
lambda values: (int(values[0]), int(values[1])),
out_file,
)
elif input_type == ".bin":
if unsorted:
bin_concat_unsorted(stack, input_files, out_file)
elif sort_by_event:
bin_concat_sort_by_headers(
stack,
input_files,
file_type,
"elt",
out_file,
)
elif sort_by_period:
bin_concat_sort_by_headers(
stack,
input_files,
file_type,
"plt",
out_file,
)
else:
raise RuntimeError(f"ERROR: katpy, file type {input_type} not supported.")
if convert_to_csv:
data = np.memmap(out_file, dtype=KAT_MAP[file_type]["dtype"])
headers = KAT_MAP[file_type]["headers"]
fmt = KAT_MAP[file_type]["fmt"]
csv_out_file = open(csv_out_file_path, "w")
csv_out_file.write(",".join(headers) + "\n")
num_rows = data.shape[0]
buffer_size = 1000000
for start in range(0, num_rows, buffer_size):
end = min(start + buffer_size, num_rows)
buffer_data = data[start:end]
write_ndarray_to_fmt_csv(csv_out_file, buffer_data, headers, fmt)
csv_out_file.close()
@redirect_logging(exec_name='katpy')
[docs]
def main(
out=None,
file_type=None,
files_in=None,
dir_in=None,
selt=False,
melt=False,
qelt=False,
splt=False,
mplt=False,
qplt=False,
unsorted=False,
**kwargs
):
run(
out_file=out,
file_type=file_type,
files_in=files_in,
dir_in=dir_in,
concatenate_selt=selt,
concatenate_melt=melt,
concatenate_qelt=qelt,
concatenate_splt=splt,
concatenate_mplt=mplt,
concatenate_qplt=qplt,
unsorted=unsorted,
)