## tests of find_zero interface
using Roots
using Test
using ForwardDiff;
Base.adjoint(f) = x -> ForwardDiff.derivative(f, float(x));

# for a user-defined method
import Roots.Accessors
import Roots.Accessors: @reset
struct Order3_Test <: Roots.AbstractSecantMethod end

## Test the interface
@testset "find_zero interface tests" begin
    meths = [
        Order0(),
        Order1(),
        Roots.Order1B(),
        Roots.King(),
        Order2(),
        Roots.Steffensen(),
        Roots.Order2B(),
        Roots.Esser(),
        Order5(),
        Roots.KumarSinghAkanksha(),
        Order8(),
        Roots.Thukral8(),
        Order16(),
        Roots.Thukral16(),
        Roots.LithBoonkkampIJzerman(3, 0),
        Roots.LithBoonkkampIJzerman(4, 0),
        Roots.Sidi(2),
    ]

    ## different types of initial values
    for m in meths
        @test find_zero(sin, 3, m) ≈ pi
        @test find_zero(sin, 3.0, m) ≈ pi
        @test find_zero(sin, big(3), m) ≈ pi
        @test find_zero(sin, big(3.0), m) ≈ pi
        @test find_zero(sin, π, m) ≈ pi
        @test find_zero(x -> x^2 - 2.0f0, 2.0f0, m) ≈ sqrt(2) # issue 421
        @test isnan(solve(ZeroProblem(x -> x^2 + 2, 0.5f0)))
    end

    ## defaults for method argument
    @test find_zero(sin, 3.0) ≈ pi    # order0()
    @test @inferred(find_zero(sin, (3, 4))) ≈ π   # Bisection()
    @test @inferred(find_zero(sin, [3, 4])) ≈ π   # Bisection()

    ## test tolerance arguments
    ## xatol, xrtol, atol, rtol, maxevals, strict
    fn, xstar = x -> sin(x) - x + 1, 1.9345632107520243
    x0, M = 20.0, Order2()
    @test find_zero(fn, x0, M) ≈ xstar   # needs 16 iterations, 33 fn evaluations, difference is exact

    # test of maxevals
    @test_throws Roots.ConvergenceFailed find_zero(fn, x0, M, maxevals=2)

    # tolerance on f, atol, rtol: f(x) ~ 0
    M = Order2()
    h = 1e-2
    rt = find_zero(fn, x0, M, atol=h, rtol=0.0)
    @test abs(fn(rt)) > h^2 / 100
    rt = find_zero(fn, x0, M, atol=0.0, rtol=h)
    @test abs(fn(rt)) > h^2 / 100

    ## test of tolerances xatol, xrtol with bisection
    a, b = 1.5, 2.0
    h = 1e-6
    M = Roots.Bisection()
    tracks = Roots.Tracks(Float64, Float64)
    if VERSION >= v"1.6.0"
        @inferred(find_zero(fn, (a, b), M, tracks=tracks, xatol=h, xrtol=0.0))
        u, v = tracks.abₛ[end]
        @test h >= abs(u - v) >= h / 2
    end

    ## test of strict
    fn, x0 = x -> cos(x) - 1, pi / 4
    @test fn(find_zero(fn, x0, Order5())) <= 1e-8
    @test_throws Roots.ConvergenceFailed find_zero(fn, x0, Order5(), strict=true)

    # xn increment needs atol setting for zeros near 0.0 if strict=true
    M = Order1()
    fn = x -> x * exp(x) + nextfloat(0.0)
    @test_throws Roots.ConvergenceFailed find_zero(
        fn,
        1.0,
        M,
        atol=0.0,
        rtol=0.0,
        strict=true,
        xatol=0.0,
    )
    @test abs(find_zero(fn, 1.0, M, atol=0.0, rtol=0.0, strict=true)) <= eps()

    ## test of extreme values for fn, bisection
    c = pi
    fn = x -> Inf * sign(x - c)
    @inferred(find_zero(fn, (-Inf, Inf))) ≈ c

    fn = x -> Inf * x / abs(x) # stop at NaN values
    @inferred(find_zero(fn, (-Inf, Inf))) ≈ 0

    bracketing_meths = (
        Roots.Bisection(),
        Roots.A42(),
        Roots.AlefeldPotraShi(),
        Roots.Brent(),
        Roots.Ridders(),
        Roots.ITP(),
        Roots.Ridders(),
        Roots.FalsePosition(),
        Roots.FalsePosition(2),
    )

    # test flexibility in interval specification
    for M in bracketing_meths
        @test @inferred(find_zero(sin, (3, 4))) ≈ pi
        @test @inferred(find_zero(sin, [3, 4])) ≈ pi
        @test @inferred(find_zero(sin, 3:4)) ≈ pi
        @test @inferred(find_zero(sin, SomeInterval(3, 4))) ≈ pi
        @test @inferred(find_zero(sin, range(3, stop=4, length=20))) ≈ pi
    end

    # test issue when non type stalbe
    h(x) = x < 2000 ? -1000 : -1000 + 0.1 * (x - 2000)
    a, b, xᵅ = 0, 20_000, 12_000
    for M in bracketing_meths
        @test find_zero(h, (a, b), M) ≈ xᵅ
    end
