# ruff: noqa: ARG001

from functools import partial

import numpy as np
import pytest
from numpy.testing import assert_almost_equal, assert_array_equal
from scipy import linalg
from sklearn.datasets import load_iris
from sklearn.linear_model import Lasso, LogisticRegression
from sklearn.linear_model._coordinate_descent import _alpha_grid
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold
from sklearn.utils.estimator_checks import (
    ignore_warnings,
    parametrize_with_checks,
)

from nilearn._utils.estimator_checks import (
    check_estimator,
    nilearn_check_estimator,
    return_expected_failed_checks,
)
from nilearn._utils.tags import SKLEARN_LT_1_6
from nilearn.decoding._utils import adjust_screening_percentile
from nilearn.decoding.space_net import (
    SpaceNetClassifier,
    SpaceNetRegressor,
    _crop_mask,
    _EarlyStoppingCallback,
    _space_net_alpha_grid,
    _univariate_feature_screening,
    path_scores,
)
from nilearn.decoding.space_net_solvers import (
    graph_net_logistic,
    graph_net_squared_loss,
)
from nilearn.decoding.tests._testing import create_graph_net_simulation_data
from nilearn.decoding.tests.test_same_api import to_niimgs
from nilearn.image import get_data

logistic_path_scores = partial(path_scores, is_classif=True)
squared_loss_path_scores = partial(path_scores, is_classif=False)


IS_CLASSIF = [True, False]

PENALTY = ["graph-net", "tv-l1"]

ESTIMATORS_TO_CHECK = [
    SpaceNetClassifier(standardize="zscore_sample"),
    SpaceNetRegressor(standardize="zscore_sample"),
]

if SKLEARN_LT_1_6:

    @pytest.mark.parametrize(
        "estimator, check, name",
        check_estimator(estimators=ESTIMATORS_TO_CHECK),
    )
    def test_check_estimator_sklearn_valid(estimator, check, name):
        """Check compliance with sklearn estimators."""
        check(estimator)

    @pytest.mark.xfail(reason="invalid checks should fail")
    @pytest.mark.parametrize(
        "estimator, check, name",
        check_estimator(estimators=ESTIMATORS_TO_CHECK, valid=False),
    )
    def test_check_estimator_sklearn_invalid(estimator, check, name):
        """Check compliance with sklearn estimators."""
        check(estimator)
else:

    @parametrize_with_checks(
        estimators=ESTIMATORS_TO_CHECK,
        expected_failed_checks=return_expected_failed_checks,
    )
    def test_check_estimator_sklearn(estimator, check):
        """Check compliance with sklearn estimators."""
        check(estimator)


@pytest.mark.slow
@pytest.mark.parametrize(
    "estimator, check, name",
    nilearn_check_estimator(estimators=ESTIMATORS_TO_CHECK),
)
def test_check_estimator_nilearn(estimator, check, name):
    """Check compliance with nilearn estimators rules."""
    check(estimator)


@pytest.mark.parametrize("is_classif", IS_CLASSIF)
@pytest.mark.parametrize("l1_ratio", [0.5, 0.99])
@pytest.mark.parametrize("n_alphas", range(1, 10))
def test_space_net_alpha_grid(
    rng, is_classif, l1_ratio, n_alphas, n_samples=4, n_features=3
):
    X = rng.standard_normal((n_samples, n_features))
    y = np.arange(n_samples)

    alpha_max = np.max(np.abs(np.dot(X.T, y))) / l1_ratio

    if n_alphas == 1:
        assert_almost_equal(
            _space_net_alpha_grid(
                X, y, n_alphas=n_alphas, l1_ratio=l1_ratio, logistic=is_classif
            ),
            alpha_max,
        )

    alphas = _space_net_alpha_grid(
        X, y, n_alphas=n_alphas, l1_ratio=l1_ratio, logistic=is_classif
    )

    assert_almost_equal(alphas.max(), alpha_max)
    assert_almost_equal(n_alphas, len(alphas))


