#! /usr/bin/env python3

import argparse
import flowpy
import matplotlib
import tifffile
import numpy as np
import sys
from tqdm import tqdm
matplotlib.use('Qt5Agg')

from matplotlib.backends.backend_qt5agg import (
        FigureCanvasQTAgg as FigureCanvas,
        NavigationToolbar2QT as NavigationToolbar
)
from matplotlib.figure import Figure
from matplotlib.widgets import Slider
from PyQt5 import QtCore, QtWidgets
from pathlib import Path

try:
    import ffmpeg
except ImportError:
    ffmpeg = None


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("file_paths", nargs="*")
    args = parser.parse_args()

    qt_app = QtWidgets.QApplication([""])

    main_window = DOFQTWindow()
    if args.file_paths:
        main_window.set_flow_source(args.file_paths)

    main_window.show()
    sys.exit(qt_app.exec_())


class DOFQTWindow(QtWidgets.QMainWindow):
    flowSourceChanged = QtCore.pyqtSignal(list, name="flowSourceChanged")
    exportRequested = QtCore.pyqtSignal(str, name="exportRequested")

    def __init__(self):
        QtWidgets.QMainWindow.__init__(self)

        self.setAttribute(QtCore.Qt.WA_DeleteOnClose)
        self.setWindowTitle("DOFReader")

        self.main_widget = DOFMainWidget(parent=self)
        self.main_widget.setFocus()
        self.setCentralWidget(self.main_widget)

        menu_bar = self.menuBar()
        file_menu = menu_bar.addMenu("&File")
        file_menu.addAction("&Export flow...", self.export_flow_dialog)

    def set_flow_source(self, file_path):
        self.flowSourceChanged.emit(file_path)

    def export_flow_dialog(self):
        filename, _ = QtWidgets.QFileDialog.getSaveFileName(
            self,
            "Choose export path",
            "",
            ""
        )
        self.exportRequested.emit(filename)