end

@testset "non simple zeros" begin
    Ms = (
        Roots.Order1B(),
        Roots.Order2B(),
        Roots.Schroder(),
        Roots.Thukral2B(),
        Roots.Thukral3B(),
        Roots.Thukral4B(),
        Roots.Thukral5B(),
    )

    g(x) = exp(x) + x - 2
    f(x) = g(x)^2
    x₀ = 1 / 4

    α = find_zero(g, x₀)
    fs = (f, f', f'', f''', f'''', f''''', f'''''')
    for M in Ms
        @test find_zero(fs, x₀, M) ≈ α atol = 1e-6
    end
end

@testset "find_zero internals" begin

    ##  init_state method
    g1 = x -> x^5 - x - 1
    x0_, xstar_ = (1.0, 2.0), 1.1673039782614187
    M = Roots.A42()
    G1 = Roots.Callable_Function(M, g1)
    state = @inferred(Roots.init_state(M, G1, x0_))
    options = Roots.init_options(M, state)
    for M in (Roots.A42(), Roots.Bisection(), Roots.FalsePosition())
        Gₘ = Roots.Callable_Function(M, G1)
        stateₘ = @inferred(Roots.init_state(M, state, Gₘ))
        @test @inferred(solve(M, Gₘ, stateₘ)) ≈ xstar_
    end

    # iterator interface (ZeroProblem, solve; init, solve!)
    meths = [
        Order0(),
        Order1(),
        Roots.Order1B(),
        Roots.King(),
        Order2(),
        Roots.Steffensen(),
        Roots.Order2B(),
        Roots.Esser(),
        Order5(),
        Roots.KumarSinghAkanksha(),
        Order8(),
        Roots.Thukral8(),
        Order16(),
        Roots.Thukral16(),
    ]
    g1(x) = x^5 - x - 1
    x0_, xstar_ = 1.16, 1.1673039782614187
    fx = ZeroProblem(g1, x0_)
    for M in meths
        @test solve(fx, M) ≈ xstar_
        P = init(fx, M)
        @test solve!(P) ≈ xstar_
    end

    # solve and parameters
    # should be positional, but named supported for now
    g2 = (x, p) -> cos(x) - x / p
    fx = ZeroProblem(g2, (0, pi / 2))
    @test solve(fx, 2) ≈ @inferred(find_zero(x -> cos(x) - x / 2, (0, pi / 2)))
    @test solve(fx, p=2) ≈ @inferred(find_zero(x -> cos(x) - x / 2, (0, pi / 2)))
    @test @inferred(solve(fx, p=3)) ≈ @inferred(find_zero(x -> cos(x) - x / 3, (0, pi / 2)))
    g3 = (x, p) -> cos(x) + p[1] * x - p[2]
    fx = ZeroProblem(g3, (0, pi / 2))
    @test @inferred(solve(fx, p=[-1 / 10, 1 / 10])) ≈
          @inferred(find_zero(x -> cos(x) - x / 10 - 1 / 10, (0, pi / 2)))

    ### issue 321, solve and broadcasting
    myfun(x, p) = x * sin(x) - p
    prob = ZeroProblem(myfun, (0.0, 2.0))
    ps = (1 / 4, 1 / 2, 3 / 4, 1)
    as = (0.5111022402679033, 0.7408409550954906, 0.9333080372907439, 1.1141571408719302)
    @test all(solve.(prob, Bisection(), ps) .≈ as)

    ## test with early evaluation of bracket
    f = x -> sin(x)
    xs = (3.0, 4.0)
    fxs = f.(xs)
    M = Bisection()
    state = @inferred(Roots.init_state(M, f, xs..., fxs..., m=3.5, fm=f(3.5)))
    @test @inferred(solve!(init(M, f, state))) ≈ π

    #     ## hybrid
    g1 = x -> exp(x) - x^4
    x0_, xstar_ = (5.0, 20.0), 8.613169456441398
    M = Roots.Bisection()
    G1 = Roots.Callable_Function(M, g1)
    state = @inferred(Roots.init_state(M, G1, x0_))
    options = Roots.init_options(M, state, xatol=1 / 2)
    ZPI = @inferred(init(M, G1, state, options))
    ϕ = iterate(ZPI)
    while ϕ !== nothing
        val, st = ϕ
        state, ctr = st
        ϕ = iterate(ZPI, st)
    end

    N = Roots.Order1() # switch to N
    G2 = Roots.Callable_Function(N, G1)
    stateₙ = Roots.init_state(N, state, G2)
    options = Roots.init_options(N, stateₙ)
    x = solve(N, G2, stateₙ, options)
    @test x ≈ xstar_

    ## test creation of new methods
    ## xn - f/f' - f'' * f^2 / 2(f')^3 = xn - r1 - r1^2/r2 is third order,
    # had to previously define:
    function Roots.update_state(
        M::Order3_Test,
        f,
        o::Roots.AbstractUnivariateZeroState{T,S},
        options,
        l=Roots.NullTracks(),
    ) where {T,S}
        # xn - f/f' - f'' * f^2 / 2(f')^3 = xn - r1 - r1^2/r2 is third order
        xn_1, xn = o.xn0, o.xn1
        fxn_1, fxn = o.fxn0, o.fxn1

        f_10 = (fxn - fxn_1) / (xn - xn_1)
        xn1::T = xn - fxn / f_10
        fxn1::S = f(xn1)

        f01 = (fxn1 - fxn) / (xn1 - xn)

        if isnan(f_10) || iszero(f_10) || isnan(f01) || iszero(f01)
            return (o, true)
        end

        r1 = fxn1 / f01
        r2 = f01 / ((f01 - f_10) / (xn1 - xn_1))

        wn = xn1 - r1 - r1^2 / r2
        fwn::S = f(wn)

        @reset o.xn0 = xn
        @reset o.xn1 = wn
        @reset o.fxn0 = fxn
        @reset o.fxn1 = fwn

        return (o, false)
    end

    g1 = x -> exp(x) - x^4
    @test find_zero(g1, 8.3, Order3_Test()) ≈ find_zero(g1, 8.3, Order1())

    # test many different calling styles
    f(x) = (sin(x), sin(x) / cos(x)) # x -> (f(x), f(x)/f′(x))
    fs(x) = (sin, cos) # (f, f′)
    x0 = (3, 4)
    g(x, p) = begin
        fx = cos(x) - x / p
        (fx, fx / (-sin(x) - 1 / p))
    end
    x0a = (0.0, pi / 2)
    α₂, α₃ = 1.0298665293222589, 1.1701209500026262
    @test find_zero(f, x0) ≈ π
    @test find_zero(f, first(x0)) ≈ π
    @test find_zero(g, x0a, p=2) ≈ α₂
    @test find_zero(g, first(x0a), p=2) ≈ α₂
    Z = ZeroProblem(f, x0)
    Za = ZeroProblem(g, x0a)
    @test solve(Z) ≈ π
    @test solve(Za, 3) ≈ α₃
    @test solve(Za, p=2) ≈ α₂
    @test solve!(init(Z)) ≈ π
    @test solve!(init(Za, 3)) ≈ α₃
    @test solve!(init(Za, p=3)) ≈ α₃
    Ms = (Roots.Secant(), Roots.Bisection(), Roots.Newton())
    for M in Ms
        @test find_zero(f, x0, M) ≈ π
        @test solve(Z, M) ≈ π
        @test solve!(init(Z, M)) ≈ π
        @test find_zero(g, x0a, M, p=2) ≈ α₂
        @test solve(Za, M, 2) ≈ α₂
        @test solve(Za, M, p=2) ≈ α₂
        @test solve!(init(Za, M, 2)) ≈ α₂
    end

    ## test broadcasting semantics with ZeroProblem
    ## This assume parameters can be passed in a positional manner, a
    ## style which is discouraged, as it is confusing
    Z = ZeroProblem((x, p) -> cos(x) - x / p, pi / 4)
    @test all(solve.(Z, (1, 2)) .≈ (solve(Z, 1), solve(Z, 2)))
