using StatsBase
using Test, Random, StableRNGs, OffsetArrays

Random.seed!(1234)

n = 100000

# test that rng specification is working correctly
# a) if the same rng is passed to a sample function twice,
#    the results should be the same (repeatability)
# b) not specifying a rng should be the same as specifying Random.GLOBAL_RNG
#    and Random.default_rng() on Julia >= 1.3
function test_rng_use(func, non_rng_args...)
    # some sampling methods mutate a passed array and return it
    # so that the tests don't pass trivially, we need to copy those
    # pre-allocated storage arrays

    # repeatability
    @test func(MersenneTwister(1), deepcopy(non_rng_args)...) ==
          func(MersenneTwister(1), deepcopy(non_rng_args)...)
    # default RNG is Random.GLOBAL_RNG/Random.default_rng()
    Random.seed!(47)
    x = func(deepcopy(non_rng_args)...)
    Random.seed!(47)
    y = func(Random.GLOBAL_RNG, deepcopy(non_rng_args)...)
    @test x == y
    Random.seed!(47)
    y = func(Random.default_rng(), deepcopy(non_rng_args)...)
    @test x == y
end

#### sample with replacement

function check_sample_wrep(a::AbstractArray, vrgn, ptol::Real; ordered::Bool=false, rev::Bool=false)
    vmin, vmax = vrgn
    (amin, amax) = extrema(a)
    @test vmin <= amin <= amax <= vmax
    n = vmax - vmin + 1
    p0 = fill(1/n, n)
    if ordered
        @test issorted(a; rev=rev)
        if ptol > 0
            @test isapprox(proportions(a, vmin:vmax), p0, atol=ptol)
        end
    else
        @test !issorted(a; rev=rev)
        ncols = size(a,2)
        if ncols == 1
            @test isapprox(proportions(a, vmin:vmax), p0, atol=ptol)
        else
            for j = 1:ncols
                aj = view(a, :, j)
                @test isapprox(proportions(aj, vmin:vmax), p0, atol=ptol)
            end
        end
    end
end

import StatsBase: direct_sample!

a = direct_sample!(1:10, zeros(Int, n, 3))
check_sample_wrep(a, (1, 10), 5.0e-3; ordered=false)

a = direct_sample!(3:12, zeros(Int, n, 3))
check_sample_wrep(a, (3, 12), 5.0e-3; ordered=false)

a = direct_sample!([11:20;], zeros(Int, n, 3))
check_sample_wrep(a, (11, 20), 5.0e-3; ordered=false)

test_rng_use(direct_sample!, 1:10, zeros(Int, 6))

a = sample(3:12, n)
check_sample_wrep(a, (3, 12), 5.0e-3; ordered=false)

for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64, Rational{Int})
    r = rev ? reverse(3:12) : (3:12)
    r = T===Int ? r : T.(r)
    aa = Int.(sample(r, n; ordered=true))
    check_sample_wrep(aa, (3, 12), 5.0e-3; ordered=true, rev=rev)

    aa = Int.(sample(r, 10; ordered=true))
    check_sample_wrep(aa, (3, 12), 0; ordered=true, rev=rev)
end

@test StatsBase._storeindices(1, 1, BigFloat) == StatsBase._storeindices(1, 1, BigFloat) == false

test_rng_use(sample, 1:10, 10)

@testset "sampling pairs" begin

    rng = StableRNG(1)

    @test samplepair(rng, 2)  ===  (2, 1)
    @test samplepair(rng, 10) === (5, 6)

    @test samplepair(rng, [3, 4, 2, 6, 8]) === (3, 8)
    @test samplepair(rng, [1, 2])          === (1, 2)

    onetwo = samplepair(rng, UInt128(2))
    @test extrema(onetwo) == (1, 2)
    @test eltype(onetwo) === UInt128
end

test_rng_use(samplepair, 1000)

#### sample without replacement

function check_sample_norep(a::AbstractArray, vrgn, ptol::Real; ordered::Bool=false, rev::Bool=false)
    # each column of a for one run

    vmin, vmax = vrgn
    (amin, amax) = extrema(a)
    @test vmin <= amin <= amax <= vmax
    n = vmax - vmin + 1

    for j = 1:size(a,2)
        aj = view(a,:,j)
        @assert allunique(aj)
        if ordered
            @assert issorted(aj, rev=rev)
        end
    end

    if ptol > 0
        p0 = fill(1/n, n)
        if ordered
            @test isapprox(proportions(a, vmin:vmax), p0, atol=ptol)
        else
            b = transpose(a)
            for j = 1:size(b,2)
                bj = view(b,:,j)
                @test isapprox(proportions(bj, vmin:vmax), p0, atol=ptol)
            end
        end
    end
