Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/Symbolics/drop_powers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ julia>drop_powers((x+y)^2 + (x+y)^3, [x,y], 3)
x^2 + y^2 + 2*x*y
```
"""
function drop_powers(expr::BasicSymbolic, vars::Vector{<:BasicSymbolic}, deg::Int)
return unwrap(drop_powers(Num(expr), Num.(vars), deg))
end

# Num fallback: the actual implementation
function drop_powers(expr::Num, vars::Vector{Num}, deg::Int)
Symbolics.@variables ϵ
subs_expr = deepcopy(expr)
Expand Down Expand Up @@ -43,19 +48,25 @@ function drop_powers(eqs::Vector{Equation}, var::Vector{Num}, deg::Int)
Equation(drop_powers(eq.lhs, var, deg), drop_powers(eq.rhs, var, deg)) for eq in eqs
]
end
function drop_powers(expr::BasicSymbolic, var::BasicSymbolic, deg::Int)
drop_powers(expr, [var], deg)
end
drop_powers(expr, var::Num, deg::Int) = drop_powers(expr, [var], deg)
drop_powers(x, vars, deg::Int) = drop_powers(wrap(x), vars, deg)
# ^ TODO: in principle `drop_powers` should get in BasicSymbolic and have a fallback for Num
drop_powers(x, vars, deg::Int) = drop_powers(Num(x), Num.(vars), deg)

"Return the highest power of `y` occurring in the term `x`."
function max_power(x::Num, y::Num)
function max_power(x::BasicSymbolic, y::BasicSymbolic)
terms = get_all_terms(x)
powers = power_of.(terms, y)
return maximum(powers)
end

# Num fallback: unwrap to BasicSymbolic
max_power(x::Num, y::Num) = max_power(unwrap(x), unwrap(y))
max_power(x::BasicSymbolic, y::Num) = max_power(x, unwrap(y))
max_power(x::Vector{Num}, y::Num) = maximum(max_power.(x, y))
max_power(x::Complex, y::Num) = maximum(max_power.([x.re, x.im], y))
max_power(x::Num, y::BasicSymbolic) = max_power(unwrap(x), y)
max_power(x, t) = max_power(wrap(x), wrap(t))

"Return the power of `y` in the term `x`"
Expand Down
16 changes: 15 additions & 1 deletion test/symbolics.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Test
using Symbolics
using SymbolicUtils: Fixpoint, Prewalk, PassThrough, BasicSymbolic
using SymbolicUtils: Fixpoint, Prewalk, PassThrough, BasicSymbolic, unwrap

using QuestBase: @eqtest

Expand Down Expand Up @@ -59,6 +59,20 @@ end

@eqtest drop_powers([a^2 + a + b, b], a, 2) == [a + b, b]
@eqtest drop_powers([a^2 + a + b, b], [a, b], 2) == [a + b, b]

@testset "BasicSymbolic" begin
using SymbolicUtils: @syms

@syms a_bs b_bs

@test max_power(a_bs^2 + b_bs, a_bs) == 2
@test max_power(a_bs * ((a_bs + b_bs)^4)^2 + a_bs, a_bs) == 9

@test isequal(drop_powers(a_bs^2 + b_bs, a_bs, 2), b_bs)
@test isequal(drop_powers((a_bs + b_bs)^2, a_bs, 1), b_bs^2)
@test isequal(drop_powers((a_bs + b_bs)^2, [a_bs, b_bs], 2), unwrap(Num(0)))
@test drop_powers(a_bs^2 + b_bs, a_bs, 2) isa BasicSymbolic
end
end

@testset "trig_to_exp and exp_to_trig" begin
Expand Down
Loading