# -*- coding: utf-8 -*-
import sys
import os
import traceback
import socket, ssl
import random
from optparse import OptionParser 
import signal

from utils import validate_user, UserInvalid, send_contents, receive_contents
from msg.stream import read_msg, write_msg, accept_msg, EndOfStream, BadMessage
from msg.message import *
from repo import Repo
from database import Database
from config import Config

class SSLServer:

    def __init__(self, host, port, backlog, cert_file, key_file,
            repository, svn_login, svn_passwd, pass_file,
            users_file, file_exts, anno_per_file, log=None):
        """
        Initialize SSL server.

        params:
        =======
        host : str
            Host name.
        port : int
            Port number.
        backlog : int
            Maximum number of waiting connections.
        cert_file : path
            Public PEM certificate.
        key_file : path
            Private key.
        pass_file : path
            File with client passwords.
        users_file : path
            File with additional users configuration.
        repository : path
            Subversion working copy.
        svn_login : str
            Subversion login.
        svn_passwd : str
            Subversion password.
        file_exts : [str]
            List of file extensions.
        anno_per_file : int
            Number of normal annotators per file (1 or 2)
        """
        self.cert_file = cert_file
        self.key_file = key_file
        self.pass_file = pass_file
        self.users_file = users_file
        self.wc = Repo(repository, svn_login, svn_passwd)
        self.file_exts = file_exts
        if log is not None:
            self.log = open(log, "a")
        else:
            self.log = sys.stdout
        self.bound = bind_socket(host, port, backlog=backlog)
        self.serving = None
        self.anno_per_file = anno_per_file

    def run(self):
        while True:
            newsocket, fromaddr = self.bound.accept()
            self.serving = fromaddr
            connstream = None
            try:
                connstream = ssl.wrap_socket(newsocket,
                                         server_side=True,
                                         certfile=self.cert_file,
                                         keyfile=self.key_file,
                                         ssl_version=ssl.PROTOCOL_TLSv1)
            
                login = validate_user(connstream, self.pass_file)
                self.serve_client(connstream, login)
            except UserInvalid as login:
                print >> self.log, "AUTH ERROR:", login
            except BadMessage as info:
                print >> self.log, "BAD MESSAGE ERROR:", info
                print >> self.log, traceback.format_exc().strip()
            except EndOfStream:
                print >> self.log, "UNEXPECTED END OF STREAM:"
                print >> self.log, traceback.format_exc().strip()
            except:
                print >> self.log, "UNEXPECTED ERROR:"
                print >> self.log, traceback.format_exc().strip()
            finally:
                try:
                    if not connstream is None:
                        connstream.shutdown(socket.SHUT_RDWR)
                except:
                    print >> self.log, "UNEXPECTED ERROR:"
                    print >> self.log, traceback.format_exc().strip()
                finally:
                    if not connstream is None:
                        connstream.close()
                    self.serving = None

    def exit(self):
        if self.serving is not None:
            print ("Client from %s connected to the server"
                  % str(self.serving))
        else:
            sys.exit(0)

    def process_msg(self, msg, stream, login):
        if type(msg) == UploadRequest:
            self.process_upload_msg(msg, stream, login, checkin=False)
        elif type(msg) == CheckinRequest:
            self.process_upload_msg(msg, stream, login, checkin=True)
        elif type(msg) == CheckoutRequest:
            self.process_checkout_msg(msg, stream, login)
        elif type(msg) == DownloadRequest:
            self.process_download_msg(msg, stream, login)
        elif type(msg) == DownloadPrimRequest:
            self.process_download_prim_msg(msg, stream, login)
        elif type(msg) == ReturnRequest:
            self.process_return_msg(msg, stream, login)
        elif type(msg) == StatsRequest:
            self.process_stats_msg(msg, stream, login)      
        else:
            print >> self.log, login, "sent:", msg

    def process_return_msg(self, msg, stream, login):
        reason = accept_msg(stream, Message).get_contents().decode("utf-8")
        
        usr_cfg = Config(self.users_file)
        try:
            user_adj = login in usr_cfg["auth.adjudicators"].split()
        except KeyError:
            user_adj = False

        db = Database(self.wc.db_path(), self.anno_per_file)
        n = msg.get_number()
        for _ in range(n):
            file_id = accept_msg(stream, Message).get_contents()        
            is_adj = accept_msg(stream, BoolMessage).get_boolean()
            if is_adj and not user_adj:
                write_msg(stream, KoMessage())
                print >> self.log, "Upload error - user " + login + " tried to upload super annotated file " + file_id + " but is not a super annotator"
                continue
            else:
                error = db.upload_prevention(file_id, login, is_adj)
                if error is not None:
                    write_msg(stream, KoMessage())
                    print >> self.log, "Upload error by user " + login + ". Details: "+error                    
                    continue                    
            
            write_msg(stream, OkMessage())            
            contents = receive_contents(stream)
            for key in contents.keys():
                if key not in self.file_exts:
                    print >> self.log, "Skipped file uploaded by user " + login + ". Filename:"+file_id+key
                    del contents[key]
            
            if is_adj is False:
                idx = db.return_file(file_id, login, reason)                
                self.wc.upload(file_id, idx, contents)
            else:
                db.return_file_prim(file_id, login, reason)
                self.wc.upload_prim(file_id, contents)

        db.save()
        accept_msg(stream, ClientDone)
        self.wc.commit("return request from %s"
                      % (login))
        write_msg(stream, ServerDone())        
        
        
    def process_upload_msg(self, msg, stream, login, checkin=False):
        usr_cfg = Config(self.users_file)
        try:
            user_adj = login in usr_cfg["auth.adjudicators"].split()
        except KeyError:
            user_adj = False

        db = Database(self.wc.db_path(), self.anno_per_file)
        n = msg.get_number()
        for _ in range(n):
            file_id = accept_msg(stream, Message).get_contents()
            is_adj = accept_msg(stream, BoolMessage).get_boolean()
            if is_adj and not user_adj:
                write_msg(stream, KoMessage())
                print >> self.log, "Upload error - user " + login + " tried to upload super annotated file " + file_id + " but is not a super annotator"
                continue
            else:
                error = db.upload_prevention(file_id, login, is_adj)
                if error is not None:
                    write_msg(stream, KoMessage())
                    print >> self.log, "Upload error by user " + login + ". Details: "+error                    
                    continue                    
            
            write_msg(stream, OkMessage())            
            contents = receive_contents(stream)
            for key in contents.keys():
                if key not in self.file_exts:
                    print >> self.log, "Skipped file uploaded by user " + login + ". Filename:"+file_id+key
                    del contents[key]

            if is_adj is False:
                if checkin:
                    idx = db.upload(file_id, login)
                else:
                    idx = db.upload_id(file_id, login)
                self.wc.upload(file_id, idx, contents)
            else:
                if checkin:
                    db.upload_prim(file_id, login)
                self.wc.upload_prim(file_id, contents)

        if checkin:
            db.save()
        accept_msg(stream, ClientDone)
        self.wc.commit("%s request from %s"
                      % ("checkin" if checkin else "upload", login))
        write_msg(stream, ServerDone())

    def process_checkout_msg(self, msg, stream, login):
        db = Database(self.wc.db_path(), self.anno_per_file)

        ann_files = db.owns_normal(login)
        write_msg(stream, NumMessage(len(ann_files)))
        for file_id in ann_files:
            write_msg(stream, Message(file_id))
            idx = db.upload_id(file_id, login)
            contents = self.wc.checkout(file_id, idx, self.file_exts)
            send_contents(stream, contents)

        adj_files = db.owns_super(login)
        write_msg(stream, NumMessage(len(adj_files)))
        write_msg(stream, NumMessage(self.anno_per_file))
        for file_id in adj_files:
            write_msg(stream, Message(file_id))

            contents = self.wc.download_prim(file_id, self.file_exts, self.anno_per_file)
            for conts in contents:
                send_contents(stream, conts)

            conts3 = self.wc.checkout_prim(file_id, self.file_exts)
            if conts3 is None:
                write_msg(stream, BoolMessage(False))
            else:
                write_msg(stream, BoolMessage(True))
                send_contents(stream, conts3)

        accept_msg(stream, ClientDone)
        #self.wc.commit("checkout request from %s" % login)
        write_msg(stream, ServerDone())

    def process_stats_msg(self, msg, stream, login):
        usr_cfg = Config(self.users_file)
        db = Database(self.wc.db_path(), self.anno_per_file)
                
        write_msg(stream, NumMessage(db.finished_count(login)))
        
        accept_msg(stream, ClientDone)        
        write_msg(stream, ServerDone())
        
    def process_download_msg(self, msg, stream, login):
        usr_cfg = Config(self.users_file)
        db = Database(self.wc.db_path(), self.anno_per_file)

        down_num = msg.get_number()
        owns_num = len(db.owns_normal(login))
        limit = float('inf')
        try:
            limit = int(usr_cfg["limits.%s.annotation" % login])
        except KeyError:
            limit = int(usr_cfg["limits.annotation"])
        if down_num + owns_num > limit:
            down_num = max(0, limit - owns_num)

        for_ann_fixed, for_ann = db.for_annotation(login)
        if len(for_ann) + len(for_ann_fixed) == 0:
            print >> self.log, "User " + login + " has no more files to download for annotation."
        
        down_files = sample(for_ann_fixed, down_num)
        down_num = max(0, down_num - len(down_files))
        down_files = down_files + sample(for_ann, down_num)
        
        write_msg(stream, NumMessage(len(down_files)))
        
        for file_id in down_files:
            fixed_idx = db.download(file_id, login)
            write_msg(stream, Message(file_id))
            if fixed_idx is None:
                contents = self.wc.download(file_id, self.file_exts)
            else:
                contents = self.wc.checkout(file_id, fixed_idx, self.file_exts)
            send_contents(stream, contents)
        db.save()
        accept_msg(stream, ClientDone)
        self.wc.commit("download request from %s" % login)
        write_msg(stream, ServerDone())

    def process_download_prim_msg(self, msg, stream, login):
        usr_cfg = Config(self.users_file)
        try:
            adjudicators = usr_cfg["auth.adjudicators"].split()
        except KeyError:
            adjudicators = []
        if login in adjudicators:
            write_msg(stream, OkMessage())
        else:
            write_msg(stream, KoMessage())
            return
        db = Database(self.wc.db_path(), self.anno_per_file)

        down_num = msg.get_number()
        owns_num = len(db.owns_super(login))
        limit = float('inf')
        try:
            limit = int(usr_cfg["limits.%s.adjudication" % login])
        except KeyError:
            limit = int(usr_cfg["limits.adjudication"])
        if down_num + owns_num > limit:
            down_num = max(0, limit - owns_num)

        for_adj_fixed, for_adj = db.for_adjudication(login)        
        if len(for_adj) + len(for_adj_fixed) == 0:
            print >> self.log, "User " + login + " has no more files to download for superannotation."

        down_files = sample(for_adj_fixed, down_num)        
        down_num = max(0, down_num - len(down_files))
        down_files = down_files + sample(for_adj, down_num)
        
        write_msg(stream, NumMessage(len(down_files)))
        write_msg(stream, NumMessage(self.anno_per_file))        
        for file_id in down_files:            
            write_msg(stream, Message(file_id))
            fixed = db.download_prim(file_id, login)
            
            contents = self.wc.download_prim(file_id, self.file_exts, self.anno_per_file)
            for conts in contents:
                send_contents(stream, conts)
                
            if fixed:
                write_msg(stream, BoolMessage(True))                
                contents = self.wc.checkout_prim(file_id, self.file_exts)                
                send_contents(stream, contents)
            else:
                write_msg(stream, BoolMessage(False))
                            
        db.save()
        accept_msg(stream, ClientDone)
        self.wc.commit("download' request from %s" % login)
        write_msg(stream, ServerDone())

    def serve_client(self, connstream, login):
        msg = read_msg(connstream)
        try:
            self.process_msg(msg, connstream, login)
        except:
            self.wc.revert()
            raise

def bind_socket(host, port, backlog=5):
    bound = socket.socket()
    bound.bind((host, port))
    bound.listen(backlog)
    return bound

def sample(population, n):
    if len(population) > n:
        return random.sample(population, n)
    else:
        return population

if __name__ == "__main__":
    optparser = OptionParser(usage="""usage: %prog CONFIG""")
    (options, args) = optparser.parse_args()
    if len(args) != 1:
        optparser.print_help()
        sys.exit(0)

    cfg = Config(args[0])
    server = SSLServer(
            host=cfg["connection.host"],
            port=int(cfg["connection.port"]),
            backlog=int(cfg["connection.backlog"]),
            cert_file=cfg["connection.certfile"],
            key_file=cfg["connection.keyfile"],
            repository=cfg["svn.repository"],
            svn_login=cfg["svn.login"],
            svn_passwd=cfg["svn.passwd"],
            pass_file=cfg["users.passfile"],
            users_file=cfg["users.config"],
            file_exts=cfg["file_extensions"].split(),
            anno_per_file=int(cfg["anno_per_file"])
            )
    
    def handler(signum, frame):
        server.exit()
        
    signal.signal(signal.SIGINT, handler)

    server.run()
