#!/usr/bin/env python3

import argparse
import os
# import mysql.connector
# import mariadb
from loguru import logger as log
import sys
import pymysql


version='0.0.4'

def get_mysql_socket(defsocket):
    
    for sock in ['/var/run/mysqld/mysqld.sock']:
        if os.path.exists(sock):
            return sock
    return defsocket

def get_args():
    parser = argparse.ArgumentParser(description='MySQL grep utulity, version: {}'.format(version))

    def_socket = get_mysql_socket(os.getenv('MYSQL_SOCKET'))

    parser.add_argument('--host', default=os.getenv('MYSQL_HOST','localhost'))
    parser.add_argument('--socket', default=def_socket, help='Path to MySQL socket file')
    parser.add_argument('-u', '--user', default=os.getenv('MYSQL_USER', os.getlogin()))
    parser.add_argument('-p', '--password', default=os.getenv('MYSQL_PASS'))
    parser.add_argument('-t', '--tables', nargs='+', default=list(), help="tables (default: all)")
    parser.add_argument('-v', '--verbose', default=False, action='store_true', help="Verbose mode")
    parser.add_argument('--suppress', default=False, action='store_true', help="Do not print value (only table, field name and primary key")
    parser.add_argument('--like', default=False, action='store_true', help="use SQL LIKE instead of =")
    parser.add_argument('--regexp', default=False, action='store_true', help="use SQL REGEXP instead of =")
    parser.add_argument('--limit', type=int)
    parser.add_argument('--types', default='TND', help='which types to process. T for text/char/varchar, N for any int/decimal, D for date. Def: TND')
    parser.add_argument('--all', default=False, action='store_true', help="display ALL fields from matching rows")


    parser.add_argument('dbname')
    parser.add_argument('needle')


    return parser.parse_args()

def main():
    args = get_args()

    if not args.verbose:
        log.remove()
    log.add(sys.stderr, level="INFO")        

    log.debug("Verbose logging")

    tables = args.tables
    needle = args.needle
    limit = args.limit

    with pymysql.connect(
                unix_socket=args.socket, host = args.host, user=args.user, passwd=args.password, database=args.dbname,
                    cursorclass=pymysql.cursors.DictCursor) as db:

        if not tables:
            # fill table when none provided
            c = db.cursor(pymysql.cursors.Cursor)
            c.execute("SHOW TABLES")
            tables = [ t[0] for t in c.fetchall() ]
        

        for table in tables:
            log.debug("Examine table: {}".format(table))
            # dbc.cursor.execute("DESC {}".format(table))
            pkfield = None
            field_list = list()


            c = db.cursor()
            c.execute("DESC {}".format(table))
            for field in c:
                log.debug(field)
                ftype = field['Type'].split('(')[0]
                
                if field['Key']=='PRI':
                    pkfield = field['Field']

                if ftype in ['varchar', 'char', 'text'] and 'T' in args.types:
                    field_list.append(field['Field'])

                elif ftype in ['int', 'smallint', 'decimal'] and 'N' in args.types:
                    field_list.append(field['Field'])

                elif ftype in ['date', 'datetime'] and 'D' in args.types:
                    field_list.append(field['Field'])
                else:
                    log.debug("Skip field {}".format(field))


            for field in field_list:
                try:
                    log.debug("Grepping table {} field {}".format(table, field))
                    t = (needle,)
                    log.debug(t)
                    
                    if args.all:
                        flist='*'
                    elif pkfield:
                        flist = '`{}`, `{}`'.format(pkfield, field)
                    else:
                        flist = '`field`'
                    
                    if args.like:
                        cmp = 'LIKE'
                    elif args.regexp:
                        cmp = 'REGEXP'
                    else:
                        cmp = '='

                    if args.limit:
                        limitstr = " LIMIT {}".format(limit)
                    else:
                        limitstr = ""

                    query = """SELECT {} FROM {} WHERE `{}` {} %s {}""".format(flist, table, field, cmp, limitstr)
                    log.debug(query)

                    c.execute(query, t)

                    for r in c:

                        if pkfield:
                            idstr = '({}={})'.format(pkfield, r[pkfield])
                        else:
                            idstr = ''

                        s = "{}{}/{} ".format(table, idstr, field)

                        if args.all:
                            s+= str(r)

                        elif not args.suppress:                        
                            s += str(r[field])

                        print(s)
                        if limit:
                            limit -= 1
                except pymysql.err.ProgrammingError as e:
                    log.error("Problem while grepping table {} field {}: {}".format(table, field, e))

if __name__ == '__main__':
    main()