import pytest
from pytest_mock.plugin import MockerFixture

from sqlmesh.core.snapshot import SnapshotId
from sqlmesh.utils.concurrency import (
    NodeExecutionFailedError,
    concurrent_apply_to_snapshots,
)


@pytest.mark.parametrize("tasks_num", [1, 2])
def test_concurrent_apply_to_snapshots(mocker: MockerFixture, tasks_num: int):
    snapshot_a = mocker.Mock()
    snapshot_a.snapshot_id = SnapshotId(name="model_a", identifier="snapshot_a")
    snapshot_a.parents = []

    snapshot_b = mocker.Mock()
    snapshot_b.snapshot_id = SnapshotId(name="model_b", identifier="snapshot_b")
    snapshot_b.parents = []

    snapshot_c = mocker.Mock()
    snapshot_c.snapshot_id = SnapshotId(name="model_c", identifier="snapshot_c")
    snapshot_c.parents = [snapshot_a.snapshot_id, snapshot_b.snapshot_id]

    snapshot_d = mocker.Mock()
    snapshot_d.snapshot_id = SnapshotId(name="model_d", identifier="snapshot_d")
    snapshot_d.parents = [snapshot_b.snapshot_id, snapshot_c.snapshot_id]

    processed_snapshots = []

    errors, skipped = concurrent_apply_to_snapshots(
        [snapshot_a, snapshot_b, snapshot_c, snapshot_d],
        lambda s: processed_snapshots.append(s),
        tasks_num,
    )

    assert len(processed_snapshots) == 4
    assert processed_snapshots[0] in (snapshot_a, snapshot_b)
    assert processed_snapshots[1] in (snapshot_a, snapshot_b)
    assert processed_snapshots[2] == snapshot_c
    assert processed_snapshots[3] == snapshot_d

    assert not errors
    assert not skipped


@pytest.mark.parametrize("tasks_num", [1, 2])
def test_concurrent_apply_to_snapshots_exception(mocker: MockerFixture, tasks_num: int):
    snapshot_a = mocker.Mock()
    snapshot_a.snapshot_id = SnapshotId(name="model_a", identifier="snapshot_a")
    snapshot_a.parents = []

    snapshot_b = mocker.Mock()
    snapshot_b.snapshot_id = SnapshotId(name="model_b", identifier="snapshot_b")
    snapshot_b.parents = []

    def raise_():
        raise RuntimeError("fail")

    with pytest.raises(NodeExecutionFailedError):
        concurrent_apply_to_snapshots(
            [snapshot_a, snapshot_b],
            lambda s: raise_(),
            tasks_num,
        )


@pytest.mark.parametrize("tasks_num", [1, 2])
def test_concurrent_apply_to_snapshots_return_failed_skipped(mocker: MockerFixture, tasks_num: int):
    snapshot_a = mocker.Mock()
    snapshot_a.snapshot_id = SnapshotId(name="model_a", identifier="snapshot_a")
    snapshot_a.parents = []

    snapshot_b = mocker.Mock()
    snapshot_b.snapshot_id = SnapshotId(name="model_b", identifier="snapshot_b")
    snapshot_b.parents = [snapshot_a.snapshot_id]

    snapshot_c = mocker.Mock()
    snapshot_c.snapshot_id = SnapshotId(name="model_c", identifier="snapshot_c")
    snapshot_c.parents = [snapshot_b.snapshot_id]

    snapshot_d = mocker.Mock()
    snapshot_d.snapshot_id = SnapshotId(name="model_d", identifier="snapshot_d")
    snapshot_d.parents = []

    snapshot_e = mocker.Mock()
    snapshot_e.snapshot_id = SnapshotId(name="model_e", identifier="snapshot_e")
    snapshot_e.parents = [snapshot_d.snapshot_id]

    def raise_(snapshot):
        if snapshot.snapshot_id.name == "model_a":
            raise RuntimeError("fail")

    errors, skipped = concurrent_apply_to_snapshots(
        [snapshot_a, snapshot_b, snapshot_c, snapshot_d, snapshot_e],
        lambda s: raise_(s),
        tasks_num,
        raise_on_error=False,
    )

    assert len(errors) == 1
    assert errors[0].node == snapshot_a.snapshot_id

    assert skipped == [snapshot_b.snapshot_id, snapshot_c.snapshot_id]
