from typing import Literal, List

import more_itertools
import pydantic
import numpy as np

from classiq_interface.hybrid.problem_input import ProblemInput, OptimizationProblemName

_TOLERANCE_DECIMALS = 6


class PlotData(pydantic.BaseModel):
    # We are currently ignoring units. This might need to be handled in the future
    x: float = pydantic.Field(description="The X coordinate of this plot")
    y: float = pydantic.Field(description="The Y coordinate of this plot")
    t: float = pydantic.Field(description="The time stamp of this plot")
    plot_id: pydantic.conint(ge=0) = pydantic.Field(
        description="The plot ID of this plot"
    )


class MhtQaoaInput(ProblemInput):
    name: Literal[OptimizationProblemName.MhtQaoa] = pydantic.Field(
        default=OptimizationProblemName.MhtQaoa,
        description="Name of optimization problem.",
    )
    reps: pydantic.conint(ge=1) = pydantic.Field(
        default=3, description="Number of QAOA layers."
    )
    plot_list: List[PlotData] = pydantic.Field(
        description="The list of (x,y,t) plots of the MHT problem."
    )
    misdetection_maximum_time_steps: pydantic.conint(ge=0) = pydantic.Field(
        default=0,
        description="The maximum number of time steps a target might be misdetected.",
    )
    penalty_energy: float = pydantic.Field(
        default=2,
        description="Penalty energy for invalid solutions. The value affects "
        "the converges rate. Small positive values are preferred",
    )
    three_local_coeff: float = pydantic.Field(
        default=0,
        description="Coefficient for the 3-local terms in the Hamiltonian. It is related to the angular acceleration.",
    )
    is_penalty: bool = pydantic.Field(
        default=True, description="Build Pubo using penalty terms"
    )
    max_velocity: float = pydantic.Field(
        default=0, description="Max allowed velocity for a segment"
    )

    def is_valid_cost(self, cost: float) -> bool:
        return True

    @pydantic.validator("plot_list")
    def round_plot_list_times_and_validate(cls, plot_list):
        MhtQaoaInput._check_all_ids_are_distinct(plot_list)
        MhtQaoaInput._round_to_tolerance_decimals(plot_list)

        time_stamps = sorted({plot.t for plot in plot_list})
        time_diff_set = {
            np.round(time_stamps[i] - time_stamps[i - 1], decimals=_TOLERANCE_DECIMALS)
            for i in range(1, len(time_stamps))
        }

        if len(time_diff_set) != 1:
            raise ValueError("The time difference between each time stamp is not equal")

        return plot_list

    @staticmethod
    def _round_to_tolerance_decimals(plot_list: List[PlotData]):
        for plot in plot_list:
            plot.t = np.round(plot.t, decimals=_TOLERANCE_DECIMALS)

    @staticmethod
    def _check_all_ids_are_distinct(plot_list: List[PlotData]):
        if not more_itertools.all_unique(plot.plot_id for plot in plot_list):
            raise ValueError("Plot IDs should be unique.")
