from os.path import basename

import psycopg2
from django.contrib.gis.geos import GEOSGeometry
from django.db import transaction
from django.utils.translation import ugettext as _
from psycopg2 import sql
from rest_framework.serializers import (
    BooleanField,
    CharField,
    ChoiceField,
    IntegerField,
    ModelSerializer,
    SerializerMethodField,
    SlugField,
    ValidationError,
)

from .models import (
    CommandSource,
    CSVSource,
    Field,
    GeoJSONSource,
    GeometryTypes,
    PostGISSource,
    ShapefileSource,
    Source,
    WMTSSource,
)


class PolymorphicModelSerializer(ModelSerializer):
    type_field = "_type"
    type_class_map = {}

    def __new__(cls, *args, **kwargs):
        """
        Return the correct serializer given the type provided in type_field
        """
        if kwargs.pop("many", False):
            return cls.many_init(*args, **kwargs)

        if "data" in kwargs:

            data_type = kwargs["data"].get(cls.type_field)

            serializer = cls.get_serializer_from_type(data_type)

            if serializer is not cls:
                return serializer(*args, **kwargs)

        return super().__new__(cls, *args, **kwargs)

    def __init_subclass__(cls, **kwargs):
        """ Create a registry of all subclasses of the current class """
        if cls.Meta.model:
            cls.type_class_map[cls.Meta.model.__name__] = cls

    @classmethod
    def get_serializer_from_type(cls, data_type):
        """
        Returns the serializer class from datatype
        """
        if data_type in cls.type_class_map:
            return cls.type_class_map[data_type]
        raise ValidationError({cls.type_field: f"{data_type}'s type is unknown"})

    def to_representation(self, obj):
        serializer = self.get_serializer_from_type(obj.__class__.__name__)

        if serializer is self.__class__:
            data = {
                k: v
                for k, v in super().to_representation(obj).items()
                if k not in obj.polymorphic_internal_model_fields
            }
        else:
            data = serializer().to_representation(obj)

        data[self.type_field] = obj.__class__.__name__

        return data

    def to_internal_value(self, data):
        data_type = data.get(self.type_field)

        validated_data = super().to_internal_value(data)

        validated_data[self.type_field] = data_type

        return validated_data

    @transaction.atomic
    def create(self, validated_data):
        data_type = validated_data.pop(self.type_field, None)
        serializer = self.get_serializer_from_type(data_type)(validated_data)

        if serializer.__class__ is self.__class__:
            return super().create(validated_data)
        else:
            return serializer.create(validated_data)


class FieldSerializer(ModelSerializer):
    class Meta:
        model = Field
        exclude = ("source",)
        read_only_fields = ("name", "sample", "source")


class SourceSerializer(PolymorphicModelSerializer):
    fields = FieldSerializer(many=True, required=False)
    status = SerializerMethodField()
    slug = SlugField(max_length=255, read_only=True)

    class Meta:
        fields = "__all__"
        model = Source

    def _update_fields(self, source):
        if source.run_sync_method("update_fields", success_state="NEED_SYNC").result:
            return source
        raise ValidationError("Fields update failed")

    @transaction.atomic
    def create(self, validated_data):
        # Fields can't be defined at source creation
        validated_data.pop("fields", None)
        source = super().create(validated_data)
        return self._update_fields(source)

    @transaction.atomic
    def update(self, instance, validated_data):
        validated_data.pop("fields")

        source = super().update(instance, validated_data)

        self._update_fields(source)

        for field_data in self.get_initial().get("fields", []):

            try:
                instance = source.fields.get(name=field_data.get("name"))
                serializer = FieldSerializer(instance=instance, data=field_data)
                if serializer.is_valid():
                    serializer.save()
                else:
                    raise ValidationError("Field configuration is not valid")
            except Field.DoesNotExist:
                pass

        return source

    def get_status(self, instance):
        return instance.get_status()


class PostGISSourceSerializer(SourceSerializer):
    id_field = CharField(required=False)
    geom_field = CharField(required=False, allow_null=True)

    def _get_connection(self, data):
        conn = psycopg2.connect(
            user=data.get("db_username"),
            password=data.get("db_password"),
            host=data.get("db_host"),
            port=data.get("db_port", 5432),
            dbname=data.get("db_name"),
        )
        return conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)

    def _first_record(self, data):
        cursor = self._get_connection(data)
        query = "SELECT * FROM ({}) q LIMIT 1"
        cursor.execute(sql.SQL(query).format(sql.SQL(data["query"])))
        return cursor.fetchone()

    def _validate_geom(self, data):
        """ Validate that geom_field exists else try to find it in source
        """
        first_record = self._first_record(data)

        if data.get("geom_field") is None:
            for k, v in first_record.items():
                try:
                    geom = GEOSGeometry(v)
                    if geom.geom_typeid == data.get("geom_type"):
                        data["geom_field"] = k
                        break
                except Exception:
                    pass

            else:
                geomtype_name = GeometryTypes(data.get("geom_type")).name
                raise ValidationError(f"No geom field found of type {geomtype_name}")
        elif data.get("geom_field") not in first_record:
            raise ValidationError("Field does not exist in source")

        return data

    def _validate_query_connection(self, data):
        """ Check if connection informations are valid or not, trying to
        connect to the Pg server and executing the query
        """
        try:
            self._first_record(data)
        except Exception:
            raise ValidationError("Connection informations or query are not valid")

    def validate(self, data):
        self._validate_query_connection(data)
        data = self._validate_geom(data)

        return super().validate(data)

    class Meta:
        model = PostGISSource
        fields = "__all__"
        extra_kwargs = {"db_password": {"write_only": True}}


