#!/usr/bin/env python3

import argparse

import bt2

def main(args):

    path = args.path[0]

    if args.tid < 0:
        vtids = {
            msg.event["vtid"]
            for msg in bt2.TraceCollectionMessageIterator(path)
            if type(msg) is bt2._EventMessageConst
        }
    else:
        vtids = [args.tid]

    for tid in vtids:
        analyze_trace(path, tid)

def analyze_trace(path, tid):

    send_avg  = 0
    recv_avg  = 0
    send_N    = 0
    recv_N    = 0
    send_hist = []
    recv_hist = []

    prefix = "jami:ice_transport_"

    def pkt_hist(hist, size):

        diff = size - len(hist)

        if diff >= 0:
            hist += [0] * (diff + 1)

        hist[size] += 1

    for msg in bt2.TraceCollectionMessageIterator(path):

        if type(msg) is bt2._EventMessageConst:

            name = msg.event.name

            if msg.event["vtid"] != tid:
                continue

            if not name.startswith(prefix):
                continue

            name = name[len(prefix):]

            if name == "send":

                pkt_len = msg.event["packet_length"]

                pkt_hist(send_hist, pkt_len)

                send_avg += pkt_len
                send_N   += 1

            elif name == "recv":

                pkt_len = msg.event["packet_length"]

                pkt_hist(recv_hist, pkt_len)

                recv_avg += pkt_len
                recv_N   += 1

    if send_N != 0:
        send_avg /= send_N

    if recv_N != 0:
        recv_avg /= recv_N

    if send_hist or recv_hist:
        print(f"Thread: {tid}")
        print("{:^20s}\t{:^20s}\t{:^20s}".format("Direction", "Count", "Average size (bytes)"))
        print("{:^20s}\t{:^20d}\t{:^20d}".format("Send", send_N, int(send_avg)))
        print("{:^20s}\t{:^20d}\t{:^20d}".format("Recv", recv_N, int(recv_avg)))
        print("")

        if send_hist:
            print_hist("Send", send_hist, args.log_scale)

        if recv_hist:
            print_hist("Recv", recv_hist, args.log_scale)

        print("")


def print_hist(name, bins, log_scale=False):

    total_count = sum(bins)

    if log_scale:
        next_hi = lambda x: 2 * x
    else:
        next_hi = lambda x: x + 1

    lo = 0
    hi = 1

    print('=' * 80)
    print("{:^22s}|{:^10s}|{:^45s}|".format(name, "count", "distribution"))
    print('-' * 80)

    while lo < len(bins):
        count = sum(bins[lo:hi])
        print("{:>9d} -> {:<9d}|{:^10d}|{:45s}|".format(lo, hi, count, "*" * int(45 * count / total_count)))
        lo = hi
        hi = next_hi(hi)
    print('=' * 80)

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Show ICE transport statistics.")

    parser.add_argument("path", metavar="TRACE", type=str, nargs=1)
    parser.add_argument("tid", metavar="TID", type=int, nargs='?', default=-1)
    parser.add_argument("--log-scale", action="store_true")
    parser.add_argument("--scale", type=int, default=1)

    args = parser.parse_args()

    main(args)
