#!/usr/bin/python2

import re
import math
import cairo

class ReadFile:
    def __init__(self, filename):
        self.filename = filename

    def __enter__(self):
        self.f = open(self.filename, 'r') 
        self.lineno = 0
        self.getnext();
        return self

    def __exit__(self, type, value, traceback):
        self.f.close()

    def __nonzero__(self):
        return self.line != ""

    def getnext(self):
        self.lineno += 1
        self.line = self.f.readline()

def read_times(rf):
    times = []

    while True:
        match = re.match('[+-]?[0-9]+ ', rf.line)
        if not match:
            break
        times += [int(x) for x in re.split(' ', rf.line.rstrip())]
        rf.getnext()

    return times[::-1]

class Thread:
    thread_number = 0

    def __init__(self, thread_name):
        # no one cares about the thread address
        match = re.match('(.*) \(0x.*?\) (.*)', thread_name)
        if match:
            thread_name = match.group(1) + " " + match.group(2)

        self.thread_name = thread_name
        self.thread_number = Thread.thread_number
        self.all_events = []
        self.workwait_events = []
        self.memory_events = []
        self.other_events = []
        Thread.thread_number += 1

all_events = []

class Event:
    def __init__(self, thread, gate_location, gate_name, start, stop):
        self.thread = thread
        self.gate_location = gate_location
        self.gate_name = gate_name

        self.work = False
        self.wait = False
        self.memory = False
        if gate_location == "memory":
            self.memory = True
        elif re.match('.*work.*', gate_name):
            self.work = True
        elif re.match('.*wait.*', gate_name):
            self.wait = True

        if self.memory:
            self.start = start
            self.stop = start
            self.size = stop
        else:
            self.start = start
            self.stop = stop

        thread.all_events.append(self)
        all_events.append(self)
        if self.wait or self.work:
            thread.workwait_events.append(self)
        elif self.memory:
            thread.memory_events.append(self)
        else:
            thread.other_events.append(self)

input_filename = 'vips-profile.txt'

thread_id = 0
threads = []
n_events = 0
print 'reading from', input_filename
with ReadFile(input_filename) as rf:
    while rf:
        if rf.line.rstrip() == "":
            rf.getnext()
            continue
        if rf.line[0] == "#":
            rf.getnext()
            continue

        match = re.match('thread: (.*)', rf.line)
        if not match:
            print 'parse error line %d, expected "thread"' % rf.lineno
        thread_name = match.group(1) + " " + str(thread_id)
        thread_id += 1
        thread = Thread(thread_name)
        threads.append(thread)
        rf.getnext()

        while True:
            match = re.match('^gate: (.*?)(: (.*))?$', rf.line)
            if not match:
                break
            gate_location = match.group(1)
            gate_name = match.group(3)
            rf.getnext()

            match = re.match('start:', rf.line)
            if not match:
                continue
            rf.getnext()

            start = read_times(rf)

            match = re.match('stop:', rf.line)
            if not match:
                continue
            rf.getnext()

            stop = read_times(rf)

            if len(start) != len(stop):
                print 'start and stop length mismatch'

            for a, b in zip(start, stop):
                Event(thread, gate_location, gate_name, a, b)
                n_events += 1

for thread in threads:
    thread.all_events.sort(lambda x, y: cmp(x.start, y.start))
    thread.workwait_events.sort(lambda x, y: cmp(x.start, y.start))
    thread.memory_events.sort(lambda x, y: cmp(x.start, y.start))
    thread.other_events.sort(lambda x, y: cmp(x.start, y.start))

all_events.sort(lambda x, y: cmp(x.start, y.start))

print 'loaded %d events' % n_events

# move time axis to secs of computation
ticks_per_sec = 1000000.0
first_time = all_events[0].start
last_time = 0
for event in all_events:
    if event.start < first_time:
        first_time = event.start
    if event.stop > last_time:
        last_time = event.stop

for event in all_events:
    event.start = (event.start - first_time) / ticks_per_sec
    event.stop = (event.stop - first_time) / ticks_per_sec

last_time = (last_time - first_time) / ticks_per_sec
first_time = 0

print 'total time =', last_time

# calculate some simple stats
for thread in threads:
    thread.start = last_time
    thread.stop = 0
    thread.wait = 0
    thread.work = 0
    thread.mem = 0
    thread.peak_mem = 0
    for event in thread.all_events:
        if event.start < thread.start:
            thread.start = event.start
        if event.stop > thread.stop:
            thread.stop = event.stop
        if event.wait:
            thread.wait += event.stop - event.start
        if event.work:
            thread.work += event.stop - event.start
        if event.memory:
            thread.mem += event.size
            if thread.mem > thread.peak_mem:
                thread.peak_mem = thread.mem

    thread.alive = thread.stop - thread.start

    # hide very short-lived threads 
    thread.hide = thread.alive < 0.01

