# -*- coding: utf-8 -*-
import struct

from message import decode

__all__ = ["read_msg", "write_msg", "EndOfStream"]

class EndOfStream(Exception):
    pass

class BadMessage(Exception):
    pass

def read_msg(connstream):
    encoded = read_encoded(connstream)
    return decode(*encoded)

def accept_msg(connstream, cls):
    msg = read_msg(connstream)
    if type(msg) != cls:
        raise BadMessage("Expected %s, got %s" %
                (cls.__name__, type(msg).__name__))
    return msg

def write_msg(connstream, msg):
    encoded = msg.encode()
    write_encoded(connstream, *encoded)

def read_encoded(connstream):
    _type = read_type(connstream)
    length = read_length(connstream)
    content = read_content(connstream, length)
    return _type, content

def write_encoded(connstream, _type, content):
    write_type(connstream, _type)
    write_length(connstream, len(content))
    write_content(connstream, content)

def read_type(connstream):
    data = read_all(connstream, 1)
    return struct.unpack("!B", data)[0]

def write_type(connstream, _type):
    data = struct.pack("!B", _type)
    write_all(connstream, data)

def read_length(connstream):
    data = read_all(connstream, 4)
    return struct.unpack("!L", data)[0]

def write_length(connstream, n):
    data = struct.pack("!L", n)
    write_all(connstream, data)

def read_content(connstream, length):
    return read_all(connstream, length)

def write_content(connstream, content):
    write_all(connstream, content)

def write_all(connstream, data):
    k = connstream.write(data)
    if k != len(data):
        raise Exception("Did not send all data !")

def read_all(connstream, k):
    data = ""
    while k > 0:
        # part = connstream.read()
        part = connstream.read(min(k, 1024))
        if len(part) == 0:
            raise EndOfStream
        k -= len(part)
        # if k < 0:
        #     raise Exception("Received more than expected !")
        data += part
    return data

