# http://pymotw.com/2/SocketServer/
import pickle
import socket
import socketserver
import sys
from collections import defaultdict
from typing import Set, DefaultDict, List

from lockmanager import LockManager
from message import *
from paxos_utils import *
from timers import P2ATimer, setTimer, HeartBeatCheckTimer, HeartBeatTimer, LeaderElectionTimer


# Paxos servers
class Paxos(socketserver.UDPServer):
    def __init__(self, address, servers) -> None:
        self.address: Address = address  # Our address tuple (ip, port)
        self.servers: List[Address] = servers  # All servers addresses
        self.n_server: int = len(self.servers)

        # Sets up other variables
        self.is_leader: List[bool] = [False] * self.n_server # Are we the leader?
        self.ballot: BallotNumber = BallotNumber(1, self.address)  # Our ballot number
        self.highest_ballot_seen: List[BallotNumber] = [] # Highest ballot number seen per slot
        self.proposals: Dict[int, BallotValuePair] = dict()  # Used by leader to send P2A to acceptors, slot number -> proposal
        self.accepted: Dict[int, BallotValuePair] = dict()  # Used by acceptors in P2A and also P1B, slot number -> proposal
        self.log: Dict[int, BallotValuePair] = dict()  # The log
        self.leader_recent_ping: List[bool] = [False] * self.n_server  # True if we think leader is alive

        # for proposal phase
        self.slot_in: int = 0  # First empty slot that the server owns
        self.slot_out: int = 0  # First non-executed proposal slot (+1 last executed slot)
        self.slot_to_acceptors: DefaultDict[int, Set[Address]] = defaultdict(set)  # Used by leader to decide majority for each slot after P2B, Multimap<Integer, Address>

        # for leader election
        self.voters: Set[Address] = set()  # Yes votes for leader election, set of addresses
        self.p1b_replies: Dict[int, BallotValuePair] = dict()  # Accepted values for each slot by acceptors that voted yes, slot number -> proposal

        # lock manager app
        self.lock_manager: LockManager = LockManager()
        # ...

        # Default leader during setup
        for i in range(self.n_server):
            if self.address == self.servers[i]:
                self.is_leader[i] = True
                # Change starting slot
                self.slot_in = i
            self.highest_ballot_seen.append(BallotNumber(1, self.servers[i]))

        print("Finished init paxos", file=sys.stdout)
        print(f"servers: {self.servers}", file=sys.stdout)
        print(f"address: {self.address}", file=sys.stdout)
        print(f"highest ballot seen: {[str(i) for i in self.highest_ballot_seen]}", file=sys.stdout)
        print(f"is_leader: {self.is_leader}\n", file=sys.stdout)

        socketserver.UDPServer.__init__(self, address, PaxosHandler)

        setTimer(HeartBeatTimer(), HeartBeatTimer.HEARTBEAT_RETRY_MILLIS, self.onHeartBeatTimer)
        setTimer(HeartBeatCheckTimer(self.highest_ballot_seen), HeartBeatTimer.HEARTBEAT_RETRY_MILLIS, self.onHeartBeatCheckTimer)

    def handlePaxosRequest(self, paxos_req, sender):
        # print(f"{self.address} Got Paxos Request from {sender}", file=sys.stdout)
        # if not self.is_leader:
        #     # Broadcast to all, leader will receive and propose
        #     if sender not in self.servers:
        #         for acceptor in self.servers:
        #             if acceptor != self.address:
        #                 self.send_msg(paxos_req, acceptor)
        #                 pass
        #     return
        # Change: All replicas can accept client request, no need to forward and return
        print(f"Leader at {self.address} Handling paxos request", file=sys.stdout)

        bvp = BallotValuePair(BallotNumber(self.ballot.seq_num, self.address), paxos_req)
        self.proposals[self.slot_in] = bvp
        self.accepted[self.slot_out] = bvp

        p2a = P2A(self.address, BallotNumber(self.ballot.seq_num, self.address), self.slot_in, paxos_req)

        # increase slot_in
        # self.slot_in += 1
        self.incrementSlotIn()

        # P2A
        for acceptor in self.servers:
            if acceptor != self.address:
                self.send_msg(p2a, acceptor)

        # accept our own proposal
        self.slot_to_acceptors[p2a.slot_num].add(self.address)

        setTimer(P2ATimer(p2a), P2ATimer.P2A_RETRY_MILLIS, self.onP2ATimer)

        if len(self.slot_to_acceptors[p2a.slot_num]) > (len(self.servers) / 2):
            # Majority accepted, can put into log
            self.log[p2a.slot_num] = bvp

            self.executeLog("handlePaxosRequest")

    def handleP1A(self, p1a, sender):
        if self.highest_ballot_seen < p1a.ballot_num:
            self.highest_ballot_seen = p1a.ballot_num

            # If we are leader, make this a follower (since the one sending P1A thinks they are leader)
            if self.is_leader:
                print(f"{self.address} demotes itself from leader")
                self.is_leader = False

            p1b = P1B(self.address, self.highest_ballot_seen, self.accepted)
            self.send_msg(p1b, sender)

            setTimer(HeartBeatCheckTimer(p1a.ballot_num), HeartBeatCheckTimer.HEARTBEAT_CHECK_RETRY_MILLIS * 2, self.onHeartBeatCheckTimer)
            self.leader_recent_ping = True

    def handleP1B(self, p1b, sender):
        if self.is_leader:
            return
        # If receive majority of response from acceptors with its ballot, becomes leader
        if self.ballot == p1b.accepted_ballot and self.ballot == self.highest_ballot_seen:
            self.voters.update(sender)
            for slot in p1b.accepted:
                new_bvp = p1b.accepted[slot]
                if slot not in self.p1b_replies:
                    self.p1b_replies[slot] = new_bvp
                else:
                    cur_bvp = self.p1b_replies[slot]
                    if new_bvp.ballot_num >= cur_bvp.ballot_num:
                        self.p1b_replies[slot] = new_bvp
        if len(self.voters) > (len(self.servers) / 2 - 1):
            # This server is elected as leader
            if not self.is_leader:
                print(f"{self.address} becomes leader")
            self.is_leader = True
            # Must update its state with accepted values from acceptors
            for slot in self.p1b_replies:
                value = self.p1b_replies[slot]
                if self.status(slot) != PaxosLogSlotStatus.CHOSEN:
                    self.proposals[slot] = BallotValuePair(BallotNumber(self.ballot.seq_num, self.address), value.value)
                    # self.slot_in = max(self.slot_in, slot + 1) 
                    # Changed: Instead of doing max of slot_in, do while loop:
                    while self.slot_in <= slot:
                        # Propose no-op(SKIP message) and send it straight away for the slot we skip
                        skip = P2A(self.address, BallotNumber(self.ballot.seq_num, self.address), self.slot_in, None)
                        for acceptor in self.servers:
                            if acceptor != self.address:
                                self.send_msg(skip, acceptor)
                        bvp = BallotValuePair(BallotNumber(self.ballot.seq_num, self.address), None)
                        self.log[self.slot_in] = bvp
                        self.incrementSlotIn()
                    bvp = self.proposals[slot]

                    p2a = P2A(self.address, bvp.ballot_num, slot, bvp.value)
                    for acceptor in self.servers:
                        if acceptor != self.address:
                            self.send_msg(p2a, acceptor)

                    self.accepted[slot] = BallotValuePair(BallotNumber(self.ballot.seq_num, self.address), value.value)
                    self.slot_to_acceptors[p2a.slot_num].add(self.address)

                    setTimer(P2ATimer(p2a), P2ATimer.P2A_RETRY_MILLIS, self.onP2ATimer)

    def handleP2A(self, p2a, sender):
        # print("p2a - 0\n", file=sys.stdout)
        # print(f"highest_ballot_seen: {self.highest_ballot_seen}, p2a.ballot_num: {p2a.ballot_num}\n", file=sys.stdout)
            # No longer drop immediately if not leader
            # if self.highest_ballot_seen != p2a.ballot_num:
            #     # It's not the leader, drop it
            #     return
        # print("p2a - 1\n", file=sys.stdout)
        # Skip messages are learned imediately
        if p2a.value is not None and self.highest_ballot_seen[self.slotToLeaderIndex(p2a.slot_num)] != p2a.ballot_num:
            # Drop if value is not Skip and is proposed by non-coordinator for that slot 
            print(f"Dropped because {p2a.ballot_num} is not from {self.highest_ballot_seen[self.slotToLeaderIndex(p2a.slot_num)]}")
            return

        # If it is a skip and from the coordinator, immediately learns it (put it in their log)
        if p2a.value is None and self.highest_ballot_seen[self.slotToLeaderIndex(p2a.slot_num)] == p2a.ballot_num:
            # Learn it immediately if value is None and proposed by the leader
            bvp = BallotValuePair(BallotNumber(p2a.ballot_num.seq_num, p2a.addr), None)
            self.log[p2a.slot_num] = bvp
            print(f"Learned skip for {p2a}")
            return

        if p2a.slot_num in self.accepted:
            bvp = self.accepted[p2a.slot_num]
            if bvp.ballot_num <= p2a.ballot_num:
                # p2a ballot is higher or equal
                bvp = BallotValuePair(p2a.ballot_num, p2a.value)
                self.accepted[p2a.slot_num] = bvp
            else:
                # Don't do anything
                return
        else:
            # Have not accepted anything, then accept it
            bvp = BallotValuePair(p2a.ballot_num, p2a.value)
            self.accepted[p2a.slot_num] = bvp

        # print("p2a - 2\n")
        # self.slot_in = max(self.slot_in, p2a.slot_num + 1)
        # Changed: Instead of doing max of slot_in, do while loop:
        while self.slot_in <= p2a.slot_num:
            skip = P2A(self.address, BallotNumber(self.ballot.seq_num, self.address), self.slot_in, None)
            for acceptor in self.servers:
                if acceptor != self.address:
                    self.send_msg(skip, acceptor)
            bvp = BallotValuePair(BallotNumber(self.ballot.seq_num, self.address), None)
            self.log[self.slot_in] = bvp
            self.incrementSlotIn()

        p2b = P2B(self.address, p2a.ballot_num, p2a.slot_num)
        self.send_msg(p2b, sender)
        # print("p2a - 3\n")

    def handleP2B(self, p2b, sender):
        if not self.is_leader:
            # Not leader, drop message
            return

        # check if it is still consistent with our proposal
        bvp = self.proposals[p2b.slot_num]
        if bvp.ballot_num != p2b.ballot_num:
            # No longer in proposal
            return

        # Keep track of who have accepted
        self.slot_to_acceptors[p2b.slot_num].add(sender)
        if len(self.slot_to_acceptors[p2b.slot_num]) > (len(self.servers) / 2):
            # Majority accepted, can put into log
            self.log[p2b.slot_num] = bvp
            self.executeLog("P2B")

    def handleLeaderHeartbeat(self, heartbeat, address):
        for i in range(len(heartbeat.leader_slot)):
            if not self.is_leader[i]:
                if self.highest_ballot_seen[i] > heartbeat.ballot_num:
                # It's not "leader" heartbeat
                    return
                newLeaderSeen = False
                if self.highest_ballot_seen[i] < heartbeat.ballot_num:
                    self.is_leader[i] = False
                    self.highest_ballot_seen[i] = heartbeat.ballot_num
                    newLeaderSeen = True
                # Replace log with the bigger log slot
                for slot in heartbeat.log:
                    if slot not in self.log:
                        self.log[slot] = heartbeat.log[slot]
                    else:
                        bvp = self.log[slot]
                        new_bvp = heartbeat.log[slot]
                        if new_bvp > bvp:
                            self.log[slot] = new_bvp
                self.executeLog("Handle Leader Heartbeat")

                self.leader_recent_ping[i] = True
                if newLeaderSeen:
                    # Exponential backoff
                    setTimer(HeartBeatCheckTimer(self.highest_ballot_seen), HeartBeatCheckTimer.HEARTBEAT_CHECK_RETRY_MILLIS * 2, self.onHeartBeatCheckTimer)

    """
    Timer Handlers
    Argument 1 needs to be a Timer
    """

    def onP2ATimer(self, p2a_timer: P2ATimer):
        # print(f"{p2a_timer}: Callback", file=sys.stdout)
        # If not leader then stop timer
        if self.is_leader and not self.status(p2a_timer.p2a.slot_num) == PaxosLogSlotStatus.CHOSEN:
            for acceptor_addr in self.servers:
                if acceptor_addr != self.address:
                    self.send_msg(p2a_timer.p2a, acceptor_addr)
            setTimer(p2a_timer, P2ATimer.P2A_RETRY_MILLIS, self.onP2ATimer)

    def onHeartBeatTimer(self, heartbeat_timer: HeartBeatTimer):
        # print(f"{heartbeat_timer}: Callback", file=sys.stdout)
        self.executeLog("HB timer")
        # Changed: Heartbeat is done for all servers, since they are all leader
        for acceptor_addr in self.servers:
            if acceptor_addr != self.address:
                lh = LeaderHeartbeat(self.address, self.log, self.ballot, self.is_leader)
                self.send_msg(lh, acceptor_addr)
        setTimer(heartbeat_timer, HeartBeatTimer.HEARTBEAT_RETRY_MILLIS, self.onHeartBeatTimer)

    def onHeartBeatCheckTimer(self, heartbeat_check_timer: HeartBeatCheckTimer):
        # Change: Since everyone is leader, heartbeat is now being sent from all leader, we 
        # need to keep track which heartbeat is last heard, and if one leader is dead, we take
        # over that leader slot. 
        
        # print(f"{heartbeat_check_timer}: Callback", file=sys.stdout)
        for i in range(len(heartbeat_check_timer.ballot_num)):
            if not self.is_leader[i]:
                if heartbeat_check_timer.ballot_num[i] == self.highest_ballot_seen[i]:
                    # Check if the leader alive or not
                    if not self.leader_recent_ping[i]:
                        # Leader is dead
                        # Just for randomization for contention issue
                        # TODO: Uncomment below for the electing leader, will need to add index 
                        # to know which slot are we trying to become the leader of. 
                        # self.__electLeader()
                        return
                    self.leader_recent_ping[i] = False
                    setTimer(heartbeat_check_timer, HeartBeatCheckTimer.HEARTBEAT_CHECK_RETRY_MILLIS, self.onHeartBeatCheckTimer)

    def onLeaderElectionTimer(self, leader_election_timer: LeaderElectionTimer):
        # print(f"{leader_election_timer}: Callback", file=sys.stdout)
        if self.highest_ballot_seen == self.ballot and not self.is_leader:
            for acceptor_addr in self.servers:
                if acceptor_addr != self.address:
                    self.send_msg(leader_election_timer.p1a, acceptor_addr)

            setTimer(leader_election_timer, LeaderElectionTimer.LEADER_ELECTION_TIMER, self.onLeaderElectionTimer)

    def __electLeader(self):
        print(f"{self.address} Detected leader is dead, try to get ourself to become leader")
        # Try to elect ourself as the leader
        # Try to get elected as leader at the beginning of time
        self.voters.clear()
        self.p1b_replies.clear()
        self.p1b_replies.update(self.accepted)

        # Increase ballot until higher than the highest we saw before electing
        while self.ballot < self.highest_ballot_seen:
            self.ballot.increaseBallot()

        p1a: P1A = P1A(self.address, BallotNumber(self.ballot.seq_num, self.address))
        self.highest_ballot_seen = p1a.ballot_num

        # P1A
        for acceptor_addr in self.servers:
            if acceptor_addr != self.address:
                self.send_msg(p1a, acceptor_addr)
        setTimer(LeaderElectionTimer(p1a), LeaderElectionTimer.LEADER_ELECTION_TIMER, self.onLeaderElectionTimer)

    def status(self, log_slot_num) -> PaxosLogSlotStatus:
        if log_slot_num in self.log:
            return PaxosLogSlotStatus.CHOSEN

        return PaxosLogSlotStatus.EMPTY

    def executeLog(self, context):
        # print(f"{self.address} executes log {[str(i) + str(self.log[i]) for i in self.log]}")
        foundEmpty = False
        for j in range(self.slot_out, self.slot_in):
            status = self.status(j)
            if status == PaxosLogSlotStatus.CHOSEN:
                if foundEmpty:
                    continue
                bvp = self.log[j]
                # execute and reply to client, skips no-op
                if bvp.value is not None:
                    lock_res = self.lock_manager.execute(bvp.value.lock_command, bvp.value.addr)

                    # print(f"{context}, executeLog - sending resp back to client\n")
                    self.socket.sendto(pickle.dumps(PaxosResult(self.address, lock_res, bvp.value.lock_command)), bvp.value.addr)
                    if self.is_leader[self.slotToLeaderIndex(j)]:
                        self.lock_manager.lockstatus()

                self.slot_out += 1
            else:
                foundEmpty = True
                if status == PaxosLogSlotStatus.EMPTY:
                    if self.is_leader[self.slotToLeaderIndex(j)]:
                        p2a = P2A(self.address, BallotNumber(self.ballot.seq_num, self.address), j, None)
                        for tmp_server in self.servers:
                            if tmp_server != self.address:
                                self.send_msg(p2a, tmp_server)
                        self.slot_to_acceptors[j].add(self.address)
                        self.proposals[j] = BallotValuePair(BallotNumber(self.ballot.seq_num, self.address), p2a.value)
                        self.accepted[j] = BallotValuePair(BallotNumber(self.ballot.seq_num, self.address), p2a.value)

    def slotToLeaderIndex(self, slot) -> int:
        return slot % self.n_server

    def incrementSlotIn(self):
        # Increment it into the first number greater than slot_out that we own
        # TODO: Can probably be optimized
        for i in range(1, self.n_server + 1):
            temp = self.slot_in + i
            if self.is_leader[self.slotToLeaderIndex(temp)]:
                self.slot_in = temp
                return

    # Serialize obj, and send the message.
    # This function will not wait for reply (communication between paxos nodes)
    def send_msg(self, obj, dest_address: Address):
        # print(f"Sending {obj} to {dest_addr}", file=sys.stdout)
        data = pickle.dumps(obj)
        self.socket.sendto(data, dest_address)


