#!/usr/bin/env python3

import collections
import statistics
import subprocess
import sys
import threading
import time

import psutil


INTERVAL = 0.1

memory = []


def worker():
    while True:
        current_process = psutil.Process()
        children = current_process.children(recursive=True)
        totals = collections.defaultdict(int)
        for child in children:
            for k, v in child.memory_info()._asdict().items():
                totals[k] += v
        memory.append(totals)
        time.sleep(INTERVAL)


def format_size(x):
    return f"{x//1024//1024}M".rjust(7)


threading.Thread(target=worker, daemon=True).start()
process = subprocess.run(sys.argv[1:])

if memory:
    keys = memory[0].keys()
    print(file=sys.stderr)
    for key in keys:
        data = [x[key] for x in memory]
        print(
            key,
            *map(format_size, [min(data), *map(int, statistics.quantiles(data, n=10)), max(data)]),
            sep="\t",
            file=sys.stderr,
        )

sys.exit(process.returncode)