print 'name\t\talive\twait%\twork%\tunkn%\tmemory\tpeakm'
for thread in threads:
    if thread.hide:
        continue

    wait_percent = 100 * thread.wait / thread.alive
    work_percent = 100 * thread.work / thread.alive
    unkn_percent = 100 - 100 * (thread.work + thread.wait) / thread.alive

    print '%13s\t%6.2g\t' % (thread.thread_name, thread.alive),
    print '%.3g\t%.3g\t%.3g\t' % (wait_percent, work_percent, unkn_percent),
    print '%.3g\t' % (float(thread.mem) / (1024 * 1024)),
    print '%.3g\t' % (float(thread.peak_mem) / (1024 * 1024))

mem = 0
peak_mem = 0
for event in all_events:
    if event.memory:
        mem += event.size
        if mem > peak_mem:
            peak_mem = mem

print 'peak memory = %.3g MB' % (float(peak_mem) / (1024 * 1024))
if mem != 0:
    print 'leak! final memory = %.3g MB' % (float(mem) / (1024 * 1024))

# does a list of events contain an overlap? 
# assume the list of events has been sorted by start time
def events_overlap(events):
    for i in range(0, len(events) - 1):
        # we can't just test for stop1 > start2 since one (or both) events 
        # might have duration zero
        event1 = events[i]
        event2 = events[i + 1]
        overlap_start = max(event1.start, event2.start)
        overlap_stop = min(event1.stop, event2.stop)
        if overlap_stop - overlap_start > 0:
            return True

    return False

# do the events on two gates overlap?
def gates_overlap(events, gate_name1, gate_name2):
    merged = []

    for event in events:
        if event.gate_name == gate_name1 or event.gate_name == gate_name2:
            merged.append(event)

    merged.sort(lambda x, y: cmp(x.start, y.start))

    return events_overlap(merged)

# allocate a y position for each gate
total_y = 0
for thread in threads:
    if thread.hide:
        continue

    thread.total_y = total_y

    gate_positions = {}

    # first pass .. move work and wait events to y == 0
    if events_overlap(thread.workwait_events):
        print 'gate overlap on thread', thread.thread_name
        for i in range(0, len(thread.workwait_events) - 1):
            event1 = thread.workwait_events[i]
            event2 = thread.workwait_events[i + 1]
            overlap_start = max(event1.start, event2.start)
            overlap_stop = min(event1.stop, event2.stop)
            if overlap_stop - overlap_start > 0:
                print 'overlap:'
                print 'event', event1.gate_location, event1.gate_name, 
                print 'starts at', event1.start, 'stops at', event1.stop
                print 'event', event2.gate_location, event2.gate_name, 
                print 'starts at', event2.start, 'stops at', event2.stop

    for event in thread.workwait_events:
        gate_positions[event.gate_name] = 0
        event.y = 0
        event.total_y = total_y

    for event in thread.memory_events:
        gate_positions[event.gate_name] = 0
        event.y = 0
        event.total_y = total_y

    # second pass: move all other events to non-overlapping ys
    y = 1
    for event in thread.other_events:
        if not event.gate_name in gate_positions:
            # look at all the ys we've allocated previously and see if we can 
            # add this gate to one of them
            for gate_y in range(1, y):
                found_overlap = False
                for gate_name in gate_positions:
                    if gate_positions[gate_name] != gate_y:
                        continue

                    if gates_overlap(thread.other_events, event.gate_name, gate_name):
                        found_overlap = True
                        break

                if not found_overlap:
                    gate_positions[event.gate_name] = gate_y
                    break

            # failure? add a new y
            if not event.gate_name in gate_positions:
                gate_positions[event.gate_name] = y
                y += 1

        event.y = gate_positions[event.gate_name]

    # third pass: flip the order of the ys to get the lowest-level ones at the
    # top, next to the wait/work line
    for event in thread.other_events:
        event.y = y - event.y
        event.total_y = total_y + event.y

    total_y += y

PIXELS_PER_SECOND = 1000
PIXELS_PER_GATE = 20
LEFT_BORDER = 130
BAR_HEIGHT = 5
MEM_HEIGHT = 100
WIDTH = int(LEFT_BORDER + last_time * PIXELS_PER_SECOND) + 20
HEIGHT = int(total_y * PIXELS_PER_GATE) + MEM_HEIGHT + 30