class PaxosHandler(socketserver.BaseRequestHandler):
    """
    The request handler class for our server.

    It is instantiated once per connection to the server, and must
    override the handle() method to implement communication to the
    client.

    This will receive lock() / unlock() command from client
    Handler for proposals and leader election will be a different class (I think, probably using different port?)
    """

    def handle(self):
        # Use self.arg to get servers fields
        # Note that we guarantee communication client to server is exactly once,
        # no need to worry about duplicate request and proposing two slot. 
        data = self.request[0].strip()
        # data = self.request.recv(1024).strip() # DEBUG LOG
        # print(f"{self.client_address} wrote: {data}", file=sys.stdout)  # DEBUG LOG

        # test deserialize
        message = pickle.loads(data)

        # print(message, "\n", file=sys.stdout)  # debug
        if isinstance(message, PaxosRequest):
            print(f"got paxos request {message}", file=sys.stdout)  # debug
            self.server.handlePaxosRequest(message, self.client_address)

        elif isinstance(message, P1A):
            print(f"got p1a {message}", file=sys.stdout)  # debug
            self.server.handleP1A(message, self.client_address)

        elif isinstance(message, P1B):
            print(f"got p1b {message}", file=sys.stdout)  # debug
            self.server.handleP1B(message, self.client_address)

        elif isinstance(message, P2A):
            print(f"got p2a {message}", file=sys.stdout)  # debug
            self.server.handleP2A(message, message.addr)

        elif isinstance(message, P2B):
            print(f"got p2b {message}", file=sys.stdout)  # debug
            self.server.handleP2B(message, self.client_address)

        elif isinstance(message, LeaderHeartbeat):
            # print(f"got heartbeat {message}", file=sys.stdout)  # debug
            self.server.handleLeaderHeartbeat(message, self.client_address)

        else:
            # prob just ignore message
            print("unrecognized message", file=sys.stdout)  # debug


if __name__ == "__main__":
    HOST, PORT = sys.argv[1], int(sys.argv[2])
    # addresses = ((HOST, PORT),)
    addresses = []
    for i in range(3, len(sys.argv), 2):
        addresses.append((sys.argv[i], int(sys.argv[i + 1])))
    # Create the server, binding to localhost on port 9999
    server = Paxos((HOST, PORT), addresses)
    # Activate the server; this will keep running until you
    # interrupt the program with Ctrl-C
    server.serve_forever()