end

@testset "find_zero issue tests" begin

    ## Misc tests
    Ms = [Order0(), Order1(), Order2(), Order5(), Order8(), Order16()]

    ## issues with starting near a maxima. Some bounce out of it, but
    ## one would expect all to have issues
    fn, xstar = x -> x^3 + 4x^2 - 10, 1.365230013414097
    for M in [Order1(), Roots.Order1B(), Order2(), Roots.Order2B(), Order5()]
        @test_throws Roots.ConvergenceFailed find_zero(fn, -1.0, M)
    end
    for M in [Order0(), Roots.Thukral8(), Roots.Thukral16()]
        @test find_zero(fn, -1.0, M) ≈ xstar
    end

    ## non-simple root
    ## convergence can depend on relaxed convergence checked after an issue
    fn, xstar, x0 = x -> cos(x) - 1, 0.0, 0.1
    for M in Ms
        xn = find_zero(fn, x0, M)
        @test abs(fn(xn)) <= 1e-10
    end
    for M in [Roots.Order1B(), Order2(), Roots.Order2B(), Order5(), Order8(), Order16()]
        @test_throws Roots.ConvergenceFailed find_zero(fn, x0, M, strict=true)
    end

    ## issue with large steps
    fn, x0 = x -> x^20 - 1, 0.5
    for M in Ms[2:end] # not 0, as it uses bracket
        @test_throws Roots.ConvergenceFailed find_zero(fn, x0, M)
    end

    ## issue with large f''
    fn, x0 = cbrt, 1.0
    for M in [Order1(), Order2(), Order5()]
        @test_throws Roots.ConvergenceFailed find_zero(fn, x0, M)
    end
    ### these stop but only because rtol is used for checking f(xn) ~ 0
    for M in [Roots.Thukral8(), Roots.Thukral16()]
        @test abs(find_zero(fn, x0, M) - 0.0) >= 100
    end

    ## similar (http://people.sc.fsu.edu/~jburkardt/cpp_src/test_zero/test_zero.html)
    function newton_baffler(x)
        a = 1 / 10
        m, b = 1 / 4, 1 / 8

        if x < -a
            m * x - b
        elseif x > a
            m * x + b
        else
            (m * a + b) / a * (x + a) + (-m * a - b)
        end
    end
    for M in
        (Order0(), Order1(), Roots.Order1B(), Order2(), Roots.Order2B(), Order5(), Order8())
        @test abs(find_zero(newton_baffler, 1.0, M)) <= 1e-15
    end
    for M in (Roots.KumarSinghAkanksha(), Roots.Thukral8(), Roots.Thukral16())
        @test_throws Roots.ConvergenceFailed find_zero(newton_baffler, 1.0, M)
    end

    ## Closed issues ###
    ## issue tests: put in tests to ensure closed issues don't reappear.

    ## issue #94; tolerances not matching documentation
    function test_94(; kwargs...)
        g, T = 1.62850, 14.60000
        α, t1, tf = 0.00347, 40.91375, 131.86573
        y, ya, yf = 0.0, 9000.0, 10000.0
        vy = sqrt(2g * (ya - y))
        θ0, θ1 = atan(α * tf), atan(α * (tf - t1))
        I_sintan(x) = tan(x) / 2cos(x) - atanh(tan(x / 2))
        I_sintan(x, y) = I_sintan(y) - I_sintan(x)
        function lhs(θ)
            tRem = (vy - T / α * (sec(θ1) - sec(θ))) / g
            val = -yf + y + vy * tRem - 0.5g * tRem^2 - T / α^2 * I_sintan(θ, θ1)
            val
        end

        M = Roots.FalsePosition()
        x0 = [atan(α * tf), atan(α * (tf - t1))]
        F = Roots.Callable_Function(M, lhs, nothing) #Roots.DerivativeFree(lhs)
        state = Roots.init_state(M, F, x0)
        options = Roots.init_options(M, state)
        l = Roots.Tracks(state)
        solve(ZeroProblem(lhs, x0), M; tracks=l)
        @test l.steps <= 45 # 15
    end
    test_94()

    ## Issue with quad_step after truncated M-step PR #140
    @test find_zero(x -> tanh(x) - tan(x), 7.36842, Order0()) ≈ 7.068582745628732

    ## Use tolerance on f, not x with bisectoin
    atol = 0.01
    if VERSION >= v"1.6.0"
        u = @inferred(find_zero(sin, (3, 4), atol=atol))
        @test atol >= abs(sin(u)) >= atol^2

        ## issue #159 bracket with zeros should be found
        @test @inferred(find_zero(x -> x + 1, (-1, 1))) == -1
    end

    ## issue #178 passing through method
    @test fzero(sin, 3, 4, Roots.Brent()) ≈ π

    ## issue #188 with A42
    f = let a = 0.18
        x -> x * (1 - x^2) / ((x^2 + a^2) * (1 + a^2 * x^2))
    end
    r = 0.05
    xs = (r + 1e-12, 1.0)
    @test find_zero(x -> f(r) - f(x), xs, Roots.A42()) ≈ 0.4715797678171889

    ## issue #336 verbose=true with complex values
    ## just test that this does not error
    for M in (Order1(), Roots.Newton())
        T = Complex{Float64}
        tracks = Roots.Tracks(T, T)
        find_zero((sin, cos), 1.0 + 1.0im, M; tracks=tracks)
        Roots.show_tracks(IOBuffer(), tracks, M)
    end

    ## Issue #343 non-type stable f
    f(t) = t <= 0 ? -1 : log(t)
    for M in (
        Roots.Order0(),
        Roots.Order1(),
        Roots.Order2(),
        Roots.Order5(),
        Roots.Order8(),
        Roots.Order16(),
    )
        @test find_zero(f, 3, M) ≈ 1
    end
