diff --git a/src/Symbolics/drop_powers.jl b/src/Symbolics/drop_powers.jl index 6919cc5..55def58 100644 --- a/src/Symbolics/drop_powers.jl +++ b/src/Symbolics/drop_powers.jl @@ -13,6 +13,11 @@ julia>drop_powers((x+y)^2 + (x+y)^3, [x,y], 3) x^2 + y^2 + 2*x*y ``` """ +function drop_powers(expr::BasicSymbolic, vars::Vector{<:BasicSymbolic}, deg::Int) + return unwrap(drop_powers(Num(expr), Num.(vars), deg)) +end + +# Num fallback: the actual implementation function drop_powers(expr::Num, vars::Vector{Num}, deg::Int) Symbolics.@variables ϵ subs_expr = deepcopy(expr) @@ -43,19 +48,25 @@ function drop_powers(eqs::Vector{Equation}, var::Vector{Num}, deg::Int) Equation(drop_powers(eq.lhs, var, deg), drop_powers(eq.rhs, var, deg)) for eq in eqs ] end +function drop_powers(expr::BasicSymbolic, var::BasicSymbolic, deg::Int) + drop_powers(expr, [var], deg) +end drop_powers(expr, var::Num, deg::Int) = drop_powers(expr, [var], deg) -drop_powers(x, vars, deg::Int) = drop_powers(wrap(x), vars, deg) -# ^ TODO: in principle `drop_powers` should get in BasicSymbolic and have a fallback for Num +drop_powers(x, vars, deg::Int) = drop_powers(Num(x), Num.(vars), deg) "Return the highest power of `y` occurring in the term `x`." -function max_power(x::Num, y::Num) +function max_power(x::BasicSymbolic, y::BasicSymbolic) terms = get_all_terms(x) powers = power_of.(terms, y) return maximum(powers) end +# Num fallback: unwrap to BasicSymbolic +max_power(x::Num, y::Num) = max_power(unwrap(x), unwrap(y)) +max_power(x::BasicSymbolic, y::Num) = max_power(x, unwrap(y)) max_power(x::Vector{Num}, y::Num) = maximum(max_power.(x, y)) max_power(x::Complex, y::Num) = maximum(max_power.([x.re, x.im], y)) +max_power(x::Num, y::BasicSymbolic) = max_power(unwrap(x), y) max_power(x, t) = max_power(wrap(x), wrap(t)) "Return the power of `y` in the term `x`" diff --git a/test/symbolics.jl b/test/symbolics.jl index 3044f5d..5ffbf0c 100644 --- a/test/symbolics.jl +++ b/test/symbolics.jl @@ -1,6 +1,6 @@ using Test using Symbolics -using SymbolicUtils: Fixpoint, Prewalk, PassThrough, BasicSymbolic +using SymbolicUtils: Fixpoint, Prewalk, PassThrough, BasicSymbolic, unwrap using QuestBase: @eqtest @@ -59,6 +59,20 @@ end @eqtest drop_powers([a^2 + a + b, b], a, 2) == [a + b, b] @eqtest drop_powers([a^2 + a + b, b], [a, b], 2) == [a + b, b] + + @testset "BasicSymbolic" begin + using SymbolicUtils: @syms + + @syms a_bs b_bs + + @test max_power(a_bs^2 + b_bs, a_bs) == 2 + @test max_power(a_bs * ((a_bs + b_bs)^4)^2 + a_bs, a_bs) == 9 + + @test isequal(drop_powers(a_bs^2 + b_bs, a_bs, 2), b_bs) + @test isequal(drop_powers((a_bs + b_bs)^2, a_bs, 1), b_bs^2) + @test isequal(drop_powers((a_bs + b_bs)^2, [a_bs, b_bs], 2), unwrap(Num(0))) + @test drop_powers(a_bs^2 + b_bs, a_bs, 2) isa BasicSymbolic + end end @testset "trig_to_exp and exp_to_trig" begin