def test_space_net_alpha_grid_same_as_sk():
    iris = load_iris()
    X = iris.data
    y = iris.target

    assert_almost_equal(
        _space_net_alpha_grid(X, y, n_alphas=5),
        X.shape[0] * _alpha_grid(X, y, n_alphas=5, fit_intercept=False),
    )


def test_early_stopping_callback_object(rng, n_samples=10, n_features=30):
    # This test evolves w so that every line of th _EarlyStoppingCallback
    # code is executed a some point. This a kind of code fuzzing.
    X_test = rng.standard_normal((n_samples, n_features))
    y_test = np.dot(X_test, np.ones(n_features))
    w = np.zeros(n_features)
    escb = _EarlyStoppingCallback(X_test, y_test, False)
    for counter in range(50):
        k = min(counter, n_features - 1)
        w[k] = 1

        # jitter
        if k > 0 and rng.random() > 0.9:
            w[k - 1] = 1 - w[k - 1]

        escb({"w": w, "counter": counter})
        assert len(escb.test_scores) == counter + 1

        # restart
        if counter > 20:
            w *= 0.0


def test_screening_space_net():
    size = 4
    X_, *_ = create_graph_net_simulation_data(
        snr=1.0, n_samples=10, size=size, n_points=5, random_state=42
    )
    _, mask = to_niimgs(X_, [size] * 3)

    for verbose in [0, 1]:
        with pytest.warns(UserWarning):
            screening_percentile = adjust_screening_percentile(
                10, mask, verbose
            )
    with pytest.warns(UserWarning):
        screening_percentile = adjust_screening_percentile(10, mask)
    # We gave here a very small mask, judging by standards of brain size
    # thus the screening_percentile_ corrected for brain size should
    # be 100%
    assert screening_percentile == 100


def test_logistic_path_scores():
    iris = load_iris()
    X, y = iris.data, iris.target
    _, mask = to_niimgs(X, [2, 2, 2])
    mask = get_data(mask).astype(bool)
    alphas = [1.0, 0.1, 0.01]

    test_scores, best_w = logistic_path_scores(
        graph_net_logistic,
        X,
        y,
        mask,
        alphas,
        0.5,
        np.arange(len(X)),
        np.arange(len(X)),
        {},
    )[:2]
    test_scores = test_scores[0]

    assert len(test_scores) == len(alphas)
    assert X.shape[1] + 1 == len(best_w)


def test_squared_loss_path_scores():
    iris = load_iris()
    X, y = iris.data, iris.target
    _, mask = to_niimgs(X, [2, 2, 2])
    mask = get_data(mask).astype(bool)
    alphas = [1.0, 0.1, 0.01]

    test_scores, best_w = squared_loss_path_scores(
        graph_net_squared_loss,
        X,
        y,
        mask,
        alphas,
        0.5,
        np.arange(len(X)),
        np.arange(len(X)),
        {},
    )[:2]

    test_scores = test_scores[0]
    assert len(test_scores) == len(alphas)
    assert X.shape[1] + 1 == len(best_w)


@pytest.mark.parametrize("l1_ratio", [0.99])
@pytest.mark.parametrize("debias", [True])
def test_tv_regression_simple(rng, l1_ratio, debias):
    dim = (4, 4, 4)
    W_init = np.zeros(dim)
    W_init[2:3, 1:2, -2:] = 1
    n = 10
    p = np.prod(dim)
    X = np.ones((n, 1)) + W_init.ravel().T
    X += rng.standard_normal((n, p))
    y = np.dot(X, W_init.ravel())
    X, mask = to_niimgs(X, dim)

    alphas = [0.1, 1.0]

    SpaceNetRegressor(
        mask=mask,
        alphas=alphas,
        l1_ratios=l1_ratio,
        penalty="tv-l1",
        max_iter=10,
        debias=debias,
        standardize="zscore_sample",
    ).fit(X, y)