end

struct _SampleCallableObject end
(::_SampleCallableObject)(x) = x^5 - x - 1

@testset "find_zero with other callable types" begin
    Ms = [
        Order0(),
        Order1(),
        Roots.Order1B(),
        Order2(),
        Roots.Order2B(),
        Order5(),
        Order8(),
        Order16(),
    ]

    for M in Ms
        @test find_zero(_SampleCallableObject(), 1.1, M) ≈ 1.1673039782614187
    end

    for M in Ms
        g = Cnt(x -> x^5 - x - 1)
        @test find_zero(g, 1.1, M) ≈ 1.1673039782614187
        @test g.cnt <= 30
    end
end

@testset "function evaluations" begin
    function wrapper(f)
        cnt = 0
        x -> begin
            cnt += 1
            f(x)
        end
    end

    # as of v"1.3.0", no more maxfnevals for stopping, just maxevals
    # this is an alternative
    function fz(f, x0::Number, M; maxfnevals=10, kwargs...)
        F = wrapper(f)
        ZPI = init(ZeroProblem(F, x0), M; kwargs...)
        x = NaN * float(x0)
        ϕ = iterate(ZPI)
        while ϕ !== nothing
            x, st = ϕ
            F.cnt.contents >= maxfnevals && return NaN * float(x0)
            ϕ = iterate(ZPI, st)
        end
        x
    end
    f(x) = x^20 - 1
    x0 = 0.9
    M = Order1()
    @test isnan(fz(f, x0, M))  # takes 19 fn evals, not 10

    # test that for update state, fnevals are correctly counted for simpler
    # methods
    fn = (x) -> sin(x)
    x0 = (3, 4)
    M = Order1()
    state = Roots.init_state(M, Roots.Callable_Function(M, fn), x0)
    options = Roots.init_options(M, state)

    for M in (
        Order1(),
        Order2(),
        Order5(),
        Order8(),
        Order16(),
        Roots.Order1B(),
        Roots.Order2B(),
        Roots.Bisection(),
        Roots.Brent(),
        Roots.Ridders(),
        Roots.ITP(),
        Roots.A42(),
        Roots.AlefeldPotraShi(),
    )

        # test initial count
        g = wrapper(fn)
        G = Roots.Callable_Function(M, g)
        Roots.init_state(M, G, x0)
        @test g.cnt.contents ≤ Roots.initial_fncalls(M)

        # test update state
        g = wrapper(fn)
        stateₘ = Roots.init_state(M, state, Roots.Callable_Function(M, fn))
        G = Roots.Callable_Function(M, g)
        l = Roots.Tracks(Float64, Float64)
        Roots.update_state(M, G, stateₘ, options, l)
        @test g.cnt.contents == l.fncalls
    end
