# Copyright 2022 RnD Center "ELVEES", JSC

import argparse
from collections import namedtuple
from .task import Task
from .entry import Entry, EntryList, Interval
from .ftrace import Ftrace

import os
import sys
import re

__all__ = ["Task", "Entry", "EntryList", "Interval", "Ftrace"]

# TODO: Support thresholds filtering

MCOM03_CLK_NAME = "sdr_pll2"
SOLARIS_CLK_NAME = "quelcore_0_clkf"
DSP_FREQ = 459_000_000  # 531 MHz

TierFunction = namedtuple("TierFunction", "global_id, timestamp, name, duration")
StatisticTuple = namedtuple("StatisticTuple", "calls, total_overhead, min, avg, max")

SYSCALL_PATTERN = re.compile(
    r"""
        ^
        (?P<prefix>.+)
        :\s
        (?P<name>\S+) # syscall_name
        .+$
        """,
    re.X | re.M,
)

MMU_MAP_PATTERN = re.compile(
    r"""
    ^
    (?P<prefix>.+)
    :\svaddr_mmu_dsp=
    (?P<vaddr_mmu_dsp>.+) # vaddr_mmu_dsp
    \ssize=
    (?P<size>.+) # size
    $
    """,
    re.X | re.M,
)

DSP_UPTIME_PATTERN = re.compile(
    r"""
    ^
    (?P<prefix>.+) # elcore50_uptime
    :\sDSP\suptime:\s
    (?P<ticks>.+) # ticks
    \skernel_name:\s_
    (?P<kernel_name>.+) # kernel_name
    $
    """,
    re.X | re.M,
)


BUF_PATTERN = re.compile(
    r"""
    ^
    (?P<prefix>.+) # prefix
    :\stype:\s
    (?P<type>.+) # type
    \ssize:\s
    (?P<size>.+) # size
    $
    """,
    re.X | re.M,
)

SYNC_PATTERN = re.compile(
    r"""
    ^
    (?P<prefix>.+) # prefix
    :\ssize=
    (?P<size>.+) # size
    $
    """,
    re.X | re.M,
)

global args
global trace, kernel_total_time, duration

syscall_funcs = []
mmu_funcs = []
irq_funcs = []
dsp_uptime_funcs = []
create_buf_funcs = []
sync_buf_funcs = []
release_buf_funcs = []
kernels = []


def parse_args(argv=None):
    parser = argparse.ArgumentParser(
        usage="%(prog)s [options] ifile",
        formatter_class=argparse.RawTextHelpFormatter,
        description="Parse ftrace log returned by elcore-prof.py",
        epilog="Example:\n"
               "ftrace-parser -o /tmp/result /tmp/xyram.ftrace"
    )
    parser.add_argument("ifile", help="Source file with ftrace log")
    parser.add_argument(
        "-o",
        "--output",
        type=argparse.FileType("w"),
        default=sys.stdout,
        help="Write profiling results to file. Default: stdout",
    )
    return parser.parse_args(argv)


class Function:
    def __init__(self, func_tuple, list):
        self.func = func_tuple
        self.list = list
        self.event = None
        self.next = self.getNextTierFunctions()
        self.name = ""
        if len(self.next) == 1 and self.list[self.func.global_id].depth == 1:
            self.func = TierFunction(
                self.func.global_id,
                self.func.timestamp,
                self.func.name + "({})".format(self.next[0].func.name),
                self.func.duration,
            )

    def getNextTierFunctions(self):
        list = []
        current_entry = self.list[self.func.global_id]
        current_tier = current_entry.depth
        # Find global_id function's start
        id = self.func.global_id - 1
        while id >= 0:
            depth = self.list[id].depth
            pid = self.list[id].pid
            if (depth == current_tier) and (pid == current_entry.pid):
                break
            id = id - 1

        id = id + 1
        while id < self.func.global_id:
            entry = self.list[id]
            if (entry.depth == current_tier + 1) and (entry.pid == current_entry.pid):
                if entry.event is not None:
                    self.event = entry.event
                    if "elcore50_syscall" in self.event:
                        syscall_funcs.append(self)
                    elif "elcore50_mmu_map" in self.event:
                        mmu_funcs.append(self)
                    elif "elcore50_uptime" in self.event:
                        dsp_uptime_funcs.append(self)
                    elif "elcore50_buf_create" in self.event:
                        create_buf_funcs.append(self)
                    elif "elcore50_buf_release" in self.event:
                        release_buf_funcs.append(self)
                    elif "elcore50_buf_sync" in self.event:
                        sync_buf_funcs.append(self)
                else:
                    func = TierFunction(
                        id, entry.timestamp, entry.function, entry.duration
                    )
                    list.append(Function(func, self.list))
                    if "elcore50_irq" in func.name:
                        irq_funcs.append(func)
                    elif "elcore50_enqueue" in func.name:
                        kernels.append(func)
                    elif "elcore50_job_inst_release" in func.name:
                        kernels.append(func)
                    elif "elcore50_job_inst_run" in func.name:
                        kernels.append(func)

            id = id + 1
        list.sort(key=lambda func: func.func[3], reverse=True)
        return list

    def output_recursive(self):
        current_tier = self.list[self.func.global_id].depth
        output = "{:<50} {:.4f} us ({:.2f}%)".format(
            "-> " * (current_tier - 1) + f"{self.func.name}:",
            self.func.duration,
            self.func.duration * 100 / kernel_total_time,
        )
        if self.event is not None:
            output += " {}".format(self.event)
        for next_func in self.next:
            output += "\n" + next_func.output_recursive()
        return output

    def __repr__(self):
        return self.output_recursive()