output_filename = "vips-profile.svg"
print 'writing to', output_filename

surface = cairo.SVGSurface(output_filename, WIDTH, HEIGHT)

ctx = cairo.Context(surface)
ctx.select_font_face('Sans')
ctx.set_font_size(15)

ctx.rectangle(0, 0, WIDTH, HEIGHT)
ctx.set_source_rgba(0.0, 0.0, 0.3, 1.0)
ctx.fill()

def draw_event(ctx, event):
    left = event.start * PIXELS_PER_SECOND + LEFT_BORDER
    top = event.total_y * PIXELS_PER_GATE + BAR_HEIGHT / 2 
    width = (event.stop - event.start) * PIXELS_PER_SECOND 
    height = BAR_HEIGHT

    if event.memory:
        width = 1
        height /= 2
        top += BAR_HEIGHT

    ctx.rectangle(left, top, width, height)

    if event.wait:
        ctx.set_source_rgb(0.9, 0.1, 0.1)
    elif event.work:
        ctx.set_source_rgb(0.1, 0.9, 0.1)
    elif event.memory:
        ctx.set_source_rgb(1.0, 1.0, 1.0)
    else:
        ctx.set_source_rgb(0.1, 0.1, 0.9)

    ctx.fill()

    if not event.wait and not event.work and not event.memory:
        xbearing, ybearing, twidth, theight, xadvance, yadvance = \
                ctx.text_extents(event.gate_name)
        ctx.move_to(left + width / 2 - twidth / 2, top + 3 * BAR_HEIGHT)
        ctx.set_source_rgb(1.00, 0.83, 0.00)
        ctx.show_text(event.gate_name)

for thread in threads:
    if thread.hide:
        continue

    ctx.rectangle(0, thread.total_y * PIXELS_PER_GATE, WIDTH, 1)
    ctx.set_source_rgb(1.00, 1.00, 1.00)
    ctx.fill()

    xbearing, ybearing, twidth, theight, xadvance, yadvance = \
            ctx.text_extents(thread.thread_name)
    ctx.move_to(0, theight + thread.total_y * PIXELS_PER_GATE + BAR_HEIGHT / 2)
    ctx.set_source_rgb(1.00, 1.00, 1.00)
    ctx.show_text(thread.thread_name)

    for event in thread.all_events:
        draw_event(ctx, event)

memory_y = total_y * PIXELS_PER_GATE

label = "memory"
xbearing, ybearing, twidth, theight, xadvance, yadvance = \
        ctx.text_extents(label)
ctx.move_to(0, memory_y + theight + 8)
ctx.set_source_rgb(1.00, 1.00, 1.00)
ctx.show_text(label)

mem = 0
ctx.move_to(LEFT_BORDER, memory_y + MEM_HEIGHT)

for event in all_events:
    if event.memory:
        mem += event.size

        left = LEFT_BORDER + event.start * PIXELS_PER_SECOND
        top = memory_y + MEM_HEIGHT - (MEM_HEIGHT * mem / peak_mem)

        ctx.line_to(left, top)

ctx.set_line_width(1)
ctx.set_source_rgb(1.00, 1.00, 1.00)
ctx.stroke()

axis_y = total_y * PIXELS_PER_GATE + MEM_HEIGHT

ctx.rectangle(LEFT_BORDER, axis_y, last_time * PIXELS_PER_SECOND, 1)
ctx.set_source_rgb(1.00, 1.00, 1.00)
ctx.fill()

label = "time"
xbearing, ybearing, twidth, theight, xadvance, yadvance = \
        ctx.text_extents(label)
ctx.move_to(0, axis_y + theight + 8)
ctx.set_source_rgb(1.00, 1.00, 1.00)
ctx.show_text(label)

for t in range(0, int(last_time * PIXELS_PER_SECOND), PIXELS_PER_SECOND / 10):
    left = t + LEFT_BORDER
    top = axis_y

    ctx.rectangle(left, top, 1, 5)
    ctx.set_source_rgb(1.00, 1.00, 1.00)
    ctx.fill()

    label = str(float(t) / PIXELS_PER_SECOND)
    xbearing, ybearing, twidth, theight, xadvance, yadvance = \
            ctx.text_extents(label)
    ctx.move_to(left - twidth / 2, top + theight + 8)
    ctx.set_source_rgb(1.00, 1.00, 1.00)
    ctx.show_text(label)

surface.finish()