class DOFMainWidget(QtWidgets.QWidget):
    flowScaleChanged = QtCore.pyqtSignal(float, name="flowScaleChanged")
    frameChanged = QtCore.pyqtSignal(np.ndarray, name="frameChanged")

    def __init__(self, parent=None):
        super(QtWidgets.QWidget, self).__init__(parent)

        self.flow_sequence = None

        self.plot_options = PlotOptions(self.flowScaleChanged, parent=self)
        self.canvas = MatplotlibCanvas(self.plot_options.getValue(), parent=self)
        self.toolbar = NavigationToolbar(self.canvas, self)
        self.slider_box = FrameSliderBox(self)

        self.layout = QtWidgets.QVBoxLayout(self)
        self.layout.addWidget(self.toolbar)
        self.layout.addWidget(self.canvas)
        self.layout.addWidget(self.slider_box)
        self.layout.addWidget(self.plot_options)

        self.plot_options.valueChanged.connect(self.canvas.handle_plot_options_changed)
        self.slider_box.slider.valueChanged.connect(self.handle_cursor_changed)
        self.parent().flowSourceChanged.connect(self.handle_flow_source_changed)
        self.parent().exportRequested.connect(self.handle_export_requested)

    def handle_flow_source_changed(self, path):
        self.flow_sequence = FlowOpener.open(path)
        self.slider_box.slider.setValue(1)
        self.slider_box.slider.setMaximum(self.flow_sequence.shape[0])
        self.emit_frame_changed(self.flow_sequence[0])

    def handle_cursor_changed(self, cursor_value):
        self.slider_box.slider_label.setText(str(cursor_value))
        self.emit_frame_changed(self.flow_sequence[cursor_value - 1])

    def handle_export_requested(self, filename):
        if ffmpeg is None:
            print("Export aborted, ffmpeg-python not found")
            return

        print("Exporting frames to " + filename)

        length, height, width, _ = self.flow_sequence.shape

        ffmpeg_process = (
            ffmpeg
            .input('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width, height))
            .output(filename, pix_fmt='yuv420p')
            .overwrite_output()
            .run_async(pipe_stdin=True)
        )

        plot_parameters = self.plot_options.getValue()
        flowpy_options = MatplotlibCanvas.get_flowpy_options(plot_parameters)

        for frame in tqdm(self.flow_sequence):
            if plot_parameters["auto_scale"]:
                flowpy_options["flow_max_radius"] = flowpy.get_flow_max_radius(frame)
            rendered_flow = flowpy.flow_to_rgb(frame, **flowpy_options)
            ffmpeg_process.stdin.write(rendered_flow.astype(np.uint8).tobytes())

        ffmpeg_process.stdin.close()
        ffmpeg_process.wait()

        print("Export finished")

    def emit_frame_changed(self, flow):
        auto_scale_radius = flowpy.get_flow_max_radius(flow)
        self.plot_options.setAutoScaleRadius(auto_scale_radius)
        self.frameChanged.emit(flow)


class MatplotlibCanvas(FigureCanvas):
    def __init__(self, plot_options, parent=None):
        fig = Figure(figsize=(5, 4), dpi=100)

        FigureCanvas.__init__(self, fig)
        FigureCanvas.setSizePolicy(self, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)
        self.fig = fig
        self.setParent(parent)

        self.ax_im, self.ax_cal = fig.subplots(1, 2)
        self.flow_image = None
        self.arrows = None
        self.circle = None

        self.parent().frameChanged.connect(self.handle_flow_changed)

    def clean_canvas(self):
        if self.flow_image:
            self.flow_image.remove()
            self.flow_image = None

        if self.arrows:
            self.arrows.remove()
            self.arrows = None

        if self.circle:
            self.circle.remove()
            self.circle = None

    @staticmethod
    def get_flowpy_options(plot_options):
        flowpy_options = {}

        flowpy_options["background"] = plot_options["background"]
        if plot_options["auto_scale"] and plot_options["auto_scale_radius"] > 0:
            flowpy_options["flow_max_radius"] = plot_options["auto_scale_radius"]
        else:
            flowpy_options["flow_max_radius"] = plot_options["max_radius"]

        return flowpy_options

    def update_rendered_flow(self):
        self.clean_canvas()

        plot_options = self.parent().plot_options.getValue()
        height, width, _ = self.flow.shape

        grid_spec = matplotlib.gridspec.GridSpec(1, 2, width_ratios=[1, height/width])
        self.ax_im.set_position(grid_spec[0].get_position(self.fig))
        self.ax_cal.set_position(grid_spec[1].get_position(self.fig))

        flowpy_options = self.get_flowpy_options(plot_options)
        rendered_flow = flowpy.flow_to_rgb(self.flow, **flowpy_options)
        self.flow_image = self.ax_im.imshow(rendered_flow)

        if plot_options["show_arrows"]:
            self.arrows = flowpy.attach_arrows(self.ax_im, self.flow, scale_units="xy", scale=1.0)
        flowpy.attach_coord(self.ax_im, self.flow)
        _, self.circle = flowpy.attach_calibration_pattern(self.ax_cal, **flowpy_options)

        self.flow_image.axes.figure.canvas.draw()

    def handle_flow_changed(self, flow):
        self.flow = flow
        self.update_rendered_flow()

    def handle_plot_options_changed(self, _):
        self.update_rendered_flow()


class FrameSliderBox(QtWidgets.QGroupBox):
    def __init__(self, parent=None):
        super(FrameSliderBox, self).__init__(parent)
        self.setTitle("Frame index")

        main_layout = QtWidgets.QHBoxLayout()

        self.slider_label = QtWidgets.QLabel(str(1))
        self.slider = QtWidgets.QSlider(QtCore.Qt.Horizontal, parent)
        self.slider.setTickPosition(QtWidgets.QSlider.TicksAbove)
        self.slider.setMinimum(1)
        self.slider.setMaximum(1)

        main_layout.addWidget(self.slider_label)
        main_layout.addWidget(self.slider)
        self.setLayout(main_layout)


class PlotOptions(QtWidgets.QGroupBox):
    valueChanged = QtCore.pyqtSignal(dict, name="valueChanged")
    default_parameters = {
        "background": "bright",
        "show_arrows": False,
        "auto_scale": True,
        "max_radius": 0.0,
        "auto_scale_radius": 0.0,
    }

    def __init__(self, flowScaleChanged, parent=None):
        super().__init__("Plot parameters", parent)
        self.parameters = self.default_parameters.copy()

        main_layout = QtWidgets.QHBoxLayout()

        self.background = QtWidgets.QCheckBox("Black background", self)
        self.background.setChecked(self.default_parameters["background"] == "dark")
        self.background.toggled.connect(self.handle_state_changed)

        self.arrows = QtWidgets.QCheckBox("arrows", self)
        self.arrows.setChecked(self.default_parameters["show_arrows"])
        self.arrows.toggled.connect(self.handle_state_changed)

        self.auto_scale = QtWidgets.QCheckBox("auto scale", self)
        self.auto_scale.setChecked(not self.default_parameters["max_radius"])
        self.auto_scale.toggled.connect(self.handle_state_changed)

        self.max_radius = QtWidgets.QDoubleSpinBox(self)
        self.max_radius.setEnabled(False)
        self.max_radius.setValue(0.1)
        self.max_radius.setRange(0.1, 1e3)
        self.max_radius.setSingleStep(0.1)
        self.max_radius.valueChanged.connect(self.handle_state_changed)
        flowScaleChanged.connect(self.handle_flow_scale_changed)

        main_layout.addWidget(self.background)
        main_layout.addWidget(self.arrows)
        main_layout.addWidget(self.auto_scale)
        main_layout.addWidget(self.max_radius)

        self.setLayout(main_layout)
        self.handle_state_changed()

    def synchronize_parameters(self):
        self.parameters["background"] = "dark" if self.background.isChecked() else "bright"
        self.parameters["show_arrows"] = self.arrows.isChecked()
        self.parameters["auto_scale"] = self.auto_scale.isChecked()
        self.parameters["max_radius"] = self.max_radius.value()

    def setAutoScaleRadius(self, value):
        self.parameters["auto_scale_radius"] = value
        if self.parameters["auto_scale"]:
            self.max_radius.valueChanged.disconnect()
            self.max_radius.setValue(value)
            self.max_radius.valueChanged.connect(self.handle_state_changed)

    def handle_state_changed(self):
        self.synchronize_parameters()
        self.max_radius.setEnabled(not self.parameters["auto_scale"])

        self.valueChanged.emit(self.getValue())

    def handle_flow_scale_changed(self, value):
        self.max_radius.setValue(value)

    def getValue(self):
        return self.parameters


class FlowOpener():
    @staticmethod
    def open(input_paths):
        if Path(input_paths[0]).suffix.lower() in [".tiff", ".tif"]:
            assert len(input_paths) == 2, "Must provide two tiff files"

            data = tifffile.imread(input_paths)
            if data.ndim == 3:
                data = np.asarray([data])
            data = data.transpose((1, 2, 3, 0))
        else:
            data = np.stack([flowpy.flow_read(path) for path in input_paths])

        return data


if __name__ == "__main__":
    main()
