# -*- coding: utf-8 -*-
import sys
import os
import shutil
import tempfile
import socket
import ssl
import pprint
from optparse import OptionParser
import traceback
import re

import i18n
_ = i18n.language.ugettext

from dfs.msg.stream import write_msg, read_msg, accept_msg
from dfs.msg.message import *
from dfs.utils import check_sum, send_contents, receive_contents

from utils import normal_files, super_files

def init_stream(host, port, cert, log=None):    
    log_msg(_("Connecting server..."), log)
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    stream = ssl.wrap_socket(s, ca_certs=cert, cert_reqs=ssl.CERT_REQUIRED, ssl_version=ssl.PROTOCOL_TLSv1)
    stream.connect((host, port))
    return stream

def upload(stream, files, src_dir, file_exts, checkin=False, log=None):
    """Checkin / Upload files to the server."""

    def is_adj(path):
        return False if len(os.path.dirname(path)) == 0 else True
    
    log_msg(_("Sending request..."), log)
    if checkin:
        write_msg(stream, CheckinRequest(len(files)))
    else:
        write_msg(stream, UploadRequest(len(files)))
    done = []
    
    failed = set()
    for file_path in files:
        log_msg(_("Uploading") + " " + file_path + "...", log)

        if not is_adj(file_path):
            file_id = file_path
            contents = read_contents(src_dir, file_path, file_exts)
        else:
            file_id, choice = os.path.split(file_path)
            contents = read_contents(os.path.join(src_dir, file_id), choice, file_exts)

        write_msg(stream, Message(file_id))
        write_msg(stream, BoolMessage(is_adj(file_path)))
        msg = read_msg(stream)
        if type(msg) != OkMessage:
            log_msg(_("Upload failed: server doesn't accept this file from you"), log)
            failed.add(file_path)
            continue
        send_contents(stream, contents)

    write_msg(stream, ClientDone())
    accept_msg(stream, ServerDone)
    
    if not checkin:
        log_msg(_("Done"), log)
        return
    
    # When checkin command
    for file_path in files:
        if file_path in failed:
            continue
        if not is_adj(file_path):
            log_msg(_("Deleting") + " " + file_path + "...", log)
            for ext in file_exts:
                os.remove(os.path.join(src_dir, file_path + ext))
        else:
            file_id, ver = os.path.split(file_path)
            log_msg(_("Deleting") + " " + file_id + "...", log)
            shutil.rmtree(os.path.join(src_dir, file_id))

    log_msg(_("Done"), log)

def download(stream, n, dest_dir, tmp_dir, log=None):
    """Download from server n files."""

    log_msg(_("Sending request..."), log)
    write_msg(stream, DownloadRequest(n))
    n = accept_msg(stream, NumMessage).get_number()
    if n == 0:
        log_msg(_("No files to download"), log)        

    done = []
    for i in range(n):
        file_id = accept_msg(stream, Message).get_contents()
        log_msg(_("Downloading") + " " + file_id + "...", log)
        contents = receive_contents(stream)
        save_contents(contents, tmp_dir, file_id)
        for ext in contents.keys():
            done.append(file_id + ext)

    write_msg(stream, ClientDone())
    accept_msg(stream, ServerDone)
    for file_name in done:
        from_path = os.path.join(tmp_dir, file_name)
        to_path = os.path.join(dest_dir, file_name)
        shutil.move(from_path, to_path)
    log_msg(_("Done"), log)

def download_prim(stream, n, dest_dir, tmp_dir, log=None):
    """Download from server n pairs of files for adjudication."""

    log_msg(_("Sending request..."), log)
    write_msg(stream, DownloadPrimRequest(n))

    msg = read_msg(stream)
    if type(msg) != OkMessage:
        log_msg(_("FAILED: you don't have adjudicator privileges"), log)
        return

    n = accept_msg(stream, NumMessage).get_number()
    anno_per_file = accept_msg(stream, NumMessage).get_number()
    if n == 0:
        log_msg(_("No files to download"), log)                

    done = []
    for i in range(n):
        file_id = accept_msg(stream, Message).get_contents()
        log_msg(_("Downloading") + " " + file_id + "...", log)            
        
        file_dir = os.path.join(tmp_dir, file_id) 
        os.mkdir(file_dir)
        done.append(file_id)
                
        for j in range(anno_per_file):
            save_contents(receive_contents(stream), file_dir, chr(ord('A') + j))
            
        if accept_msg(stream, BoolMessage).get_boolean() == True:
            if anno_per_file == 1:
                save_contents(receive_contents(stream), file_dir, "A")
            else:           
                save_contents(receive_contents(stream), file_dir, "Super")
                
    write_msg(stream, ClientDone())
    accept_msg(stream, ServerDone)
    for fild_id in done:
        from_path = os.path.join(tmp_dir, fild_id)
        to_path = os.path.join(dest_dir, fild_id)
        shutil.move(from_path, to_path)
    log_msg(_("Done"), log)

