#!/bin/env python

import argparse
import datetime
import getpass
import multiprocessing
import os
import psycopg2
import psycopg2.extras
import psycopg2.errorcodes
import quopri
import random
import re
import sys
import traceback
import UserDict

from multiprocessing import Process, JoinableQueue

class Message(dict):
	
	def __init__(self, message):
		self.message = message
		self.body    = self.body(message)
		self.headers = self.headers(message)
		self.parts   = self.parts(message)
	
	def __getitem__(self, key):
		if self.headers.has_key(key.lower()):
			return self.headers[key.lower()]
		else:
			return None
	
	def __setitem__(self, key, value):
		self.headers.update({key.lower() : value})
	
	def __delitem__(self, key):
		if self.headers.has_key(key.lower()):
			del self.headers[key.lower()]
	
	def keys(self):
		return self.headers.keys()
	
	def get_body(self):
		return self.body

	def get_raw(self):
		return self.message
	
	def get_parts(self):
		return self.parts
	
	def get_headers(self):
		return self.headers

	def get_content_type(self):
		
		if self.headers.has_key('content-type'):
			return self.headers['content-type'].split(';')[0]
		else:
			return None
		
	def __repr__(self):
		return '%s %s' % (type(self).__name__, self.headers)
	
	def is_multipart(self):
		ctype = self.get_content_type()
		if ctype != None and re.match('multipart/.*', ctype):
			return True
		else:
			return False
	
	def part_boundary(self):
	
		if not self.is_multipart():
			return None
		else:
			r = re.match('.*boundary="?([^"]*)"?', self.headers['content-type'], re.IGNORECASE)
			if r:
				return '--' + r.group(1)

	# FIXME this keeps only the last value - needs to keep a list
	def headers(self, message):
		
		lines = message.split("\n")
		
		key = ''
		value = ''
		
		headers = {};
		
		for l in lines:
			if l == '':
				if key != '':
					headers.update({key.lower() : value})
				break
			
			r = re.match('([a-zA-Z0-9-]*):\s*(.*)', l)
			if r:
				if key != '':
					headers.update({key.lower() : value})
				
				key = r.group(1)
				value = r.group(2)
			else:
				value += ' ' + l.strip()
			
		r = re.match('^From .*@.*\s+([a-zA-Z]*\s+[a-zA-Z]*\s+[0-9]+ [0-9]+:[0-9]+:[0-9]+\s+[0-9]{4})$', lines[0])
		if r:
			headers.update({'message-date' : r.group(1)})
		
		r = re.match('^From bouncefilter\s+([a-zA-Z]*\s+[a-zA-Z]*\s+[0-9]+ [0-9]+:[0-9]+:[0-9]+\s+[0-9]{4})$', lines[0])
		if r:
			headers.update({'message-date' : r.group(1)})
		
		r = re.match('^From scrappy$', lines[0])
		if r:
			r = re.search('^([^+-]*) [+-].*$', headers['date'])
			if r:
				headers.update({'message-date' : r.group(1)})
			else:
				r = re.search('^([^\(]*) \(.*', headers['date'])
				if r:
					headers.search({'message-date' : r.group(1)})
				else:
					headers.update({'message-date' : headers['date']})
		
		return headers
	
	def body(self, message):
	
		lines = message.split("\n")
		body = ''
		in_body = False
		
		for l in lines:
			
			if in_body:
				body += l + "\n"
			
			if l == '':
				in_body = True
			
		return body.strip()

	def is_boundary(self, boundary, line):
		
		if (boundary == line) or ((boundary + '--') == line):
			return True
		
		else:
			return False
	
	def parts(self, message):
		
		# not a multipart message - the whole message is a part
		if not self.is_multipart():
			return [message]

		# split the message into parts
		else:
			boundary = self.part_boundary()
			lines    = self.body.split("\n")
			part     = ''
			parts    = []
			
			for l in lines:
				
				if self.is_boundary(boundary, l) and part != '':
					parts.append(Message(part))
					part = ''
				
				elif self.is_boundary(boundary, l):
					pass
				
				else:
					part += l + "\n"
			
			return parts
			
	def get_payload(self):
		
		payload = ''
		
		# if it's not a multi-part message, then just get the body
		if not self.is_multipart():
			
			# we do accpet only some content types (basically text/* and empty, which is supposed to be text/plain)
			if self.get_content_type() == None or self.get_content_type().lower() == 'text/plain' or self.get_content_type().lower() == 'text/html':
				
				# decode if needed
				# FIXME handle quoted-printable too
				if self['content-transfer-encoding'] == 'base64':
					body = self.body.decode('base64')
				elif self['content-transfer-encoding'] == 'quoted-printable':
					body = self.qpdecodebody(self.body)
				else:
					body = self.body
				
				# if it's text/html, strip the html tags
				if self.get_content_type() == 'text/html':
					body = self.clean_html(body)
				
				return body.strip()
		else:
			
			# if it's a multi-part message, then try to get the text/plain part first
			for p in self.parts:
				if p.get_content_type() != None and p.get_content_type().lower() == 'text/plain':
					return p.get_payload()
				
			# no text/plain part, try to get any text/* part
			for p in self.parts:
				if p.get_content_type() != None and re.match('text/.*', p.get_content_type().lower()):
					return p.get_payload()
				
			# meh, nothing useful :-(
			return None

	def get_subject(self):
	
		if not self['subject']:
			return None
		
		subject = ''
		lines = self['subject'].split(' ')
		for l in lines:

			if (self.is_qpencoded(l)):
				subject += ' ' + self.qpdecode(l)
			elif (self.is_b64encoded(l)):
				subject += ' ' + self.b64decode(l)
			else:
				subject += ' ' + l
			
		return subject

	def is_b64encoded(self, line):
		
		r = re.match('=\?(.*)\?B\?(.*)\?=', line)
			
		if r:
			return True
		else:
			return False

	def is_qpencoded(self, line):
		
		r = re.match('=\?(.*)\?Q\?(.*)\?=', line)
		
		if r:
			return True
		else:
			return False
		
	def b64decode(self, line):
		
		lines = line.split(' ')
		output = ''
		
		for line in lines:
			r = re.match('=\?(.*)\?B\?(.*)\?=', line)
			
			if r:
				output += ' ' + r.group(2).decode('base64')
		
		return output
	
	def qpdecodebody(self, body):
		return quopri.decodestring(body)
	
	def qpdecode(self, line):
		
		lines = line.split(' ')
		output = ''
		
		for line in lines:
			r = re.match('=\?(.*)\?Q\?(.*)\?=', line)
			
			if r:
				output += ' ' + quopri.decodestring(r.group(2))
		
		return output
		
	def clean_html(self, raw_html):

		return re.sub(re.compile('<.*?>'),'', raw_html)
	