end

import StatsBase: knuths_sample!, fisher_yates_sample!, self_avoid_sample!
import StatsBase: seqsample_a!, seqsample_c!, seqsample_d!

a = zeros(Int, 5, n)
for j = 1:size(a,2)
    knuths_sample!(3:12, view(a,:,j))
end
check_sample_norep(a, (3, 12), 5.0e-3; ordered=false)

test_rng_use(knuths_sample!, 1:10, zeros(Int, 6))

a = zeros(Int, 5, n)
for j = 1:size(a,2)
    fisher_yates_sample!(3:12, view(a,:,j))
end
check_sample_norep(a, (3, 12), 5.0e-3; ordered=false)

test_rng_use(fisher_yates_sample!, 1:10, zeros(Int, 6))

a = zeros(Int, 5, n)
for j = 1:size(a,2)
    self_avoid_sample!(3:12, view(a,:,j))
end
check_sample_norep(a, (3, 12), 5.0e-3; ordered=false)

test_rng_use(self_avoid_sample!, 1:10, zeros(Int, 6))

a = zeros(Int, 5, n)
for j = 1:size(a,2)
    seqsample_a!(3:12, view(a,:,j))
end
check_sample_norep(a, (3, 12), 5.0e-3; ordered=true)

test_rng_use(seqsample_a!, 1:10, zeros(Int, 6))

a = zeros(Int, 5, n)
for j = 1:size(a,2)
    seqsample_c!(3:12, view(a,:,j))
end
check_sample_norep(a, (3, 12), 5.0e-3; ordered=true)

test_rng_use(seqsample_c!, 1:10, zeros(Int, 6))

a = zeros(Int, 5, n)
for j = 1:size(a,2)
    seqsample_d!(3:12, view(a,:,j))
end
check_sample_norep(a, (3, 12), 5.0e-3; ordered=true)

test_rng_use(seqsample_d!, 1:10, zeros(Int, 6))

a = sample(3:12, 5; replace=false)
check_sample_norep(a, (3, 12), 0; ordered=false)

a = sample(3:12, 5; replace=false, ordered=true)
check_sample_norep(a, (3, 12), 0; ordered=true)

a = sample(reverse(3:12), 5; replace=false, ordered=true)
check_sample_norep(a, (3, 12), 0; ordered=true, rev=true)

# tests of multidimensional sampling

a = sample(3:12, (2, 2); replace=false)
check_sample_norep(a, (3, 12), 0; ordered=false)

@test sample(1:1, (2, 2); replace=true) == ones(Int, 2, 2)

# test of weighted sampling without replacement
a = [1:10;]
wv = Weights([zeros(6); 1:4])
x = vcat([sample(a, wv, 1, replace=false) for j in 1:100000]...)
@test minimum(x) == 7
@test maximum(x) == 10
@test maximum(abs, proportions(x) .- (1:4)/10) < 0.01

x = vcat([sample(a, wv, 2, replace=false) for j in 1:50000]...)
exact2 = [0.117261905, 0.220634921, 0.304166667, 0.357936508]
@test minimum(x) == 7
@test maximum(x) == 10
@test maximum(abs, proportions(x) .- exact2) < 0.01

x = vcat([sample(a, wv, 4, replace=false) for j in 1:10000]...)
@test minimum(x) == 7
@test maximum(x) == 10
@test maximum(abs, proportions(x) .- 0.25) == 0

@test_throws DimensionMismatch sample(a, wv, 5, replace=false)

wv = Weights([zeros(5); 1:4; -1])
@test_throws ErrorException sample(a, wv, 1, replace=false)

#### weighted sampling with dimension

# weights respected; this works because of the 0-weight
@test sample([1, 2], Weights([0, 1]), (2,2)) == [2 2 ; 2 2]
wm =  sample(collect(1:4), Weights(1:4), (2,2), replace=false)
@test size(wm) == (2, 2) # correct shape
@test length(Set(wm)) == 4 # no duplicates in elements


#### check that sample and sample! do the same thing
function test_same(;kws...)
    wv = Weights(rand(20))
    Random.seed!(1)
    x1 = sample(1:20, wv, 10; kws...)
    Random.seed!(1)
    x2 = zeros(Int, 10)
    sample!(1:20, wv, x2; kws...)
    @test x1 == x2
