import textwrap
import traceback
from collections import OrderedDict

import rows
from django.core.management.base import BaseCommand
from django.db import connections
from psycopg2.sql import SQL, Identifier
from tqdm import tqdm

import urlid_graph.settings as urlid_graph_settings
from urlid_graph.utils import DatabaseConnection, random_name, read_total_size, working


class Command(BaseCommand):
    help = "Import relations to graph database"

    def add_arguments(self, parser):
        parser.add_argument("--batch-size", type=int, default=10000)
        parser.add_argument("--chunk-size", type=int, default=8388608)
        parser.add_argument("--debug", action="store_true")
        parser.add_argument("--disable-autovacuum", action="store_true")
        parser.add_argument("--no-drop-table", action="store_true")
        parser.add_argument("--sample-size", type=int, default=3000)
        parser.add_argument("relationship")
        parser.add_argument("input_filename")

    def handle(self, *args, **options):
        debug = options.get("debug", False)
        disable_autovacuum = options["disable_autovacuum"]
        relationship = options.pop("relationship")
        input_filename = options.pop("input_filename")

        django_connection = connections[urlid_graph_settings.GRAPH_DATABASE]
        django_connection.connect()
        self.db = DatabaseConnection(connection=django_connection.connection, debug=debug)
        ok = True

        try:
            with working("Disabling sync commit"):
                self.db.disable_sync_commit()
            if disable_autovacuum:
                with working("Disabling autovacuum"):
                    self.db.disable_autovacuum()

            # TODO: `.lower()` is needed so rows' slug won't mess with the name
            temp_table_name = "tmp_" + random_name().lower()
            try:
                self.execute_import_data(relationship, input_filename, temp_table_name, *args, **options)
            except:  # noqa
                traceback.print_exc()
                ok = False

            for relname in (f"{urlid_graph_settings.GRAPH_NAME}.ag_vertex", f"{urlid_graph_settings.GRAPH_NAME}.ag_edge"):
                with working(f"Running VACUUM ANALYZE on {relname}"):
                    self.db.vacuum_analyze(relname)

        except:  # noqa
            traceback.print_exc()
            ok = False

        finally:
            if disable_autovacuum:
                with working("Enabling autovacuum"):
                    self.db.enable_autovacuum()
            with working("Enabling sync commit"):
                self.db.enable_sync_commit()

        return str(ok)  # Used by import_data when calling this command programatically

    def create_properties(self, header, type_):
        # TODO: use `rows.properties.{field}` instead
        props = []
        for field in header:
            props.append(f"ON {type_.upper()} SET r.{field} = row.{field}")

        return "\r\n".join(props)

    def create_merge_query(self, schema, relation_name, temp_table_name):
        header = [field_name for field_name in schema.keys() if field_name != "id" and "node_uuid" not in field_name]
        props_create = self.create_properties(header, type_="create")
        props_match = self.create_properties(header, type_="match")
        # https://neo4j.com/developer/kb/understanding-how-merge-works/

        # TODO: may create a VLABEL for each entity and import the objects
        # using this VLABEL (not `object`) - it may not be possible depending
        # on the extraction file format.
        return textwrap.dedent(
            """
            LOAD FROM {temp_table_name} as row
            WITH row WHERE $1 < row.id AND row.id <= $2
            MERGE (obj1:object {{uuid: row.from_node_uuid}})
            MERGE (obj2:object {{uuid: row.to_node_uuid}})
            MERGE (obj1)-[r:{relation}]->(obj2)
            {create}
            {match}
            """.format(
                relation=relation_name,
                temp_table_name=temp_table_name,
                create=props_create,
                match=props_match,
            )
        )

    def create_table(self, schema, table_name):
        # Create table manually so we can inject an `id` column to control the
        # batch import data.

        schema = schema.copy()  # Do not change original object
        schema["id"] = rows.fields.IntegerField
        sql = rows.utils.pg_create_table_sql(schema, table_name=table_name, unlogged=True)
        # TODO: change the way we change this behavior (serial)
        sql = sql.replace('"id" BIGINT', '"id" SERIAL').replace("id BIGINT", "id SERIAL")
        self.db.execute_query(sql)

    def optimize_data_table(self, table_name):
        self.db.execute_queries(
            [
                f"ALTER TABLE {table_name} ADD PRIMARY KEY (id)",
                f"VACUUM ANALYZE {table_name}",
            ],
        )

    def ensure_graph_requisites(self, relation_name):
        self.db.execute_query(
            SQL("CREATE ELABEL IF NOT EXISTS {}").format(Identifier(relation_name)),
        )
        self.db.execute_query(
            SQL("CREATE VLABEL IF NOT EXISTS {}").format(Identifier("object")),
        )
        self.db.execute_query(
            SQL("CREATE PROPERTY INDEX IF NOT EXISTS idx_vlabel_object_uuid ON object (uuid)"),
        )

    def import_relations(self, schema, relation_name, temp_table_name, n_rows, batch_size):
        progress = tqdm(desc="Importing data to graph", total=n_rows)
        cypher_statement = self.create_merge_query(schema, relation_name, temp_table_name)
        print(cypher_statement)
        statement_name = f"urlid_tmp_rel_{random_name(10)}"
        self.db.execute_query(f"PREPARE {statement_name}(int, int) AS {cypher_statement}")
        for min_id in range(0, n_rows + 1, batch_size):
            self.db.execute_query(f"EXECUTE {statement_name} ({min_id}, {min_id + batch_size})")
            if progress.n + batch_size > n_rows:
                progress.update(n_rows - progress.n)
            else:
                progress.update(batch_size)
        self.db.execute_query(f"DEALLOCATE {statement_name}")
        progress.close()

    def execute_import_data(self, relation_name, input_filename, temp_table_name, *args, **options):
        print(f"Importing {relation_name} from {input_filename}")

        rows_progress = rows.utils.ProgressBar(prefix="Importing data", pre_prefix="Detecting schema", unit="bytes")
        schema = OrderedDict(
            [
                ("from_node_uuid", rows.fields.UUIDField),
                ("to_node_uuid", rows.fields.UUIDField),
                ("relationship", rows.fields.TextField),
                ("properties", rows.fields.JSONField),
            ]
        )

        rows_progress.description = "Creating table"
        self.create_table(schema, temp_table_name)

        rows_progress.description = "Detecting file size"
        rows_progress.total = read_total_size(input_filename)
        rows_progress.description = "Detecting CSV dialect"
        rows_ouput = rows.utils.pgimport(
            input_filename,
            encoding="utf-8",
            dialect="excel",
            table_name=temp_table_name,
            create_table=False,
            database_uri=urlid_graph_settings.GRAPH_DATABASE_URL,
            schema=schema,
            callback=rows_progress.update,
            chunk_size=options["chunk_size"],
        )
        total_imported, batch_size = rows_ouput["rows_imported"], options["batch_size"]
        rows_progress.description = f"Imported {total_imported} rows to '{temp_table_name}'"
        rows_progress.close()

        with working("Optimizing table for reading"):
            self.optimize_data_table(temp_table_name)

        with working(f"Ensuring elabel {relation_name} exists"):
            self.ensure_graph_requisites(relation_name)

        self.import_relations(schema, relation_name, temp_table_name, total_imported, batch_size)

        if not options["no_drop_table"]:
            with working(f"Deleting table '{temp_table_name}'"):
                self.db.delete_table(temp_table_name)
