import mariadb
import select
import subprocess
import socket
from time import sleep
import logging
import configparser
import sched, time
import threading

logconfig = configparser.RawConfigParser()
logconfig.read('/home/akkadianuser/logs-config.ini')
loglevel = logconfig.get('General', 'galera-replication')

log_level = int(loglevel)
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s', filename='/home/akkadianuser/logs/galera_failover_recovery.log', level=log_level, datefmt='%Y-%m-%d %H:%M:%S')


# Connect to MariaDB Platform
class MariadbGaleraHelper:

    def __init__(self, user, password):
        self.user = user
        self.password = password
        self.checker_sch = sched.scheduler(time.time, time.sleep)

    def connect(self):
        logging.debug("Trying to connect")
        try:
            self.conn = mariadb.connect(
                user=self.user,
                password=self.password,
                host="localhost",
                port=3306
            )
            # Get Cursor
            self.cur = self.conn.cursor()
            return True
        except mariadb.Error as e:
            logging.debug("Error connecting to MariaDB Platform: {e}")
            return False

    def disconnect(self):
        self.conn.close()

    def getClusterStatus(self):
        logging.debug("Getting cluster status")
        status = "undetermined"
        try:
            self.cur.execute("SHOW GLOBAL STATUS LIKE 'wsrep_cluster_status';")
            for Variable_name, Value in self.cur:
                status = Value
        except:
            # mariadb is down here.
            pass

        logging.debug("cluster status: " + status)
        return status

    def setBootstrapVariable(self):
        try:
            self.cur.execute("SET GLOBAL wsrep_provider_options='pc.bootstrap=YES';")
            self.cur.execute("SET GLOBAL wsrep_provider_options='pc.ignore_sb=TRUE';")
        except mariadb.OperationalError as e:
            logging.debug("Error setting variables into MariaDB: {e}")

    def getClusterSize(self):
        size = ''
        self.cur.execute("SHOW GLOBAL STATUS LIKE 'wsrep_cluster_size';")
        for Variable_name, Value in self.cur:
            size = Value

        logging.debug("Cluster size: {size}")

        return size

    def nodeLastCommit(self):
        last_commit = ''
        self.cur.execute("SHOW STATUS LIKE 'wsrep_last_committed';")
        for Variable_name, Value in self.cur:
            last_commit = Value

        logging.debug("Cluster size: {last_commit}")
        return last_commit

    def statusMariadb(self):
        cp = subprocess.run(['systemctl status mariadb'], universal_newlines=True, stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE, shell=True)

        logging.debug("mariadb status:")
        logging.debug(cp.stdout)
        return cp.stdout, cp.stderr, cp.returncode

    def stopMariadb(self):
        cp = subprocess.run(['systemctl stop mariadb'], universal_newlines=True, stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE, shell=True)
        return cp.stdout, cp.stderr, cp.returncode

    def startMariadb(self):
        cp = subprocess.run(['systemctl start mariadb'], universal_newlines=True, stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE, shell=True)

        return cp.stdout, cp.stderr, cp.returncode

    def bootstrapMariadb(self):
        cp = subprocess.run(['galera_new_cluster'], universal_newlines=True, stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE, shell=True)

        return cp.stdout, cp.stderr, cp.returncode

    @classmethod
    def sendRemoteCommand(cls, remote, msg):
        # TODO: what to do if akkadian-executor is down????
        # Create a TCP/IP socket and use akkadian-executor
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.settimeout(10)
        # Connect the socket to the port where the server is listening
        server_address = (remote, 65432)
        logging.debug('connecting to {} port {}'.format(*server_address))
        logging.debug('Sending: {msg}')
        sock.connect(server_address)

        data = ''
        try:

            # Send data
            message = str(msg).encode()
            logging.debug('sending {!r}'.format(message))
            sock.send(message)
            i = 0
            while True:
                try:
                    sleep(1)
                    data = sock.recv(1024)
                    if data or i == 10:
                        break
                    else:
                        i = i + 1
                except KeyboardInterrupt:
                    break

        finally:
            logging.debug('closing socket')
            sock.close()

        return data.decode('utf-8')

    def stopMariaDbRemote(self, remote):
        self.sendRemoteCommand(remote, "stopMariadb")

    def startMariaDbRemote(self, remote):
        self.sendRemoteCommand(remote, "startMariadb")

    def bootstrapNode(self, remote):
        self.sendRemoteCommand(remote, "bootstrapNode")

    def getLastCommitRemote(self, remote):
        data = self.sendRemoteCommand(remote, "mariadbLastCommit")
        return data

    def getMariaDbStatusRemote(self, remote):
        data = self.sendRemoteCommand(remote, "statusMariadb")
        return data

    def forceBootstrapOnRemote(self, remote):
        data = self.sendRemoteCommand(remote, "forceBootstrap")
        return data

    def killMariadbRemote(self, remote):
        data = self.sendRemoteCommand(remote, "killMariadb")
        return data

    def isClientSideWorking(self, remote):
        data = self.sendRemoteCommand(remote, "failoverClientSideStatus")
        return data

    def restartClientSideRemote(self, remote):
        data = self.sendRemoteCommand(remote, "restartClientSide")
        return data

    def forceBootstap(self):
        cp = subprocess.run(['sudo sed -i "/safe_to_bootstrap/s/0/1/" /var/lib/mysql/grastate.dat'],
                            universal_newlines=True, stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE, shell=True)
        return cp.stdout, cp.stderr, cp.returncode

    def killMariadb(self):
        cp = subprocess.run(['sh /home/akkadianuser/scripts/ha/replication_new_galera_approach/kill_mariadb.sh'],
                            universal_newlines=True, stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE, shell=True)
        return cp.stdout, cp.stderr, cp.returncode

    def disableMariadbRemote(self, remote):
        self.sendRemoteCommand(remote, "disableMariadb")

    def enableMariadbRemote(self, remote):
        self.sendRemoteCommand(remote, "enableMariadb")

    def schedule_myself_until_recover(self, remote_node):
        try:
            self.checker_sch.enter(15, 1, self.checkStatusAndIfNeededRestartClusterWhenAllNodesUp(self, remote_node),
                                   (self.checker_sch,))
            t = threading.Thread(target=self.checker_sch.run)
            t.start()
        except:
            logging.debug("checkStatus cannot be setup...")

    def cancel_schedule_myself(self):
        try:
            self.checker_sch.cancel(self.checker_sch)
        except ValueError as ve:
            logging.debug('Trying to cancel unexisting event: cancel_schedule_myself')

    def checkStatusAndIfNeededRestartClusterWhenAllNodesUp(self, remote_node):

        cancel_sch = False

        logging.debug("Trying to recover mariadb with two nodes if needed")
        maria_local_status = self.statusMariadb()
        if maria_local_status[0].find("active (running)") != -1:
            logging.debug("local mariadb is running")
            remote_status = self.getMariaDbStatusRemote(remote_node)

            if remote_status.find("active (running)") != -1:
                logging.debug("remote mariadb is running")
                cluster_size = self.getClusterSize()
                cluster_status = self.getClusterStatus()

                if cluster_status == "Disconnected" or cluster_size != "2":
                    logging.debug("local cluster status is disconnected or local cluster size is not accurate")
                    cluster_last_commit_local = self.nodeLastCommit()
                    cluster_last_commit_remote = self.getLastCommitRemote(remote_node)

                    if cluster_last_commit_local != '' and cluster_last_commit_remote != '':
                        self.stopMariaDbRemote(remote_node)
                        self.stopMariadb()

                        if cluster_last_commit_local > cluster_last_commit_remote:
                            logging.debug("local cluster last commit is bigger, so bootstrap this node")
                            # bootstrap mariadb here
                            self.stopMariadb()
                            sleep(3)
                            self.bootstrapMariadb()
                            self.stopMariaDbRemote(remote_node)
                            sleep(3)
                            self.startMariaDbRemote(remote_node)
                            # start mariadb normal in remote
                        elif cluster_last_commit_local < cluster_last_commit_remote:
                            logging.debug("remote cluster last commit is bigger, so bootstrap remote")
                            # bootstrap mariadb remote
                            self.stopMariaDbRemote(remote_node)
                            sleep(3)
                            self.bootstrapNode(remote_node)
                            # start mariadb normal here
                            self.stopMariaDb()
                            sleep(3)
                            self.startMariadb()
                else:
                    logging.debug("Everything is working... ending the recovery by now")
                    cancel_sch = True
                    self.cancel_schedule_myself()

            elif remote_status.find("activating (start)") != -1:
                logging.debug("remote node mariabd status is stuck into activating, trying to recover")
                # Means the other side is down, probably we need to restart both
                self.killMariadbRemote(remote_node)
                sleep(1)
                self.stopMariadb()
                sleep(1)

                self.forceBootstap()
                self.bootstrapMariadb()
                sleep(3)

                local_status = self.statusMariadb()
                if local_status[0].find("active (running)") != -1:
                    self.startMariaDbRemote()
            else:
                logging.debug("remote node is fully down, start it normally")
                self.disableMariadbRemote(remote_node)
                self.killMariadbRemote(remote_node)
                sleep(3)
                self.stopMariadb()
                self.bootstrapMariadb()
                sleep(3)
                status = self.statusMariadb()
                if status[0].find("active (running)") == -1:
                    self.bootstrapNode(remote_node)
                    sleep(3)
                    remote_status = self.getMariaDbStatusRemote()
                    if remote_status.find("active (running)") == -1:
                        self.forceBootstap()
                        self.bootstrapMariadb()
                        self.startMariaDbRemote(remote_node)
                        self.enableMariaDbRemote(remote_node)

        else:  # maria_local_status.find("activating (start)") != -1:
            logging.debug("local mariadb is down, trying to recover")
            self.killMariadb()
            remote_status = self.getMariaDbStatusRemote(remote_node)
            if remote_status.find("active (running)") != -1:
                logging.debug("Remote is active")
                self.startMariadb()
                sleep(3)
                maria_local_status = self.statusMariadb()
                if maria_local_status[0].find("active (running)") != -1:
                    # it did not work, try forcing bootstrap
                    self.killMariadb()
                    self.killMariadbRemote(remote_node)

                    self.forceBootstap()
                    self.bootstrapMariadb()
                    sleep(3)
                    self.startMariaDbRemote(remote_node)

            else:
                logging.debug("mariadb remote is down")
                # none is up
                self.killMariadb()
                self.killMariadbRemote(remote_node)

                sleep(5)

                self.forceBootstap()
                self.bootstrapMariadb()

                sleep(5)

                self.startMariaDbRemote(remote_node)

        if not cancel_sch:
            logging.debug("Setting up scheduler...")
            self.schedule_myself_until_recover(remote_node)
