import time, random, six, asyncio, datetime, traceback
from typing import List, Dict, Optional
from paramiko.client import SSHClient, AutoAddPolicy
from paramiko import RSAKey
from .do import (
    CommandContextDO,
    CommandHostDO,
    CommandSettingDetailDO,
    CommandResponseDO
)

class Session:
    def __init__(self, client, host:CommandHostDO, timeout=3600) -> None:
        self.client = client
        self.expired_time = time.time()+timeout
        self.host = host

    def is_expired(self)->bool:
        return self.expired_time<time.time()

    async def execute(self, command, timeout):
        return self.client.exec_command(command,timeout=timeout)

    def close(self):
        self.client.close()

class TaskRunner:
    def __init__(
        self, 
        id:int, 
        task:asyncio.Task, 
        session: Session, 
        command_index:int
    )->None:
        self.id = id
        self.task = task
        self.session = session
        self.command_index = command_index
        self.response = None
        self.start_time = datetime.datetime.utcnow()
        self.end_time = None
        self.status = None
        self.failed_time = 0

    def reset(self):
        self.status=None
        self.failed_time=0
        self.set_running()
    
    def set_response(self, response):
        self.response = response

    def set_running(self):
        self.status='running'

    def is_pending(self)->bool:
        return self.status=='pending' 
    
    def set_pending(self, delay:float):
        if delay>10:
            raise ValueError('Greater then 10s delay time is not supported!')
        self.status='pending'
        self.restart_time = time.time()+delay
        print(self.restart_time)

    def is_ready_restart(self)->bool:
        print(self.restart_time)
        if time.time()>self.restart_time:
            return True
        else:
            return False


