#!/usr/bin/env python3

import re
import os
import sys
import time
import argparse
import logging
import mysql.connector
from collections import Counter


"""
读取慢查询文件分析它里面的 SQL 分布情况


"""


name = os.path.basename(__file__)
logging.basicConfig(format="%(asctime)s %(levelname)s %(message)s",level=logging.INFO)

def parser_cmd_args() -> argparse.ArgumentParser:
    """
    处理命令行参数
    """
    parser = argparse.ArgumentParser(name)
    parser.add_argument('--limit',type=int,default=7)
    parser.add_argument('sqlfile',type=str,default='slow.log',help='slow query log file')
    args = parser.parse_args()
    #print(args)
    return args

class BaseAnalyze(object):
    """
    所有分析实现的基类
    """
    def analyze(self,line:str)->None:
        """
        读入一行数据统计信息
        """
        raise NotImplementedError("请在子类中实现 analyze 方法")

    def __str__(self):
        """
        格式化输出
        """
        raise NotImplementedError("请在子类中实现 __str__ 方法")

class SqlDistributionAnalyze(BaseAnalyze):
    def __init__(self):
        self._sql_type_counter = Counter({
            'select':0,
            'insert':0,
            'update':0,
            'delete':0
        })
        self.select = re.compile(rb'select ([\w\W]*) from ',re.IGNORECASE)
        self.insert = re.compile(rb"insert into ([\w\W]*) values",re.IGNORECASE)
        self.delete = re.compile(rb"delete from ([\w\W]*)",re.IGNORECASE)
        self.update = re.compile(rb"update ([\w\W]*) set",re.IGNORECASE)
    
    def analyze(self,line):
        """
        分析输入的 line 是 insert,update,delete,select 的哪一种，并统计它们的次数
        """
        
        if self.select.match(line):
            self._sql_type_counter.update({'select':1})
            return
        
        if self.insert.match(line):
            self._sql_type_counter.update({'insert':1})
            return

        if self.delete.match(line):
            self._sql_type_counter.update({'delete':1})
            return

        if self.update.match(line):
            self._sql_type_counter.update({'update':1})
            return

    def __str__(self):
        """
        格式化输出
        """
        s = "-"*48+'\n'
        s = s +  "{0:<47}\n".format("SQL出现频率如下:")
        s = s + "-"*48+'\n'
        for k in self._sql_type_counter:
            t = "{0:<24}|{1:<23}\n".format(k,self._sql_type_counter[k])
            s = s + t
        s = s + "-"*48+'\n\n'
        return s

class TableDistributionAnalyze(BaseAnalyze):
    """
    统计表的分布情况
    """
    def __init__(self,limit=7):
        self._table_counter = Counter()
        self.limit = limit
        self.select = re.compile(rb"select ([\w\W]*) from ([.\w]*)",re.IGNORECASE)
        self.insert = re.compile(rb"insert([\s]*)into([\s]*)([.\w]*)",re.IGNORECASE)
        self.update = re.compile(rb"update([\s]*)([.\w]*)([\s]*)set",re.IGNORECASE)
        self.delete = re.compile(rb"delete([\s]*)from([\s]*)([.\w]*)",re.IGNORECASE)


    def analyze(self,line):
        """
        分析 SQL 语句所操作的表
        """
        p = self.select.match(line)
        if p and p.group(2) != b'':
            self._table_counter.update({
                p.group(2):1
            })
            return
        
        p = self.insert.match(line)
        if p and p.group(3) != b'':
            self._table_counter.update({
                p.group(3):1
            })
            return

        p = self.update.match(line)
        if p and p.group(2) != b'':
            self._table_counter.update({
                p.group(2):1
            })
            return 
        
        p = self.delete.match(line)
        if p and p.group(3) != b'':
            self._table_counter.update({
                p.group(3):1
            })

    def __str__(self):
        s = "-"*48+'\n'
        s = s +  "{0:<47}\n".format("表名出现频率如下:")
        s = s + "-"*48+'\n'
        common = self._table_counter.most_common(self.limit)
        for k,v in common:
            t = "{0:<40}|{1:<7}\n".format(k.decode('utf8'),v)
            s = s + t
        s = s + "-"*48+'\n'
        return s
        
#class RowsAnalyze(BaseAnalyze):
##Query_time: 0.101302  Lock_time: 0.000084 Rows_sent: 37  Rows_examined: 269513
#    pass

class main():
    starting = time.time()
    # 处理命令行参数
    args = parser_cmd_args()

    # 检查文件是否存在
    if not os.path.isfile(args.sqlfile):
        logging.error(f"file {args.sqlfile} not exists.")
        sys.exit(1)
    
    # 全部以二进制的形式打开、避免出现解码问题
    distributions = [SqlDistributionAnalyze(),TableDistributionAnalyze(args.limit)]
    with open(args.sqlfile,'rb') as sql_file_obj:
        for line in sql_file_obj:
            if line.startswith(b'#') or line.startswith(b'SET'):
                continue
            for dis in distributions:
                dis.analyze(line)

    for dis in distributions:
        print(dis)

    ending = time.time()
    print(f"日志分析用时 {(ending - starting):.2f} s")




if __name__ == "__main__":
    main()

    


