Source code for oasislmf.execution.load_balancer

#!/usr/bin/env python
import argparse

import select
import numpy as np
from numba import njit
import time
from io import BytesIO
import queue
import concurrent.futures

import logging
[docs] logger = logging.getLogger(__name__)
[docs] last_event_padding = b'\x00\x00\x00\x00\x00\x00\x00\x00'
[docs] number_size = 4
[docs] CHECK_STOPPER_DURATION = 1
[docs] parser = argparse.ArgumentParser()
parser.add_argument('-i', '--pipe-in', help='names of the input file_path', nargs='+') parser.add_argument('-o', '--pipe-out', help='names of the output file_path', nargs='+') parser.add_argument('-r', '--read-size', help='maximum size of chunk read from input', default=1_048_576, type=int) parser.add_argument('-w', '--write-size', help='maximum size of chunk read from input', default=1024, type=int) parser.add_argument('-q', '--queue-size', help='maximum size of the queue', default=50, type=int) parser.add_argument('-v', '--logging-level', help='logging level (debug:10, info:20, warning:30, error:40, critical:50', default=30, type=int)
[docs] class ProducerStopped(RuntimeError): pass
@njit(cache=True)
[docs] def get_next_event_index(read_buffer, last_event_index, last_event_id, max_cursor): """ try to get the index of the end of the event if found return the index and 0 to indicate it is found if not found return the index of the last item parsed and the last event id :param sub: byte array to parse :param last_item_index: last index parsed :param last_event_id: last event idea parsed (0 means no event) :return: last index parsed, last event idea parsed (0 means the chunk sub[:last_item_index] is a full event """ cursor = last_event_index while cursor < max_cursor - 4: cur_event_id = read_buffer[cursor] if last_event_id != cur_event_id: if last_event_id == 0: last_event_id = read_buffer[cursor] else: return cursor, last_event_id, 1 cursor += 2 while cursor < max_cursor - 2: sidx = read_buffer[cursor] cursor += 2 if sidx == 0: last_event_index = cursor break return last_event_index, last_event_id, 0
[docs] def produce(even_queue, event, stopper): while not stopper[0]: try: even_queue.put(event, timeout=CHECK_STOPPER_DURATION) break except queue.Full: pass else: raise ProducerStopped()
[docs] def producer(in_stream, pipeline, read_size, stopper): read_buffer = memoryview(bytearray(read_size)) buf_as_int32 = np.ndarray(read_size // number_size, buffer=read_buffer, dtype=np.int32) event_buffer = BytesIO() left_over = 0 last_event_index = 0 last_event_id = 0 wait_read_time = 0 read_input_time = 0 parse_input_time = 0 buffer_management_time = 0 while True: tw = time.time() select.select([in_stream], [], []) tr = time.time() wait_read_time += tr - tw len_read = in_stream.readinto1(read_buffer[left_over:]) read_input_time += time.time() - tr if not len_read: in_stream.close() event_buffer.write(read_buffer[:left_over]) if read_buffer[left_over - 8:left_over - 4] != b'\x00\x00\x00\x00': event_buffer.write(last_event_padding) event_buffer.seek(0) produce(pipeline, event_buffer, stopper) break valid_buf = len_read + left_over while True: tp = time.time() event_index, last_event_id, event_finished = get_next_event_index(buf_as_int32, last_event_index, last_event_id, valid_buf // number_size) tm = time.time() parse_input_time += tm - tp event_buffer.write(read_buffer[last_event_index * number_size: event_index * number_size]) if event_finished: event_buffer.seek(0) produce(pipeline, event_buffer, stopper) event_buffer = BytesIO() last_event_index = event_index last_event_id = 0 else: left_over = valid_buf - number_size * event_index read_buffer[:left_over] = read_buffer[number_size * event_index: valid_buf] last_event_index = 0 break buffer_management_time += time.time() - tm return wait_read_time, read_input_time, parse_input_time, buffer_management_time
[docs] def consumer(out_stream, pipeline, write_size, sentinel, stopper): s_tot = 0 w_tot = 0 p_tot = 0 while True: tp = time.time() event_buf = pipeline.get() p_tot += time.time() - tp if event_buf is sentinel: break else: while True: data = event_buf.read(write_size) if not data: break ts = time.time() select.select([], [out_stream], []) tw = time.time() s_tot += tw - ts try: out_stream.write(data) except Exception: stopper[0] = True raise w_tot += time.time() - tw return s_tot, w_tot, p_tot
[docs] def balance(pipe_in, pipe_out, read_size, write_size, queue_size): """ Load balance events for a list of input fil_path to a list of output fil_path :param pipe_in: list of fil_path fil_path to take as input :param pipe_out: list of fil_path fil_path to take as input :param read_size: int size of the maximum amount of Byte read from one input at a time :param queue_size: int maximum size ofthe buffer queue """ inputs = [open(p, 'rb') for p in pipe_in] outputs = [open(p, 'wb') for p in pipe_out] pipeline = queue.Queue(maxsize=queue_size) sentinel = object() stopper = np.zeros(1, dtype=np.bool) try: # check stream input header and write it to the stream output headers = set([s.read(8) for s in inputs]) if len(headers) != 1: raise Exception('input streams have different header type') header = headers.pop() [s.write(header) for s in outputs] with concurrent.futures.ThreadPoolExecutor(max_workers=len(inputs) + len(outputs)) as executor: producer_task = [executor.submit(producer, s, pipeline, read_size, stopper) for s in inputs] consumer_task = [executor.submit(consumer, s, pipeline, write_size, sentinel, stopper) for s in outputs] try: prod_t = [t.result() for t in producer_task] finally: for _ in pipe_out: pipeline.put(sentinel) wait_read_time, read_input_time, parse_input_time, buffer_management_time = 0, 0, 0, 0 for t in prod_t: wait_read_time += t[0] read_input_time += t[1] parse_input_time += t[2] buffer_management_time += t[3] cons_t = [t.result() for t in consumer_task] wait_write_time, write_output_time, wait_pipeline = 0, 0, 0 for t in cons_t: wait_write_time += t[0] write_output_time += t[1] wait_pipeline += t[2] logger.info(f""" wait_read_time = {wait_read_time}, {wait_read_time / len(inputs)} read_input_time = {read_input_time}, {read_input_time / len(inputs)} parse_input_time = {parse_input_time}, {parse_input_time / len(inputs)} buffer_management_time = {buffer_management_time}, {buffer_management_time / len(inputs)} wait_write_time = {wait_write_time}, {wait_write_time / len(outputs)} write_output_time = {write_output_time}, {write_output_time / len(outputs)} wait_pipeline = {wait_pipeline}, {wait_pipeline / len(outputs)} """) finally: [s.close() for s in inputs] [s.close() for s in outputs]
[docs] def main(): kwargs = vars(parser.parse_args()) # add handler to fm logger ch = logging.StreamHandler() formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') ch.setFormatter(formatter) logger.addHandler(ch) logging_level = kwargs.pop('logging_level') logger.setLevel(logging_level) balance(**kwargs)
if __name__ == '__main__': main()