#!/usr/bin/python

"""
Hacky script to build OSM user statistics.

Take a bz2 compressed osm file and builds statistik over it.
If PyQt4 is available, also generate a map with each user's
average mapping center.

Released to Public Domain by Andreas Stricker <andy@knitter.ch>
"""

import xml.parsers.expat
import bz2
import time
import sys
import math

try:
    import PyQt4
    import PyQt4.QtCore as core
    import PyQt4.QtGui as gui
    import PyQt4.QtSvg as svg
    Qt = PyQt4.QtCore.Qt
    hasQt = True
except ImportError:
    hasQt = False

class OSMParser(object):
    def __init__(self):
        self.parser = xml.parsers.expat.ParserCreate()
        self.parser.StartElementHandler = self._startElement
        self.users = dict()
        self.minLat = 360
        self.maxLat = -360
        self.minLon = 360
        self.maxLon = -360

    def parseBZ2File(self, filename):
        f = bz2.BZ2File(filename, 'r')
        self.parser.ParseFile(f)

    def _startElement(self, name, attrs):
        if name == 'node':
            if 'user' in attrs:
                user = attrs['user']
                if not user in self.users:
                    self.users[user] = {
                        'name':user,
                        'first':None,
                        'last':None,
                        'latsum':0,
                        'lonsum':0,
                        'count':0
                    }
                u = self.users[user]
                u['count'] += 1
                lat = float(attrs['lat'])
                u['latsum'] += lat
                lon = float(attrs['lon'])
                u['lonsum'] += lon

                if self.minLat > lat:
                    self.minLat = lat
                if self.maxLat < lat:
                    self.maxLat = lat
                if self.minLon > lon:
                    self.minLon = lon
                if self.maxLon < lon:
                    self.maxLon = lon
                
                if 'timestamp' in attrs:
                    ts = attrs['timestamp']
                    ts = time.mktime(time.strptime(ts.split('T', 1)[0], '%Y-%m-%d'))
                    if (u['first'] is None) or (u['first'] > ts):
                        u['first'] = ts
                    if (u['last'] is None) or (u['last'] < ts):
                        u['last'] = ts

    def stat(self, stream):
        for name, u in self.users.iteritems():
            try:
                line = (u"%s, %d, %s, %s, %g, %g\n" % (
                    u['name'],
                    u['count'],
                    time.strftime('%Y-%m-%d', time.gmtime(u['first'])),
                    time.strftime('%Y-%m-%d', time.gmtime(u['last'])),
                    u['latsum'] / u['count'],
                    u['lonsum'] / u['count']))
                stream.write(line.encode('UTF-8', 'replace'))
            except UnicodeEncodeError, e:
                print >>sys.stderr, "Error:", str(e), 'at', repr(name)

class SvgGenerator(object):
    def __init__(self):
        self.svg = svg.QSvgGenerator()
        self.svg.setFileName("userstat.svg")
        self.svg.setSize(core.QSize(1200, 800))

    def exportSvg(self, p):
        self.painter = gui.QPainter()
        self.painter.begin(self.svg)

        rect = core.QRectF(core.QPointF(p.minLon, p.minLat),
                           core.QPointF(p.maxLon, p.maxLat))
        self.matrix = gui.QMatrix()
        # flip x and y and move to +x +y
        self.matrix.scale(1, -1)
        self.matrix.translate(0, -self.svg.size().height())

        # scale and translate to convert lat/lon to screen coordinates
        pt = rect.topLeft() * -1
        self.matrix.scale(self.svg.size().width()/rect.width(),
                          self.svg.size().height()/rect.height())
        self.matrix.translate(pt.x(), pt.y())

        self.painter.pen().setJoinStyle(Qt.MiterJoin)
        self.painter.pen().setColor(gui.QColor(128, 128, 128))

        self.painter.save()
        self._paint(p.users)
        self.painter.restore()
        self.painter.end()
        self.painter = None

    def _paint(self, users):
        p = self.painter
        p.setFont(gui.QFont("Helvetica", 6))
        
        red = gui.QColor(248, 32, 48, 192)
        ptpen = gui.QPen()
        ptpen.setColor(red)
        brush = gui.QBrush()
        brush.setColor(red)
        brush.setStyle(Qt.SolidPattern)
        p.setBrush(brush)

        txtpen = gui.QPen()
        txtpen.setColor(Qt.black)
        txtpen.brush().setColor(Qt.black)
        for name, u in users.iteritems():
            lat = u['latsum'] / u['count']
            lon = u['lonsum'] / u['count']

            size = 2 * max(1, min(30, math.log(u['count'], 10)))

            point = self.matrix.map(core.QPointF(lon, lat))
            p.setPen(ptpen)
            #p.drawPoints(point)
            p.drawEllipse(core.QRectF(point, core.QSizeF(size, size)).translated(-size/2, -size/2))

            flags = Qt.AlignHCenter | Qt.AlignVCenter
            bound = p.boundingRect(core.QRectF(), flags, name)
            bound.moveCenter(point)
            p.setPen(txtpen)
            p.drawText(bound, flags, name)

def main():
    p = OSMParser()
    p.parseBZ2File('switzerland.osm.bz2')
    #p.parseBZ2File('test.osm.bz2')
    p.stat(sys.stdout)
    if hasQt:
        app = gui.QApplication(sys.argv)
        svggen = SvgGenerator()
        svggen.exportSvg(p)

if __name__ == '__main__':
    main()