class Tier1FunctionList(list):
    def __init__(self, list):
        self.list = list
        level = 1
        i = 0
        for entry in trace.entries:
            if entry.depth == level:
                global_id = trace.entries.index(entry)
                tierfunc = TierFunction(
                    global_id, entry.timestamp, entry.function, entry.duration
                )
                self.append(Function(tierfunc, self.list))
                if "elcore50_irq" in tierfunc.name:
                    irq_funcs.append(tierfunc)
                elif "elcore50_enqueue_job_inst" in tierfunc.name:
                    kernels.append(tierfunc)
                elif "elcore50_job_inst_release" in tierfunc.name:
                    kernels.append(tierfunc)
                elif "elcore50_job_inst_run" in tierfunc.name:
                    kernels.append(tierfunc)
                i = i + 1
        self.sort(key=lambda func: func.func[3], reverse=True)


def syscall_trace(syscalls):
    syscall_dict = dict()
    for func in syscalls:
        duration = func.func.duration
        match = re.match(SYSCALL_PATTERN, func.event)
        if match is None:
            continue
        match_dict = match.groupdict()
        syscall_name = match_dict["name"]
        stat = StatisticTuple(0, 0, sys.maxsize, 0, -1)
        if syscall_name in syscall_dict.keys():
            stat = syscall_dict[syscall_name]

        calls = stat.calls + 1
        total_overhead = stat.total_overhead + duration
        minv = min(stat.min, duration)
        maxv = max(stat.max, duration)
        avg = total_overhead / calls
        stat = StatisticTuple(calls, total_overhead, minv, avg, maxv)
        syscall_dict[syscall_name] = stat

    if len(syscall_dict):
        print("Syscall statistics info:", file=args.output)
    for syscall_name in syscall_dict:
        stat = syscall_dict[syscall_name]
        calls = stat.calls
        total = stat.total_overhead
        minv = stat.min
        maxv = stat.max
        avg = stat.avg
        print(
            f"{syscall_name}: calls={calls} overheads(us)=(total:{round(total, 1)} "
            f"min:{round(minv, 1)} max:{round(maxv, 1)} avg:{round(avg, 1)})",
            file=args.output,
        )


def kernel_trace(enqueue, run, release):
    global DSP_FREQ

    time_in_queue = (run.timestamp - enqueue.timestamp) * 1e6 - enqueue.duration
    total_time = (run.timestamp - enqueue.timestamp) * 1e6 + run.duration
    # Workaround for bug: leaf timestamp sometimes points to end but not start function
    if time_in_queue < 0:
        time_in_queue += enqueue.duration
        total_time += enqueue.duration
    syscalls_per_kernel = [
        syscall
        for syscall in syscall_funcs
        if syscall.func.timestamp < release.timestamp
        and syscall.func.timestamp > run.timestamp
    ]

    gl_run = [
        event for event in dsp_uptime_funcs if event.func.global_id == run.global_id
    ][0]
    match = re.match(DSP_UPTIME_PATTERN, gl_run.event)
    if match is None:
        raise Exception("Failed to parse uptime event")
    match_dict = match.groupdict()
    kernel_name = match_dict["kernel_name"]
    dsp_uptime = int(match_dict["ticks"]) * 1e6 / DSP_FREQ
    kernel_overhead = total_time - dsp_uptime

    func_dict = dict()
    for func in gl_run.next:
        tmp = 0
        if func.func.name in func_dict.keys():
            tmp = func_dict[func.func.name]
        func_dict[func.func.name] = tmp + func.func.duration
    func_dict["event_handler"] = round(func_dict["event_handler"] - dsp_uptime, 2)

    print(f"Function name: '{kernel_name}'", file=args.output)
    print(f"Total time: {round(total_time)} us", file=args.output)
    print(f"Time spend in the queue: {round(time_in_queue, 2)} us", file=args.output)
    print(f"Job_inst_run elapsed time: {round(run.duration, 2)} us", file=args.output)
    print(
        f"DSP elapsed time: {round(dsp_uptime, 2)} us "
        f"({round(dsp_uptime * 100 / total_time, 2)} %)",
        file=args.output,
    )
    print(
        f"Kernel overhead: {round(kernel_overhead, 2)} us "
        f"({round(kernel_overhead * 100 / total_time, 2)} %)",
        file=args.output,
    )
    for func_name in func_dict:
        overhead = round(func_dict[func_name], 2)
        percent = round(func_dict[func_name] * 100 / run.duration, 2)
        print(f"Function {func_name} overhead {overhead} us ({percent} %)", file=args.output)
    syscall_trace(syscalls_per_kernel)
    pass