def check_stats(stream, log=None):
    """Check annotatator stats."""

    log_msg(_("Sending request..."), log)
    write_msg(stream, StatsRequest())
    n = accept_msg(stream, NumMessage).get_number()
    log_msg(_("Number of finished files: ") + str(n), log)
    write_msg(stream, ClientDone())
    accept_msg(stream, ServerDone)    
    log_msg(_("Done"), log)
    
def checkout(stream, dest_dir, tmp_dir, exts, log=None):
    log_msg(_("Sending request..."), log)
    write_msg(stream, CheckoutRequest())

    done = []
    n = accept_msg(stream, NumMessage).get_number()
    for i in range(n):
        file_id = accept_msg(stream, Message).get_contents()
        log_msg(_("Downloading") + " " + file_id + "...", log)

        contents = receive_contents(stream)
        save_contents(contents, tmp_dir, file_id)
        for ext in contents.keys():
            done.append(file_id + ext)

    n = accept_msg(stream, NumMessage).get_number()
    anno_per_file = accept_msg(stream, NumMessage).get_number()
    
    for i in range(n):
        file_id = accept_msg(stream, Message).get_contents()
        log_msg(_("Downloading") + " " + file_id + "...", log)
        
        file_dir = os.path.join(tmp_dir, file_id)
        os.mkdir(file_dir)        
        
        for j in range(anno_per_file):
            save_contents(receive_contents(stream), file_dir, chr(j + ord('A')))
        
        if accept_msg(stream, BoolMessage).get_boolean() == True:
            if anno_per_file == 1:
                save_contents(receive_contents(stream), file_dir, "A")
            else:
                save_contents(receive_contents(stream), file_dir, "Super")                

        done.append(file_id)

    write_msg(stream, ClientDone())
    accept_msg(stream, ServerDone)
        
    log_msg(_("Deleting local files..."), log)
    try:
        # normal texts
        for f in normal_files(dest_dir, exts):        
            for ext in exts:
                os.remove(os.path.join(dest_dir, f + ext))
        
        # super annotated texts        
        dirs = set()
        for f in super_files(dest_dir, exts):
            dir = f.split(os.path.sep)[0]
            dirs.add(dir)        
        for dir in dirs:  
            shutil.rmtree(os.path.join(dest_dir, dir))
                                
    except Exception as ex:
        log_msg(_("Error occured:") + unicode(str(ex), errors='replace'), log)
            
    log_msg(_("Saving dowloaded files..."), log)    
    for path in done:
        from_path = os.path.join(tmp_dir, path)
        to_path = os.path.join(dest_dir, path)
        shutil.move(from_path, to_path)
        
    log_msg(_("Done"), log)

def return_files(stream, files, reason, src_dir, file_exts, log=None):
    """Return files to the server."""

    def is_adj(path):
        return False if len(os.path.dirname(path)) == 0 else True
    
    log_msg(_("Sending request..."), log)    
    write_msg(stream, ReturnRequest(len(files)))
    write_msg(stream, Message(reason))
    
    done = []
        
    failed = set()
    for file_path in files:
        log_msg(_("Uploading") + " " + file_path + "...", log)

        if not is_adj(file_path):
            file_id = file_path
            contents = read_contents(src_dir, file_path, file_exts)
        else:
            file_id, choice = os.path.split(file_path)
            contents = read_contents(os.path.join(src_dir, file_id), choice, file_exts)

        write_msg(stream, Message(file_id))
        write_msg(stream, BoolMessage(is_adj(file_path)))
        msg = read_msg(stream)
        if type(msg) != OkMessage:
            log_msg(_("Upload failed: server doesn't accept this file from you"), log)
            failed.add(file_path)
            continue
        send_contents(stream, contents)

    write_msg(stream, ClientDone())
    accept_msg(stream, ServerDone)
    
    for file_path in files:
        if file_path in failed:
            continue
        if not is_adj(file_path):
            log_msg(_("Deleting") + " " + file_path + "...", log)
            for ext in file_exts:
                os.remove(os.path.join(src_dir, file_path + ext))
        else:
            file_id, ver = os.path.split(file_path)
            log_msg(_("Deleting") + " " + file_id + "...", log)
            shutil.rmtree(os.path.join(src_dir, file_id))

    log_msg(_("Done"), log)
    
def save_contents(contents, dest_dir, file_id):
    for ext, data in contents.iteritems():
        dest_path = os.path.join(dest_dir, file_id + ext)
        with open(dest_path, "w") as dest:
            dest.write(data)

def read_contents(src_dir, file_id, file_exts):
    result = {}
    regex = re.compile("%s(.*)$" % file_id)
    for entry in os.listdir(src_dir):
        match = regex.match(entry)
        if not match:
            continue
        with open(os.path.join(src_dir, entry)) as src:
            ext = match.group(1)
            if ext in file_exts:
                result[ext] = src.read()
            else:
                print "Skipping file " + entry
    return result