def load_files(id, connstr, queue):

	conn = psycopg2.connect(connstr)
	conn.set_client_encoding('LATIN1')
	
	psycopg2.extras.register_hstore(conn)
	cur = conn.cursor()
	
	c = 0
	n = 0
	start_time = datetime.datetime.now()
	
	try:
	
		while True:
			
			task = queue.get()
			
			if not task:
				
				print "worker %s : no more tasks, terminating" % (id,)
				queue.task_done()
				break;
			
			for f in task['files']:
				
				(msgs,errs,dups) = load_mbox(task['list'], f, cur)
				conn.commit()
				
				n += msgs
				d = (datetime.datetime.now() - start_time).total_seconds()
				
				print "worker %s : %s - imported %s messages (%s errors, %s duplicate): %s msgs/sec" % (id, f, msgs, errs, dups, round((c+n)/d,1))
				
				# analyze only if we reach enough new messages (at least 250 and more that inserted so far)
				if (n >= c) and (n >= 250):
					c += n
					n = 0
					cur.execute('ANALYZE messages')
			
			queue.task_done()
			
	except KeyboardInterrupt:
		print "worker %s : process stopped by Ctrl-C" % (id,)
	
	finally:
		cur.close()
		conn.close()

def load_mbox(lname, filename, cur):
	
	r = re.match('(.*)\.[0-9]+-[0-9]+', filename)
	if r:
		lst = r.group(1)
		if lst != lname:
			print "ERROR: mbox '%s' does not belong to list '%s'" % (filename, lname)
			return
	else:
		print "unknown list (expected %s): %s" % (filename, lname)
		return

	msgs = read_messages(filename)
	
	cur.execute("SAVEPOINT bulk_load")
	
	n_msgs = 0
	n_errs = 0
	n_dups = 0
	
	for m in msgs:
		
		# print m['message-date'],"\t",len(m.get_body()),'B',"\t",m.get_content_type(),"\t",len(m.get_parts()),"\t",m['subject'],"\t",m['from']
		
		try:
			
			mid = None
			if m['message-id']:
				r = re.search('<(.*)>.*', m['message-id'])
				if r:
					mid = r.group(1)
			
			date = m['message-date']

			subject = m.get_subject()
			if subject:
				subject = subject.strip()
			
			refs = None
			if m['references']:
				refs = re.findall('<([^>]*)>', m['references'])
			
			in_reply_to = None
			if m['in-reply-to']:
				in_reply_to = re.findall('<([^>]*)>', m['in-reply-to'])
			
			body = m.get_payload();
			
			if m['message-date'] == '':
				continue
			
			cur.execute("INSERT INTO messages (list, message_id, in_reply_to, refs, sent, subject, author, body_plain, headers, raw_message) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)",
						(lst, mid, in_reply_to, refs, date, subject, m['from'], m.get_payload(), m.get_headers(), m.get_raw()))
			
			cur.execute("SAVEPOINT bulk_load")
			
			n_msgs += 1
			
		except psycopg2.DatabaseError as ex:
			
			# well, no matter what happened, rollback to the savepoint
			cur.execute("ROLLBACK TO bulk_load")
			
			if ex.pgcode == psycopg2.errorcodes.UNIQUE_VIOLATION:
				n_dups += 1
			else:
				print ex.pgcode,':',ex.pgerror
				n_errs += 1
	
	return (n_msgs,n_errs,n_dups)
	