@pytest.mark.parametrize("l1_ratio", [-2, 2])
@pytest.mark.parametrize("estimator", [SpaceNetClassifier, SpaceNetRegressor])
def test_base_estimator_invalid_l1_ratio(rng, l1_ratio, estimator):
    """Check that 0 < L1 ratio < 1."""
    dim = (4, 4, 4)
    W_init = np.zeros(dim)
    W_init[2:3, 1:2, -2:] = 1
    n = 10
    p = np.prod(dim)
    X = np.ones((n, 1)) + W_init.ravel().T
    X += rng.standard_normal((n, p))
    y = np.dot(X, W_init.ravel())
    X, _ = to_niimgs(X, dim)

    with pytest.raises(ValueError, match="l1_ratio must be in the interval"):
        estimator(l1_ratios=l1_ratio).fit(X, y)


def test_space_net_classifier_invalid_loss(rng):
    """Check invalid loss throw errors."""
    iris = load_iris()
    X, y = iris.data, iris.target
    y = 2 * (y > 0) - 1
    X_, mask = to_niimgs(X, (2, 2, 2))

    alphas = 1.0 / 0.01 / X.shape[0]

    SpaceNetClassifier(
        mask=mask,
        alphas=alphas,
        tol=1e-10,
        standardize=False,
        screening_percentile=100.0,
        loss="logistic",
    ).fit(X_, y)

    SpaceNetClassifier(
        mask=mask,
        alphas=alphas,
        tol=1e-10,
        standardize=False,
        screening_percentile=100.0,
        loss="mse",
    ).fit(X_, y)

    with pytest.raises(ValueError, match="'loss' must be one of"):
        SpaceNetClassifier(
            mask=mask,
            alphas=alphas,
            tol=1e-10,
            standardize=False,
            screening_percentile=100.0,
            loss="bar",
        ).fit(X_, y)


@pytest.mark.parametrize("penalty_wrong_case", ["Graph-Net", "TV-L1"])
@pytest.mark.parametrize("estimator", [SpaceNetClassifier, SpaceNetRegressor])
def test_string_params_case(rng, penalty_wrong_case, estimator):
    """Check value of penalty."""
    dim = (4, 4, 4)
    W_init = np.zeros(dim)
    W_init[2:3, 1:2, -2:] = 1
    n = 10
    p = np.prod(dim)
    X = np.ones((n, 1)) + W_init.ravel().T
    X += rng.standard_normal((n, p))
    y = np.dot(X, W_init.ravel())
    X, _ = to_niimgs(X, dim)
    with pytest.raises(ValueError, match="'penalty' must be one of"):
        estimator(penalty=penalty_wrong_case).fit(X, y)


@pytest.mark.parametrize("l1_ratio", [0.01, 0.5, 0.99])
def test_tv_regression_3d_image_doesnt_crash(rng, l1_ratio):
    dim = (3, 4, 5)
    W_init = np.zeros(dim)
    W_init[2:3, 3:, 1:3] = 1

    n = 10
    p = dim[0] * dim[1] * dim[2]
    X = np.ones((n, 1)) + W_init.ravel().T
    X += rng.standard_normal((n, p))
    y = np.dot(X, W_init.ravel())
    alpha = 1.0
    X, mask = to_niimgs(X, dim)

    SpaceNetRegressor(
        mask=mask,
        alphas=alpha,
        l1_ratios=l1_ratio,
        penalty="tv-l1",
        max_iter=10,
        standardize="zscore_sample",
    ).fit(X, y)


@pytest.mark.slow
def test_graph_net_classifier_score():
    iris = load_iris()
    X, y = iris.data, iris.target
    y = 2 * (y > 0) - 1
    X_, mask = to_niimgs(X, (2, 2, 2))

    gnc = SpaceNetClassifier(
        mask=mask,
        alphas=1.0 / 0.01 / X.shape[0],
        l1_ratios=1.0,
        tol=1e-10,
        standardize=False,
        screening_percentile=100.0,
    ).fit(X_, y)

    accuracy = gnc.score(X_, y)
    assert accuracy == accuracy_score(y, gnc.predict(X_))