def run_auth(stream, login, passwd, log=None):
    log_msg(_("Authentication..."), log)
    write_msg(stream, Message(login))
    write_msg(stream, Message(passwd))
    msg = read_msg(stream)
    if type(msg) != OkMessage:
        log_msg(_("FAILED: incorrect login or password"), log)
        return False
    return True

def run_upload(login, passwd, host, port, cert, files, src_dir, file_exts, checkin=False, log=None):
    result = True
    try:
        stream = init_stream(host, port, cert, log=log)
        try:
            if run_auth(stream, login, passwd, log=log):                
                upload(stream, files, src_dir, file_exts, checkin=checkin, log=log)
        finally:
            stream.close()
    except:
        log_msg(_("FAILED: unable to connect to server"), log)
        print traceback.format_exc()
        result = False
    if log:
        log.put(None)
    return result

def run_download(login, passwd, host, port, cert, n, dest_dir, log=None):
    result = True
    try:
        stream = init_stream(host, port, cert, log=log)
        try:
            tmp_dir = tempfile.mkdtemp()
            try:
                if run_auth(stream, login, passwd, log=log):
                    download(stream, n, dest_dir, tmp_dir, log=log)
            finally:
                shutil.rmtree(tmp_dir)
        finally:
            stream.close()
    except:
        log_msg(_("FAILED: unable to connect to server"), log)
        print traceback.format_exc()
        result = False
    if log:
        log.put(None)
    return result

def run_check_stats(login, passwd, host, port, cert, log=None):
    result = True
    try:
        stream = init_stream(host, port, cert, log=log)
        try:
            if run_auth(stream, login, passwd, log=log):
                check_stats(stream, log=log)            
        finally:
            stream.close()
    except:
        log_msg(_("FAILED: unable to connect to server"), log)
        print traceback.format_exc()
        result = False
    if log:
        log.put(None)
    return result

def run_download_prim(login, passwd, host, port, cert, n, dest_dir, log=None):
    result = True
    try:
        stream = init_stream(host, port, cert, log=log)
        try:
            tmp_dir = tempfile.mkdtemp()
            try:
                if run_auth(stream, login, passwd, log=log):
                    download_prim(stream, n, dest_dir, tmp_dir, log=log)
            finally:
                shutil.rmtree(tmp_dir)
        finally:
            stream.close()
    except:
        log_msg(_("FAILED: unable to connect to server"), log)
        print traceback.format_exc()
        result = False
    if log:
        log.put(None)
    return result

def run_checkout(login, passwd, host, port, cert, dest_dir, exts, log=None):
    result = True
    try:
        stream = init_stream(host, port, cert, log=log)
        try:
            tmp_dir = tempfile.mkdtemp()
            try:
                if run_auth(stream, login, passwd, log=log):
                    checkout(stream, dest_dir, tmp_dir, exts, log=log)
            finally:
                shutil.rmtree(tmp_dir)
        finally:
            stream.close()
    except:
        log_msg(_("FAILED: unable to connect to server"), log)
        print traceback.format_exc()
        result = False
    if log:
        log.put(None)
    return result

def run_return(login, passwd, host, port, cert, files, reason, src_dir, file_exts, log=None):
    result = True
    try:
        stream = init_stream(host, port, cert, log=log)
        try:
            if run_auth(stream, login, passwd, log=log):
                return_files(stream, files, reason, src_dir, file_exts, log)            
        finally:
            stream.close()
    except:
        log_msg(_("FAILED: unable to connect to server"), log)
        print traceback.format_exc()
        result = False
    if log:
        log.put(None)
    return result

def log_msg(msg, log, new_line=True):
    if new_line:
        msg = msg + "\n"
    sys.stdout.write(msg.encode('utf-8'))
    if log:
        log.put(msg)

if __name__ == "__main__":

    optparser = OptionParser(usage="""usage: %prog [options] CMD CMD-ARGS

    Command line client for files distribution system.""")

    optparser.add_option("--login", dest="login")
    optparser.add_option("--passwd", dest="passwd")
    optparser.add_option("--home", metavar="DIR", dest="home",
            help="Directory with users files.")

    (options, args) = optparser.parse_args()

    if options.login == None:
        optparser.print_help()
        print "\n--login option is mandatory"
        sys.exit(0)
    if options.passwd == None:
        optparser.print_help()
        print "\n--passwd option is mandatory"
        sys.exit(0)
    if options.home == None:
        optparser.print_help()
        print "\n--home option is mandatory"
        sys.exit(0)

    commands = ["upload", "download", "download2"]
    if len(args) < 1 or args[0] not in commands:
        optparser.print_help()
        print "\nCMD should be one of the following:"
        print commands
        sys.exit(0)

    cmd = args[0]
    if cmd == "upload":
        run_upload(options.login, options.passwd, args[1:], options.home)
    elif cmd == "download":
        run_download(options.login, options.passwd,
                int(args[1]), options.home)
    elif cmd == "download2":
        run_download_prim(options.login, options.passwd,
                int(args[1]), options.home)