def read_messages(filename):
	
	messages = []
	message = ''
	
	f = open(filename, 'r')
	
	for line in f:
		
		if re.match('^From .*@.*\s+[a-zA-Z]*\s+[a-zA-Z]*\s+[0-9]+ [0-9]+:[0-9]+:[0-9]+\s+[0-9]{4}$', line):
			if message != '':
				messages.append(Message(message))
			message = line
		
		elif re.match('^From bouncefilter\s+[a-zA-Z]*\s+[a-zA-Z]*\s+[0-9]+ [0-9]+:[0-9]+:[0-9]+\s+[0-9]{4}$', line):
			if message != '':
				messages.append(Message(message))
			message = line
		
		elif re.match('^From scrappy$', line):
			if message != '':
				messages.append(Message(message))
			message = line
		
		else:
			message += line

	messages.append(Message(message))
	return messages

def get_space(files):
	'returns total size occupied by the files (from a list)'
	
	s = 0
	for f in files:
		s += os.path.getsize(f)
	
	return s

def build_parser():
	
	parser = argparse.ArgumentParser(description='Archie loader')
	
	parser.add_argument('files', metavar='FILE', type=str, nargs='+', help='mbox files to load')
	
	parser.add_argument('--workers', dest='nworkers', action='store', type=int,
						default=0, help='number of worker processes (default is number of cores)')
	
	parser.add_argument('--host', dest='hostname', action='store', type=str,
						default='localhost', help='database server hostname (localhost by default)')
	
	parser.add_argument('--db', dest='dbname', action='store', type=str,
						required=True, help='name of the database to test (required)')
	
	parser.add_argument('--port', dest='port', action='store', type=int,
						default=5432, help='database server port (5432 by default)')
	
	parser.add_argument('--user', dest='user', action='store', type=str,
						default=getpass.getuser(), help='db username (current OS user by default)')
	
	parser.add_argument('--password', dest='password', action='store', type=str,
						default=None, help='db password (empty by default)')
	
	return parser
	
if __name__ == '__main__':
	
	# build parser and parse arguments
	parser = build_parser()
	args = parser.parse_args()
	
	if args.password:
		connstr = 'host=%s dbname=%s user=%s port=%s password=%s' % (args.hostname,
			args.dbname, args.user, args.port, arg.password)
	else:
		connstr = 'host=%s dbname=%s user=%s port=%s' % (args.hostname,
			args.dbname, args.user, args.port)
	
	# prepare list of files
	lists = {}
	for f in args.files:
		r = re.match('([a-z-]*)\.[0-9]+-[0-9]+', f)
		if r:
			mlist = r.group(1)
			if not lists.has_key(mlist):
				lists.update({mlist : [f]})
			else:
				lists[mlist].append(f)
	
	# get occupied space for each list (assumption space ~ number of messages)
	keys = [{'key' : k, 'length' : get_space(lists[k])} for k in lists.keys()]
	keys = sorted(keys, key=lambda x: -x['length'])
	
	queue = JoinableQueue()
	
	for k in keys:
		queue.put({'list' : k['key'], 'files' : lists[k['key']]})
	
	# get number of workers
	nworkers = multiprocessing.cpu_count()
	if args.nworkers > 0:
		nworkers = args.nworkers
	
	if nworkers > len(lists):
		nworkers = len(lists)
	
	try:
	
		started = datetime.datetime.now()
	
		# start workers
		workers = []
		for i in xrange(nworkers):
			p = Process(target=load_files, args=(i, connstr, queue))
			workers.append(p)
			p.start()
			
			# for each worker, put there one empty message, meaning "stop"
			queue.put(False)
		
		# wait for the workers to finish
		queue.join()
		
		runtime = (datetime.datetime.now() - started).total_seconds()
		
		print "total runtime: %s seconds" % (round(runtime,1),)
	
	except KeyboardInterrupt:
		print "loading interrupted"