def main():
    global trace, kernel_total_time, duration, DSP_FREQ, args

    args = parse_args()

    trace = Ftrace(args.ifile)

    clk_file = None
    if os.path.exists(f"/sys/kernel/debug/clk/{MCOM03_CLK_NAME}/clk_rate"):
        clk_file = f"/sys/kernel/debug/clk/{MCOM03_CLK_NAME}/clk_rate"
    elif os.path.exists(f"/sys/kernel/debug/clk/{SOLARIS_CLK_NAME}/clk_rate"):
        clk_file = f"/sys/kernel/debug/clk/{SOLARIS_CLK_NAME}/clk_rate"

    if clk_file is not None:
        with open(clk_file, "r") as f:
            DSP_FREQ = int(f.read())

    # Calculate kernel running time
    kernel_total_time = 0
    for entry in trace.entries:
        if entry.depth == 1:
            kernel_total_time += entry.duration

    duration = trace.interval.duration * 1e6  # us

    tier1funcs = Tier1FunctionList(trace.entries)
    tier1funcs.sort(key=lambda func: func.func[1])

    nkernels = len(
        [func for func in tier1funcs if func.func.name == "elcore50_job_inst_run"]
    )

    if len(create_buf_funcs):
        print("Create buffers info:", file=args.output)
    for func in create_buf_funcs:
        duration = func.func.duration
        match = re.match(BUF_PATTERN, func.event)
        if match is None:
            continue
        match_dict = match.groupdict()
        print(
            "{} buffer 0x{} bytes created per {} us ({} MB/s)".format(
                match_dict["type"],
                match_dict["size"],
                duration,
                round(float(int(match_dict["size"], 16)) / (duration), 1),
            ),
            file=args.output,
        )

    if len(release_buf_funcs):
        print("\nRelease buffers info:", file=args.output)
    for func in release_buf_funcs:
        duration = func.func.duration
        match = re.match(BUF_PATTERN, func.event)
        if match is None:
            continue
        match_dict = match.groupdict()
        print(
            "{} buffer 0x{} bytes released per {} us ({} MB/s)".format(
                match_dict["type"],
                match_dict["size"],
                duration,
                round(float(int(match_dict["size"], 16)) / (duration), 1),
            ),
            file=args.output,
        )

    if len(mmu_funcs):
        print("\nMMU buffers info:", file=args.output)
    for func in mmu_funcs:
        duration = func.func.duration
        match = re.match(MMU_MAP_PATTERN, func.event)
        if match is None:
            continue
        match_dict = match.groupdict()
        print(
            "buffer 0x{} bytes mapped per {} us ({} GB/s)".format(
                match_dict["size"],
                duration,
                round(float(int(match_dict["size"], 16)) / (duration), 1),
            ),
            file=args.output,
        )

    if len(sync_buf_funcs):
        print("\nSync buffers info:", file=args.output)
    for func in sync_buf_funcs:
        duration = func.func.duration
        match = re.match(SYNC_PATTERN, func.event)
        if match is None:
            continue
        match_dict = match.groupdict()
        print(
            "buffer 0x{} bytes synced per {} us ({} GB/s)".format(
                match_dict["size"],
                duration,
                round(float(int(match_dict["size"], 16)) / (duration), 1),
            ),
            file=args.output,
        )

    # Process irq overheads
    irq_number = len(irq_funcs)
    irq_overhead_max = round(max(irq_funcs, key=lambda item: item.duration).duration, 1)
    irq_overhead_min = round(min(irq_funcs, key=lambda item: item.duration).duration, 1)
    irq_overhead_total = round(
        sum(map(lambda item: float(item.duration), irq_funcs)), 1
    )
    irq_overhead_avg = round(irq_overhead_total / irq_number, 1)
    print("\nIRQ statistics info:", file=args.output)
    print(
        f"IRQS: number={irq_number} overheads(us)=(total:{irq_overhead_total}"
        f"min:{irq_overhead_min} max:{irq_overhead_max} avg:{irq_overhead_avg})",
        file=args.output,
    )

    print(file=args.output)

    # Job instances tracing
    for i in range(nkernels):
        print("Job instance {} info:".format(i), file=args.output)
        kernel_trace(kernels[3 * i], kernels[3 * i + 1], kernels[3 * i + 2])
        print(file=args.output)


if __name__ == "__main__":
    main()