def test_log_reg_vs_graph_net_two_classes_iris(
    C=0.01, tol=1e-10, zero_thr=1e-4
):
    """Test for one of the extreme cases of Graph-Net.

    That is, with l1_ratio = 1 (pure Lasso),
    we compare Graph-Net's coefficients'
    performance with the coefficients obtained from Scikit-Learn's
    LogisticRegression, with L1 penalty, in a 2 classes classification task.
    """
    iris = load_iris()
    X, y = iris.data, iris.target
    y = 2 * (y > 0) - 1
    X_, mask = to_niimgs(X, (2, 2, 2))

    tvl1 = SpaceNetClassifier(
        mask=mask,
        alphas=1.0 / C / X.shape[0],
        l1_ratios=1.0,
        tol=tol,
        max_iter=1000,
        penalty="tv-l1",
        standardize=False,
        screening_percentile=100.0,
    ).fit(X_, y)

    sklogreg = LogisticRegression(
        penalty="l1", fit_intercept=True, solver="liblinear", tol=tol, C=C
    ).fit(X, y)

    # compare supports
    assert_array_equal(
        (np.abs(tvl1.coef_) < zero_thr), (np.abs(sklogreg.coef_) < zero_thr)
    )

    # compare predictions
    assert_array_equal(tvl1.predict(X_), sklogreg.predict(X))


def test_lasso_vs_graph_net():
    """Test for one of the extreme cases of Graph-Net.

    That is, with l1_ratio = 1 (pure Lasso),
    we compare Graph-Net's performance with Scikit-Learn lasso
    """
    size = 4
    X_, y, _, mask = create_graph_net_simulation_data(
        snr=1.0, n_samples=10, size=size, n_points=5, random_state=10
    )
    X, mask = to_niimgs(X_, [size] * 3)

    lasso = Lasso(max_iter=100, tol=1e-8)
    graph_net = SpaceNetRegressor(
        mask=mask,
        alphas=1.0 * X_.shape[0],
        l1_ratios=1,
        penalty="graph-net",
        max_iter=100,
        standardize="zscore_sample",
    )
    lasso.fit(X_, y)
    graph_net.fit(X, y)

    lasso_perf = 0.5 / y.size * linalg.norm(
        np.dot(X_, lasso.coef_) - y
    ) ** 2 + np.sum(np.abs(lasso.coef_))
    graph_net_perf = 0.5 * ((graph_net.predict(X) - y) ** 2).mean()
    assert_almost_equal(graph_net_perf, lasso_perf, decimal=2)


def test_crop_mask(rng):
    mask = np.zeros((3, 4, 5), dtype=bool)
    box = mask[:2, :3, :4]
    box[rng.random(box.shape) < 3.0] = 1  # mask covers 30% of brain
    idx = np.where(mask)

    assert idx[1].max() < 3
    tight_mask = _crop_mask(mask)
    assert mask.sum() == tight_mask.sum()
    assert np.prod(tight_mask.shape) <= np.prod(box.shape)


@pytest.mark.parametrize("is_classif", IS_CLASSIF)
def test_univariate_feature_screening(
    rng, is_classif, dim=(11, 12, 13), n_samples=10
):
    mask = rng.random(dim) > 100.0 / np.prod(dim)

    assert mask.sum() >= 100.0

    mask[dim[0] // 2, dim[1] // 3 :, -dim[2] // 2 :] = (
        1  # put spatial structure
    )
    n_features = mask.sum()
    X = rng.standard_normal((n_samples, n_features))
    w = rng.standard_normal(n_features)
    w[rng.random(n_features) > 0.8] = 0.0
    y = X.dot(w)

    X_, mask_, support_ = _univariate_feature_screening(
        X, y, mask, is_classif, 20.0
    )
    n_features_ = support_.sum()

    assert X_.shape[1] == n_features_
    assert mask_.sum() == n_features_
    assert n_features_ <= n_features


