#!/usr/bin/env python3

import argparse
import os
import mysql.connector
from contextlib import ContextDecorator
from loguru import logger as log
import sys


version='0.0.2'

class MySQLizer(ContextDecorator):
    def __init__(self, db, host, socket, user, password):
        self.db = db
        self.host = host
        self.socket = socket
        self.user = user
        self.password = password

    def __enter__(self):
        if self.socket:
            log.debug("Connect to socket {}".format(self.socket))
            self.cnx = mysql.connector.connect(unix_socket=self.socket, user=self.user, passwd=self.password, database=self.db)
        else:
            log.debug("Connect to host {}".format(self.host))

            self.cnx = mysql.connector.connect(host=self.host, user=self.user, passwd=self.password, database=self.db)
        self.cursor = self.cnx.cursor(prepared=True)
        self.cursor2 = self.cnx.cursor(prepared=True)
        log.debug("Connected!")
        return self

    def __exit__(self, *exc):
        log.debug("Close mysql connection")
        self.cursor.close()
        self.cnx.close()        
        return False

def get_mysql_socket(defsocket):
    
    for sock in ['/var/run/mysqld/mysqld.sock']:
        if os.path.exists(sock):
            return sock
    return defsocket

def get_args():
    parser = argparse.ArgumentParser(description='MySQL grep utulity, version: {}'.format(version))

    def_socket = get_mysql_socket(os.getenv('MYSQL_SOCKET'))

    parser.add_argument('--host', default=os.getenv('MYSQL_HOST','localhost'))
    parser.add_argument('--socket', default=def_socket, help='Path to MySQL socket file')
    parser.add_argument('-u', '--user', default=os.getenv('MYSQL_USER', os.getlogin()))
    parser.add_argument('-p', '--password', default=os.getenv('MYSQL_PASS'))
    parser.add_argument('-t', '--tables', nargs='+', default=list(), help="tables (default: all)")
    parser.add_argument('-v', '--verbose', default=False, action='store_true', help="Verbose mode")
    parser.add_argument('--suppress', default=False, action='store_true', help="Do not print value (only table, field name and primary key")
    parser.add_argument('--like', default=False, action='store_true', help="use SQL LIKE instead of =")
    parser.add_argument('--regexp', default=False, action='store_true', help="use SQL REGEXP instead of =")
    parser.add_argument('--limit', type=int)
    parser.add_argument('--types', default='TND', help='which types to process. T for text/char/varchar, N for any int/decimal, D for date. Def: TND')
    parser.add_argument('--all', default=False, action='store_true', help="display ALL fields from matching rows")


    parser.add_argument('dbname')
    parser.add_argument('needle')


    return parser.parse_args()

def main():
    args = get_args()

    if not args.verbose:
        log.remove()
        log.add(sys.stderr, level="INFO")        

    log.debug("DEBUG message")

    tables = args.tables
    needle = args.needle
    limit = args.limit


    with MySQLizer(args.dbname, args.host, args.socket, args.user, args.password) as dbc:

        if not tables:
            # fill table when none provided
            dbc.cursor.execute("SHOW TABLES;")
            for tname in dbc.cursor:
                log.debug("Detected table {}".format(tname[0]))
                tables.append(tname[0])
        


        for table in tables:
            log.debug("Examine table: {}".format(table))
            dbc.cursor.execute("DESC {}".format(table))
            pkfield = None
            field_list = list()
            for ftuple in dbc.cursor:
                field = dict(zip(dbc.cursor.column_names, ftuple))
                log.debug(field)
                ftype = field['Type'].split('(')[0]
                
                if field['Key']=='PRI':
                    pkfield = field['Field']

                if ftype in ['varchar', 'char', 'text'] and 'T' in args.types:
                    field_list.append(field['Field'])

                elif ftype in ['int', 'smallint', 'decimal'] and 'N' in args.types:
                    field_list.append(field['Field'])

                elif ftype in ['date', 'datetime'] and 'D' in args.types:
                    field_list.append(field['Field'])


            for field in field_list:
                log.debug("Grepping table {} field {}".format(table, field))
                t = (needle,)
                log.debug(t)
                
                if args.all:
                    flist='*'
                elif pkfield:
                    flist = '{}, {}'.format(pkfield, field)
                else:
                    flist = field
                
                if args.like:
                    cmp = 'LIKE'
                elif args.regexp:
                    cmp = 'REGEXP'
                else:
                    cmp = '='

                if args.limit:
                    limitstr = " LIMIT {}".format(limit)
                else:
                    limitstr = ""

                query = """SELECT {} FROM {} WHERE {} {} %s {}""".format(flist, table, field, cmp, limitstr)
                log.debug(query)

                dbc.cursor.execute(query, t)
                for r in dbc.cursor:
                    result = dict(zip(dbc.cursor.column_names, r))
                    #if result[field] is None:
                    #    continue

                    if pkfield:
                        idstr = '({}={})'.format(pkfield, result[pkfield])
                    else:
                        idstr = ''

                    s = "{}{}/{} ".format(table, idstr, field)

                    if args.all:
                        s+= str(result)

                    elif not args.suppress:                        
                        s += str(result[field])

                    print(s)
                    if limit:
                        limit -= 1


if __name__ == '__main__':
    main()