class FileSourceSerializer(SourceSerializer):
    filename = SerializerMethodField()

    def to_internal_value(self, data):
        if len(data.get("file", [])) > 0:
            data["file"] = data["file"][0]

        return super().to_internal_value(data)

    def get_filename(self, instance):
        if instance.file:
            return basename(instance.file.name)

    class Meta:
        model = None


class GeoJSONSourceSerializer(FileSourceSerializer):
    class Meta:
        model = GeoJSONSource
        fields = "__all__"
        extra_kwargs = {"file": {"write_only": True}}


class ShapefileSourceSerializer(FileSourceSerializer):
    class Meta:
        model = ShapefileSource
        fields = "__all__"
        extra_kwargs = {"file": {"write_only": True}}


class CommandSourceSerializer(SourceSerializer):
    class Meta:
        model = CommandSource
        fields = "__all__"
        extra_kwargs = {"command": {"read_only": True}}


class WMTSSourceSerialize(SourceSerializer):
    minzoom = IntegerField(min_value=0, max_value=24, allow_null=True, default=0)
    maxzoom = IntegerField(min_value=0, max_value=24, allow_null=True, default=24)
    geom_type = CharField(
        required=False, allow_null=True, default=GeometryTypes.Undefined.value
    )

    class Meta:
        model = WMTSSource
        fields = "__all__"


class CSVSourceSerializer(FileSourceSerializer):
    coordinate_reference_system = CharField(required=True)
    encoding = CharField(required=True)
    field_separator = CharField(required=True)
    decimal_separator = CharField(required=True)
    char_delimiter = CharField(required=True)
    coordinates_field = CharField(required=True)
    number_lines_to_ignore = IntegerField(required=True)

    use_header = BooleanField(required=False, default=False)
    ignore_columns = BooleanField(required=False, default=False)
    latitude_field = CharField(required=False)
    longitude_field = CharField(required=False)
    latlong_field = CharField(required=False)
    coordinates_field_count = CharField(required=False)
    coordinates_separator = CharField(required=False)
    geom_type = ChoiceField(
        default=GeometryTypes.Point.value, choices=GeometryTypes.choices()
    )

    class Meta:
        model = CSVSource
        fields = "__all__"
        extra_kwargs = {
            "file": {"write_only": True},
        }

    def to_internal_value(self, data):
        validated_data = super().to_internal_value(data)
        # settings does not exist if no group is specifed at creation
        settings = validated_data.get("settings", {})
        settings.update(
            {
                "coordinate_reference_system": validated_data.pop(
                    "coordinate_reference_system"
                ),
                "encoding": validated_data.pop("encoding"),
                "field_separator": validated_data.pop("field_separator"),
                "decimal_separator": validated_data.pop("decimal_separator"),
                "char_delimiter": validated_data.pop("char_delimiter"),
                "coordinates_field": validated_data.get("coordinates_field"),
                "number_lines_to_ignore": validated_data.pop("number_lines_to_ignore"),
                "use_header": validated_data.pop("use_header"),
                "ignore_columns": validated_data.pop("ignore_columns"),
            }
        )
        if validated_data.get("coordinates_field") == "one_column":
            settings.update(
                {
                    "latlong_field": validated_data.pop("latlong_field"),
                    "coordinates_field_count": validated_data.pop(
                        "coordinates_field_count"
                    ),
                    "coordinates_separator": validated_data.pop(
                        "coordinates_separator"
                    ),
                }
            )

        elif validated_data.get("coordinates_field") == "two_columns":
            settings.update(
                {
                    "latitude_field": validated_data.pop("latitude_field"),
                    "longitude_field": validated_data.pop("longitude_field"),
                }
            )
        validated_data.pop("coordinates_field")
        validated_data["settings"] = settings
        return validated_data

    def to_representation(self, obj):
        data = super().to_representation(obj)
        if data.get("coordinates_field") == "one_column":
            data.pop("latitude_field")
            data.pop("longitude_field")

        if data.get("coordinates_field") == "two_columns":
            data.pop("latlong_field")
            data.pop("coordinates_field_count")
            data.pop("coordinates_separator")
        return data

    def validate(self, data):
        validated_data = super().validate(data)
        if data["settings"]["coordinates_field"] == "one_column":
            if not data["settings"].get("latlong_field"):
                raise ValidationError(
                    _(
                        "latlong_field must be defined when coordinates are set to one column"
                    )
                )
            if not data["settings"].get("coordinates_field_count"):
                raise ValidationError(
                    _(
                        "Coordinates order must be specified when coordinates are set to one column"
                    )
                )
            if not data["settings"].get("coordinates_separator"):
                raise ValidationError(
                    _(
                        "Coordinates separator must be specified when coordinates are set to one column"
                    )
                )
        elif data["settings"]["coordinates_field"] == "two_columns":
            if not data["settings"].get("latitude_field"):
                raise ValidationError(
                    _(
                        "Latitude field must be specified when coordinates are set to two columns"
                    )
                )
            if not data["settings"].get("longitude_field"):
                raise ValidationError(
                    _(
                        "Longitude field must be specified when coordinates are set to two columns"
                    )
                )
        else:
            raise ValidationError(_("Incorrect value for coordinates field"))
        return validated_data