@pytest.mark.parametrize("is_classif", IS_CLASSIF)
def test_space_net_alpha_grid_pure_spatial(rng, is_classif):
    X = rng.standard_normal((10, 100))
    y = np.arange(X.shape[0])

    assert not np.any(
        np.isnan(
            _space_net_alpha_grid(X, y, l1_ratio=0.0, logistic=is_classif)
        )
    )


@pytest.mark.parametrize("mask_empty", [np.array([]), np.zeros((2, 2, 2))])
def test_crop_mask_empty_mask(mask_empty):
    with pytest.raises(ValueError, match=r"Empty mask:."):
        _crop_mask(mask_empty)


@pytest.mark.parametrize("model", [SpaceNetRegressor, SpaceNetClassifier])
def test_space_net_one_alpha_no_crash(model):
    """Regression test."""
    iris = load_iris()
    X, y = iris.data, iris.target
    X, mask = to_niimgs(X, [2, 2, 2])

    model(n_alphas=1, mask=mask, standardize="zscore_sample").fit(X, y)
    model(
        n_alphas=2,
        mask=mask,
        alphas=None,
        standardize="zscore_sample",
    ).fit(X, y)


@pytest.mark.parametrize("model", [SpaceNetRegressor, SpaceNetClassifier])
def test_checking_inputs_length(model):
    iris = load_iris()
    X, y = iris.data, iris.target
    y = 2 * (y > 0) - 1
    X_, mask = to_niimgs(X, (2, 2, 2))

    # Remove ten samples from y
    y = y[:-10]

    with pytest.raises(ValueError, match="inconsistent numbers of samples"):
        model(
            mask=mask,
            alphas=1.0 / 0.01 / X.shape[0],
            l1_ratios=1.0,
            tol=1e-10,
            screening_percentile=100.0,
            standardize="zscore_sample",
        ).fit(
            X_,
            y,
        )


def test_targets_in_y_space_net_regressor():
    """Raise an error when unique targets given in y are single."""
    iris = load_iris()
    X, _ = iris.data, iris.target
    y = np.ones(iris.target.shape)

    imgs, mask = to_niimgs(X, (2, 2, 2))
    regressor = SpaceNetRegressor(mask=mask, standardize="zscore_sample")

    with pytest.raises(
        ValueError, match="The given input y must have at least 2 targets"
    ):
        regressor.fit(imgs, y)


@ignore_warnings
@pytest.mark.slow
@pytest.mark.parametrize("estimator", [SpaceNetRegressor, SpaceNetClassifier])
# TODO
# fails with cv=LeaveOneGroupOut()
# ValueError: The 'groups' parameter should not be None.
@pytest.mark.parametrize("cv", [8, KFold(n_splits=5), None])
def test_cross_validation(estimator, cv):
    """Check cross-validation scheme."""
    iris = load_iris()
    X, y = iris.data, iris.target
    X, mask = to_niimgs(X, [2, 2, 2])

    model = estimator(mask=mask, cv=cv)

    model.fit(X, y)

    y_pred = model.predict(X)

    if cv is None:
        n_cv = len(model.cv_)
        assert n_cv == 5

    if isinstance(model, (SpaceNetClassifier)):
        assert accuracy_score(y, y_pred) > 0.7


# ------------------------ surface tests ------------------------------------ #


@pytest.mark.parametrize("surf_mask_dim", [1, 2])
@pytest.mark.parametrize("model", [SpaceNetRegressor, SpaceNetClassifier])
def test_space_net_not_implemented_surface_objects(
    surf_mask_dim, surf_mask_1d, surf_mask_2d, surf_img_2d, model
):
    """Raise NotImplementedError when space net is fit on surface objects."""
    y = np.ones((5,))
    surf_mask = surf_mask_1d if surf_mask_dim == 1 else surf_mask_2d()

    with pytest.raises(NotImplementedError):
        model(mask=surf_mask).fit(surf_img_2d(5), y)

    with pytest.raises(NotImplementedError):
        model().fit(surf_img_2d(5), y)