end

test_same()
test_same(replace=true)
test_same(replace=false)
test_same(replace=true, ordered=true)
test_same(replace=false, ordered=true)
test_same(replace=true, ordered=false)
test_same(replace=false, ordered=false)

@testset "validation of inputs" begin
    for f in (sample!, knuths_sample!, fisher_yates_sample!, self_avoid_sample!,
            seqsample_a!, seqsample_c!, seqsample_d!)
        x = rand(10)
        y = rand(10)
        ox = OffsetArray(x, -4:5)
        oy = OffsetArray(y, -4:5)

        # Test that offset arrays throw an error
        @test_throws ArgumentError f(ox, y)
        @test_throws ArgumentError f(x, oy)
        @test_throws ArgumentError f(ox, oy)

        # Test that an error is thrown when output shares memory with inputs
        @test_throws ArgumentError f(x, x)
        @test_throws ArgumentError f(view(x, 2:4), view(x, 3:5))
        # This corner case should succeed
        f(view(x, 2:4), view(x, 5:6))
    end
end

@testset "issue #872" begin
    for T in [Int8, Int16, Int32, Int64, Int128, BigInt], f in [identity, unsigned]
        T == BigInt && f == unsigned && continue
        T = f(T)
        # The type of the second argument should not affect the return type
        let samp = sample(T(1):T(10), T(2); replace=false, ordered=false)
            @test all(x -> x isa T, samp)
            @test all(x -> T(1) <= x <= T(10), samp)
            @test length(samp) == 2
        end
        let samp = sample(T(1):T(10), 2; replace=false, ordered=false)
            @test all(x -> x isa T, samp)
            @test all(x -> T(1) <= x <= T(10), samp)
            @test length(samp) == 2
        end
        let samp = sample(1:10, T(2); replace=false, ordered=false)
            @test all(x -> x isa Int, samp)
            @test all(x -> 1 <= x <= 10, samp)
            @test length(samp) == 2
        end
    end
end

# Custom weights without `values` field
struct YAUnitWeights <: AbstractWeights{Int, Int, Vector{Int}}
    n::Int
end
Base.sum(wv::YAUnitWeights) = wv.n
Base.length(wv::YAUnitWeights) = wv.n
Base.isempty(wv::YAUnitWeights) = iszero(wv.n)
Base.size(wv::YAUnitWeights) = (wv.n,)
Base.axes(wv::YAUnitWeights) = (Base.OneTo(wv.n),)
function Base.getindex(wv::YAUnitWeights, i::Int)
    @boundscheck checkbounds(wv, i)
    return 1
end

@testset "issue #950" begin
    # Sampling with unit weights behaves the same as sampling without weights
    Random.seed!(123)
    xs = sample(1:100, uweights(100), 10; replace=false)
    Random.seed!(123)
    @test xs == sample(1:100, 10; replace=false)

    Random.seed!(123)
    x = sample(uweights(100))
    Random.seed!(123)
    @test x == sample(1:100)

    Random.seed!(123)
    xs = direct_sample!(1:100, uweights(100), Vector{Int}(undef, 10))
    Random.seed!(123)
    @test xs == direct_sample!(1:100, Vector{Int}(undef, 10))

    # Errors
    @test_throws DimensionMismatch("Number of samples (100) and sample weights (99) must be equal.") sample(1:100, uweights(99), 10; replace=false)
    @test_throws DimensionMismatch("Number of samples (80) and sample weights (53) must be equal.") direct_sample!(1:80, uweights(53), Vector{Int}(undef, 10))

    # Custom unit weights don't error and behave the same as sampling with `Weights`
    Random.seed!(123)
    xs = sample(1:100, YAUnitWeights(100), 10; replace=false)
    Random.seed!(123)
    @test xs == sample(1:100, weights(ones(Int, 100)), 10; replace=false)
    for f in (StatsBase.efraimidis_a_wsample_norep!, StatsBase.efraimidis_ares_wsample_norep!, StatsBase.efraimidis_aexpj_wsample_norep!)
        Random.seed!(123)
        xs = f(1:100, YAUnitWeights(100), Vector{Int}(undef, 10))
        Random.seed!(123)
        @test xs == f(1:100, weights(ones(Int, 100)), Vector{Int}(undef, 10))
    end
end
