using QuantumCumulants
using ModelingToolkit
using OrdinaryDiffEq
using Test

@testset "two-level-laser" begin

    N = 10

    Δ = cnumbers((Symbol(:Δ_, i) for i = 1:N)...)
    g = cnumbers((Symbol(:g_, i) for i = 1:N)...)
    γ = cnumbers((Symbol(:γ_, i) for i = 1:N)...)
    # ν = cnumbers((Symbol(:ν_, i) for i=1:N)...)
    ν = rnumbers((Symbol(:ν_, i) for i = 1:N)...)
    @cnumbers κ

    h_cavity = FockSpace(:cavity)
    h_atoms = [NLevelSpace(Symbol(:atom, i), (:g, :e)) for i = 1:N]
    h = tensor(h_cavity, h_atoms...)

    a = Destroy(h, :a)
    σ(i, j, k) = Transition(h, Symbol(:σ_, k), i, j, k+1)

    ops = begin
        ops_ = [a'*a; [σ(:e, :e, k) for k = 1:N]; [a'*σ(:g, :e, k) for k = 1:N]]
        for i = 1:N
            for j = (i+1):N
                push!(ops_, σ(:e, :g, i)*σ(:g, :e, j))
            end
        end
        ops_
    end

    n_eqs = div(N*(N-1), 2) + 2N + 1

    H =
        sum(Δ[i]*σ(:e, :e, i) for i = 1:N) +
        sum(g[i]*(a'*σ(:g, :e, i) + a*σ(:e, :g, i)) for i = 1:N)

    J = [a; [σ(:g, :e, k) for k = 1:N]; [σ(:e, :g, k) for k = 1:N]]
    Jdagger = adjoint.(J)
    rates = [κ, γ..., ν...]
    he = meanfield(ops, H, J; Jdagger = Jdagger, rates = rates, simplify = true, order = 2)

    missed = find_missing(he)

    subs = Dict(missed .=> 0)

    he_nophase = substitute(he, subs)

    @test isempty(find_missing(he_nophase))

    eqs_mtk = equations(he_nophase)

    @named sys = System(he_nophase)

    u0 = zeros(ComplexF64, length(ops))
    p0 = [
        κ => 1,
        (γ .=> 0.25 .* ones(N))...,
        (ν .=> 4 .* ones(N))...,
        (g .=> 1.5 .* ones(N))...,
        (Δ .=> ones(N))...,
    ]
    dict = merge(Dict(unknowns(sys) .=> u0), Dict(p0...))
    prob = ODEProblem(sys, dict, (0.0, 10.0))

    sol = solve(prob, RK4())

    n = real.(sol[average(a'*a)])
    pe = sol[average(σ(:e, :e, 1))]

    @test get_solution(sol, σ(:g, :g, 1)) == -sol[σ(:e, :e, 1)] .+ 1
    @test get_solution(sol, 2a'a + 3*σ(:e, :e, 1)) == 2*sol[a'a] + 3*sol[σ(:e, :e, 1)]

    # Test with complete and custom filter
    ϕ(::Destroy) = -1
    ϕ(::Create) = 1
    function ϕ(t::Transition)
        if t.i != t.j
            t.i == :e && return 1
            return -1
        else
            return 0
        end
    end
    function ϕ(q::QuantumCumulants.QMul)
        p = 0
        for arg ∈ q.args_nc
            p += ϕ(arg)
        end
        return p
    end
    ϕ(avg::Average) = ϕ(avg.arguments[1])
    phase_invariant(x) = iszero(ϕ(x))

    he_n = meanfield(a'*a, H, J; rates = rates)
    complete!(he_n; filter_func = phase_invariant)

    @test length(he_n.equations) == length(ops)
    @test isempty(find_missing(he_n))


    @named sys_comp = System(he_n)
    dict = merge(Dict(unknowns(sys_comp) .=> u0), Dict(p0...))
    prob_comp = ODEProblem(sys_comp, dict, (0.0, 10.0))

    sol_comp = solve(prob_comp, RK4())

    @test getindex.(sol.u, 1) ≈ getindex.(sol_comp.u, 1)

end # testset
