from typing import Dict

from pyspark.sql import SparkSession, Column, DataFrame

# noinspection PyUnresolvedReferences
from pyspark.sql.functions import col, lit

from spark_auto_mapper.automappers.automapper import AutoMapper
from spark_auto_mapper.expression_comparer import assert_expressions_are_equal
from spark_auto_mapper.helpers.automapper_helpers import AutoMapperHelpers as A


def test_auto_mapper_boolean(spark_session: SparkSession) -> None:
    # Arrange
    spark_session.createDataFrame(
        [
            (1, "Qureshi", "Imran", "0"),
            (2, "Vidal", "Michael", "1"),
        ],
        ["member_id", "last_name", "first_name", "my_age"],
    ).createOrReplaceTempView("patients")

    source_df: DataFrame = spark_session.table("patients")

    df = source_df.select("member_id")
    df.createOrReplaceTempView("members")

    # Act
    mapper = AutoMapper(
        view="members", source_view="patients", keys=["member_id"]
    ).columns(
        age=A.boolean(A.column("my_age")),
        is_active=A.boolean("False"),
    )

    assert isinstance(mapper, AutoMapper)
    sql_expressions: Dict[str, Column] = mapper.get_column_specs(source_df=source_df)
    for column_name, sql_expression in sql_expressions.items():
        print(f"{column_name}: {sql_expression}")

    assert_expressions_are_equal(
        sql_expressions["age"], col("b.my_age").cast("boolean").alias("age")
    )
    assert_expressions_are_equal(
        sql_expressions["is_active"], lit("False").cast("boolean").alias("is_active")
    )

    result_df: DataFrame = mapper.transform(df=df)

    # Assert
    result_df.printSchema()
    result_df.show()

    assert (
        result_df.where("member_id == 1")
        .select(
            "age",
            "is_active",
        )
        .collect()[0][:]
        == (False, False)
    )
    assert (
        result_df.where("member_id == 2")
        .select(
            "age",
            "is_active",
        )
        .collect()[0][:]
        == (True, False)
    )

    assert dict(result_df.dtypes)["age"] == "boolean"
    assert dict(result_df.dtypes)["is_active"] == "boolean"