end

@testset "_extrema" begin
    if VERSION >= v"1.6.0"
        @test @inferred(Roots._extrema((π, 0))) === (0.0, Float64(π))
        @test @inferred(Roots._extrema([π, 0])) === (0.0, Float64(π))
    end
    @test_throws ArgumentError Roots._extrema(π)
    @test_throws ArgumentError Roots._extrema((π, π))
    @test_throws ArgumentError Roots._extrema([π, π])
end

@testset "sensitivity" begin
    # Issue #349
    if VERSION >= v"1.9.0-"
        f(x, p) = cos(x) - first(p) * x
        x₀ = (0, pi / 2)
        F(p) = solve(ZeroProblem(f, x₀), Bisection(), p)
        G(p) = find_zero(f, x₀, Bisection(), p)
        H(p) = find_zero(f, x₀, Bisection(); p=p)

        ∂ = -0.4416107917053284
        @test ForwardDiff.derivative(F, 1.0) ≈ -0.4416107917053284
        @test ForwardDiff.gradient(F, [1.0, 2])[1] ≈ -0.4416107917053284
        @test ForwardDiff.derivative(G, 1.0) ≈ -0.4416107917053284
        @test ForwardDiff.gradient(G, [1.0, 2])[1] ≈ -0.4416107917053284
        @test ForwardDiff.derivative(H, 1.0) ≈ -0.4416107917053284
        @test ForwardDiff.gradient(H, [1.0, 2])[1] ≈ -0.4416107917053284
    end
end

@testset "bracketing_atol" begin
    ## issue $457
    f(x) = x^2 - 4
    @test find_zero(f, (0, Inf)) ≈ 2 # 2.0 correct
    @test find_zero(f, (0, Inf), atol=1) ≈ 1.9997558593749998
    @test find_zero(f, (0, Inf), atol=1e-5) ≈ 1.9999998807907102
    @test find_zero(f, (0, 8), atol=1) ≈ 1.99609375
    @test find_zero(f, (0, 8), atol=1e-3) ≈ 2.0000152587890625
end

@testset "similar methods" begin
    Lsidi, Lsec = Roots.Tracks(), Roots.Tracks()
    find_zero(sin, 3.0, Roots.Sidi(1); tracks=Lsidi)
    find_zero(sin, 3.0, Roots.Secant(); tracks=Lsec)
    @test Lsidi.xfₛ[3:end] == Lsec.xfₛ[3:end] # drop x₀x₁ ordering
end