class CommandClient:
    def __init__(self, clear_prob=0.1, session_timeout=6000, forks=10) -> None:
        self.sessions: Dict[str, Session] = {}
        self.clear_prob = clear_prob
        self.session_timeout=session_timeout
        self.forks=forks

    def clear(self):
        if random.random()>self.clear_prob:
            return 
        for ip in self.sessions.keys():
            session = self.sessions[ip]
            if session.is_expired():
                session.close()
                del self.sessions[ip]

    def connect(self, host:CommandHostDO, key_filename:str)->Session:
        self.clear()
        ip, port, username = host.ip, host.port, host.username
        if ip in self.sessions:
            return self.sessions[ip]
        client = SSHClient()
        client.load_system_host_keys()
        client.set_missing_host_key_policy(AutoAddPolicy())
        client.connect(
            hostname=ip, 
            port=port, 
            username=username, 
            key_filename=key_filename, 
            timeout=self.session_timeout
        )
        session = Session(client=client, host=host, timeout=self.session_timeout)
        self.sessions[ip] = session
        return session

    def _create_ssh_keypair(comment=None, bits=4096):
        """Generate an ssh keypair for use on the overcloud"""
        if comment is None:
            comment = "Generated by ssh-cmd"
        key = RSAKey.generate(bits)
        keyout = six.StringIO()
        key.write_private_key(keyout)
        private_key = keyout.getvalue()
        public_key = '{} {} {}'.format(key.get_name(), key.get_base64(), comment)
        return {
            'private_key': private_key,
            'public_key': public_key,
        }

    async def run_tasks(
        self, 
        commands: List[CommandSettingDetailDO], 
        sessions: List[Session],
        forks: int
    )->List[CommandResponseDO]:
        ret_resp = []
        runners: Dict[str, TaskRunner] = {}
        for i in range(500000):
            if len(runners)<forks and sessions:
                # add runner
                session = sessions.pop()
                cmd = commands[0].command
                timeout = commands[0].timeout
                try:
                    command_task = asyncio.create_task(session.execute(cmd, timeout))
                except:
                    response = CommandResponseDO(
                        status='failed',
                        ip=session.host.ip,
                        stdout=stdout,
                        stderr=stderr,
                        cmd=cmd,
                        exception=traceback.format_exc()
                    )
                    ret_resp.append(response)
                    continue
                runner = TaskRunner(id=i, task=command_task, session=session, command_index=0)
                runners[i]=runner
            # check runner
            for runner_idx in list(runners.keys()):
                runner = runners[runner_idx]
                task:asyncio.Task = runner.task
                command_idx = runner.command_index
                command = commands[command_idx]
                # retry
                if runner.is_pending() and runner.is_ready_restart():
                    try:
                        runner.restart_time = None
                        runner.set_running()
                        cmd, timeout = command.command, command.timeout
                        command_task = asyncio.create_task(session.execute(cmd, timeout))
                    except:
                        response:CommandResponseDO = runner.response
                        response.status = 'failed'
                        response.cmd += '\n'+'-'*20+f'\n{cmd}'
                        response.exception=traceback.format_exc()
                        ret_resp.append(response)
                        del runners[runner_idx]
                        continue
                    runner.task = command_task
                    print(runner.response)
                elif runner.is_pending():
                    continue
                # complete
                elif task.done():
                    _, stdout, stderr = task.result()
                    return_code = stdout.channel.recv_exit_status()
                    stdout = stdout.read().decode("utf-8") 
                    stderr = stderr.read().decode("utf-8")
                    end_time = datetime.datetime.utcnow()
                    delta_time = end_time-runner.start_time
                    print(stdout)
                    # result failed
                    if return_code>0:
                        runner.failed_time += 1
                        # if beyond the retry times
                        if runner.failed_time>=command.retries:
                            if runner.response:
                                response = runner.response
                                response.status = 'failed'
                                response.stdout += '\n'+'-'*20+f'\n{stdout}'
                                response.stderr += '\n'+'-'*20+f'\n{stderr}'
                                response.cmd += '\n'+'-'*20+f'\n{command.command}'
                                response.end_time = end_time.isoformat()
                                response.delta_time = str(delta_time)
                            else:
                                response = CommandResponseDO(
                                    status='failed',
                                    ip=session.host.ip,
                                    stdout=stdout,
                                    stderr=stderr,
                                    cmd=command.command,
                                    start_time = runner.start_time.isoformat(),
                                    end_time = end_time.isoformat(),
                                    delta_time=str(delta_time),
                                )
                            ret_resp.append(response)
                            del runners[runner_idx]
                        else:
                            runner.set_pending(command.delay)
                    # result succeed
                    else:
                        session = runner.session
                        if runner.response is None:
                            runner.response = CommandResponseDO(
                                status='succeed',
                                ip=session.host.ip,
                                stdout=stdout,
                                stderr=stderr,
                                cmd=command.command,
                                start_time = runner.start_time.isoformat(),
                                end_time = end_time.isoformat(),
                                delta_time=str(delta_time),
                            )
                        else:
                            runner.response.stdout += '\n'+'-'*20+f'\n{stdout}'
                            runner.response.stderr += '\n'+'-'*20+f'\n{stderr}'
                            runner.response.cmd += '\n'+'-'*20+f'\n{command.command}'
                            runner.response.end_time = end_time.isoformat()
                            runner.response.delta_time = str(delta_time)
                        runner.command_index += 1
                        
                        if runner.command_index>=len(commands):
                            # all command are executed successfully
                            ret_resp.append(runner.response)
                            del runners[runner_idx]
                        else: # execute next command
                            runner.reset()
                            cmd = commands[runner.command_index].command
                            timeout = commands[runner.command_index].timeout
                            try:
                                command_task = asyncio.create_task(session.execute(cmd, timeout))
                            except:
                                response:CommandResponseDO = runner.response
                                response.status = 'failed'
                                response.cmd += '\n'+'-'*20+f'\n{cmd}'
                                response.exception=traceback.format_exc()
                                ret_resp.append(response)
                                del runners[runner_idx]       
                                continue
                            runner.task = command_task

            if not sessions and not runners:
                return ret_resp
            await asyncio.sleep(0.01)


    def execute(
        self, 
        commands: List[CommandSettingDetailDO],
        hosts: List[CommandHostDO],
        key_filename: str,
        context: Optional[CommandContextDO]=None,
    )->List[CommandResponseDO]:
        sessions:List[Session] = []
        responses = []
        if context:
            forks = context.forks
        else:
            forks = self.forks
            
        for host in hosts:
            try:
                sessions.append(self.connect(host, key_filename))
            except:
                responses.append(
                    CommandResponseDO(
                        status='unreachable',
                        ip=host.ip,
                        stdout=None,
                        stderr=None,
                        cmd=None,
                        exception=traceback.format_exc()
                    )
                )
        responses.extend(asyncio.run(self.run_tasks(commands, sessions, forks)))
        return responses

    def close(self):
        for ip in self.sessions:
            session = self.sessions[ip]
            session.client.close()
    
        


