Skip to content

Commit 7f2da5b

Browse files
authored
Merge pull request #54 from JeffFessler/master
Support differing x,y eltypes for dwt
2 parents d044058 + 72cbb3d commit 7f2da5b

6 files changed

Lines changed: 51 additions & 33 deletions

File tree

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ julia:
66
- 1.0
77
- 1.1
88
- 1.2
9+
- 1.4
910
- nightly
1011
matrix:
1112
allow_failures:

Project.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
name = "Wavelets"
22
uuid = "29a6e085-ba6d-5f35-a997-948ac2efa89a"
33
author = ["Gudmundur Adalsteinsson "]
4-
version = "0.9.0"
4+
version = "0.9.1"
5+
6+
[compat]
7+
DSP = "0.5.1"
8+
FFTW = "0.2.4"
9+
Reexport = "0.2.0"
10+
SpecialFunctions = "0.7.1"
11+
julia = " 1"
512

613
[deps]
714
DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2"

src/mod/Transforms.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ const DWTArray = AbstractArray
1111
const WPTArray = AbstractVector
1212
const ValueType = Union{AbstractFloat, Complex}
1313

14+
const FVector = StridedVector # e.g., work space vectors
15+
1416
# DWT
1517

1618
"""
@@ -114,8 +116,8 @@ for (Xwt, Xwt!, _Xwt!, fw) in ((:dwt, :dwt!, :_dwt!, true),
114116
y = Array{T}(undef, size(x))
115117
return ($_Xwt!)(y, x, filter, L, $fw)
116118
end
117-
function ($Xwt!)(y::DWTArray{T}, x::DWTArray{T}, filter::OrthoFilter,
118-
L::Integer=maxtransformlevels(x)) where T<:ValueType
119+
function ($Xwt!)(y::DWTArray{<:ValueType}, x::DWTArray{<:ValueType}, filter::OrthoFilter,
120+
L::Integer=maxtransformlevels(x))
119121
return ($_Xwt!)(y, x, filter, L, $fw)
120122
end
121123
# lifting
@@ -485,7 +487,10 @@ end # for
485487

486488
# Array with shared memory
487489
function unsafe_vectorslice(A::Array{T}, i::Int, n::Int) where T
488-
return unsafe_wrap(Array, pointer(A, i), n)::Vector{T}
490+
return unsafe_wrap(Array, pointer(A, i), n)::Vector{T}
491+
end
492+
function unsafe_vectorslice(A::StridedArray{T}, i::Int, n::Int) where T
493+
return @view A[i:(i-1+n)]
489494
end
490495

491496
# linear indices of start of rows/cols/planes

src/mod/Util.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,15 +363,15 @@ function merge!(b::AbstractArray{T}, ib::Integer, incb::Integer, a::AbstractVect
363363
end
364364

365365

366-
function stridedcopy!(b::AbstractVector{T}, a::AbstractArray{T}, ia::Integer, inca::Integer, n::Integer) where T<:Number
366+
function stridedcopy!(b::AbstractVector{<:Number}, a::AbstractArray{<:Number}, ia::Integer, inca::Integer, n::Integer)
367367
@assert ia+(n-1)*inca <= length(a) && n <= length(b)
368368

369369
@inbounds for i = 1:n
370370
b[i] = a[ia + (i-1)*inca]
371371
end
372372
return b
373373
end
374-
function stridedcopy!(b::AbstractArray{T}, ib::Integer, incb::Integer, a::AbstractVector{T}, n::Integer) where T<:Number
374+
function stridedcopy!(b::AbstractArray{<:Number}, ib::Integer, incb::Integer, a::AbstractVector{<:Number}, n::Integer)
375375
@assert ib+(n-1)*incb <= length(b) && n <= length(a)
376376

377377
@inbounds for i = 1:n

src/mod/transforms_filter.jl

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@
1010
# DWT
1111
# 1-D
1212
# writes to y
13-
function _dwt!(y::AbstractVector{T}, x::AbstractVector{T},
14-
filter::OrthoFilter, L::Integer, fw::Bool) where T<:Number
15-
si = Vector{T}(undef, length(filter)-1) # tmp filter vector
13+
function _dwt!(y::AbstractVector{Ty}, x::AbstractVector{Tx},
14+
filter::OrthoFilter, L::Integer, fw::Bool) where {Tx<:Number, Ty<:Number}
15+
T = promote_type(Tx, Ty)
16+
si = Vector{T}(undef, length(filter)-1) # tmp filter vector
1617
scfilter, dcfilter = WT.makereverseqmfpair(filter, fw, T)
1718
return _dwt!(y, x, filter, L, fw, dcfilter, scfilter, si)
1819
end
19-
function _dwt!(y::AbstractVector{T}, x::AbstractVector{T},
20+
function _dwt!(y::AbstractVector{<:Number}, x::AbstractVector{<:Number},
2021
filter::OrthoFilter, L::Integer, fw::Bool,
2122
dcfilter::Vector{T}, scfilter::Vector{T},
2223
si::Vector{T}, snew::Vector{T} = Vector{T}(undef, ifelse(L>1, length(x)>>1, 0))) where T<:Number
@@ -35,7 +36,7 @@ function _dwt!(y::AbstractVector{T}, x::AbstractVector{T},
3536
if L == 0
3637
return copyto!(y,x)
3738
end
38-
s = x # s is currect scaling coefs location
39+
s = x # s is current scaling coefs location
3940
filtlen = length(filter)
4041

4142
lrange = 1:L
@@ -59,10 +60,10 @@ function _dwt!(y::AbstractVector{T}, x::AbstractVector{T},
5960
end
6061
return y
6162
end
62-
function unsafe_dwt1level!(y::AbstractVector{T}, x::AbstractVector{T},
63+
function unsafe_dwt1level!(y::AbstractVector{<:Number}, x::AbstractVector{<:Number},
6364
filter::OrthoFilter, fw::Bool,
64-
dcfilter::Vector{T}, scfilter::Vector{T},
65-
si::Vector{T}) where T<:Number
65+
dcfilter::FVector{T}, scfilter::FVector{T},
66+
si::FVector{T}) where T<:Number
6667
n = length(x)
6768
l = 1
6869
filtlen = length(filter)
@@ -81,11 +82,11 @@ function unsafe_dwt1level!(y::AbstractVector{T}, x::AbstractVector{T},
8182
return y
8283
end
8384

84-
function dwt_transform_strided!(y::Array{T}, x::AbstractArray{T},
85+
function dwt_transform_strided!(y::AbstractArray{<:Number}, x::AbstractArray{<:Number},
8586
msub::Int, nsub::Int, stride::Int, idx_func::Function,
86-
tmpvec::Vector{T}, tmpvec2::Vector{T},
87+
tmpvec::FVector{T}, tmpvec2::FVector{T},
8788
filter::OrthoFilter, fw::Bool,
88-
dcfilter::Vector{T}, scfilter::Vector{T}, si::Vector{T}) where T<:Number
89+
dcfilter::FVector{T}, scfilter::FVector{T}, si::FVector{T}) where T<:Number
8990
for i=1:msub
9091
xi = idx_func(i)
9192
stridedcopy!(tmpvec, x, xi, stride, nsub)
@@ -94,11 +95,11 @@ function dwt_transform_strided!(y::Array{T}, x::AbstractArray{T},
9495
end
9596
end
9697

97-
function dwt_transform_cols!(y::Array{T}, x::AbstractArray{T},
98+
function dwt_transform_cols!(y::AbstractArray{<:Number}, x::AbstractArray{<:Number},
9899
msub::Int, nsub::Int, idx_func::Function,
99-
tmpvec::Vector{T},
100+
tmpvec::FVector{T},
100101
filter::OrthoFilter, fw::Bool,
101-
dcfilter::Vector{T}, scfilter::Vector{T}, si::Vector{T}) where T<:Number
102+
dcfilter::FVector{T}, scfilter::FVector{T}, si::FVector{T}) where T<:Number
102103
for i=1:nsub
103104
xi = idx_func(i)
104105
copyto!(tmpvec, 1, x, xi, msub)
@@ -109,16 +110,17 @@ end
109110

110111
# 2-D
111112
# writes to y
112-
function _dwt!(y::Matrix{T}, x::AbstractMatrix{T},
113-
filter::OrthoFilter, L::Integer, fw::Bool) where T<:Number
113+
function _dwt!(y::AbstractMatrix{Ty}, x::AbstractMatrix{Tx},
114+
filter::OrthoFilter, L::Integer, fw::Bool) where {Tx<:Number, Ty<:Number}
114115
m, n = size(x)
116+
T = promote_type(Tx, Ty)
115117
si = Vector{T}(undef, length(filter)-1) # tmp filter vector
116118
tmpbuffer = Vector{T}(undef, max(n<<1, m)) # tmp storage vector
117119
scfilter, dcfilter = WT.makereverseqmfpair(filter, fw, T)
118120

119121
return _dwt!(y, x, filter, L, fw, dcfilter, scfilter, si, tmpbuffer)
120122
end
121-
function _dwt!(y::Matrix{T}, x::AbstractMatrix{T},
123+
function _dwt!(y::AbstractMatrix{<:Number}, x::AbstractMatrix{<:Number},
122124
filter::OrthoFilter, L::Integer, fw::Bool,
123125
dcfilter::Vector{T}, scfilter::Vector{T},
124126
si::Vector{T}, tmpbuffer::Vector{T}) where T<:Number
@@ -187,16 +189,17 @@ end
187189

188190
# 3-D
189191
# writes to y
190-
function _dwt!(y::Array{T, 3}, x::AbstractArray{T, 3},
191-
filter::OrthoFilter, L::Integer, fw::Bool) where T<:Number
192+
function _dwt!(y::AbstractArray{Ty, 3}, x::AbstractArray{Tx, 3},
193+
filter::OrthoFilter, L::Integer, fw::Bool) where {Tx<:Number, Ty<:Number}
192194
m, n, d = size(x)
195+
T = promote_type(Tx, Ty)
193196
si = Vector{T}(undef, length(filter)-1) # tmp filter vector
194197
tmpbuffer = Vector{T}(undef, max(m, n<<1, d<<1)) # tmp storage vector
195198
scfilter, dcfilter = WT.makereverseqmfpair(filter, fw, T)
196199

197200
return _dwt!(y, x, filter, L, fw, dcfilter, scfilter, si, tmpbuffer)
198201
end
199-
function _dwt!(y::Array{T, 3}, x::AbstractArray{T, 3},
202+
function _dwt!(y::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3},
200203
filter::OrthoFilter, L::Integer, fw::Bool,
201204
dcfilter::Vector{T}, scfilter::Vector{T},
202205
si::Vector{T}, tmpbuffer::Vector{T}) where T<:Number
@@ -329,7 +332,7 @@ function _wpt!(y::AbstractVector{T}, x::AbstractVector{T}, filter::OrthoFilter,
329332
Lfw = (fw ? Lmax-L : L-1)
330333
nj = detailn(n, Lfw)
331334
treeind = 2^(Lfw)-1
332-
dx = unsafe_vectorslice(snew, 1, nj)
335+
dx = first ? x : unsafe_vectorslice(snew, 1, nj) # dx will be overwritten if first
333336

334337
while ix <= n
335338
if tree[treeind+k]
@@ -381,9 +384,9 @@ end
381384
# x : filter convolved with x[ix:ix+nx-1], where nx=nout*2 (shifted by shift)
382385
# ss : shift downsampling
383386
# based on Base.filt
384-
function filtdown!(f::Vector{T}, si::Vector{T},
385-
out::AbstractVector{T}, iout::Integer, nout::Integer,
386-
x::AbstractVector{T}, ix::Integer,
387+
function filtdown!(f::AbstractVector{T}, si::AbstractVector{T},
388+
out::AbstractVector{<:Number}, iout::Integer, nout::Integer,
389+
x::AbstractVector{<:Number}, ix::Integer,
387390
shift::Integer=0, ss::Bool=false) where T<:Number
388391
nx = nout<<1
389392
silen = length(si)
@@ -462,8 +465,8 @@ end
462465
# ss : shift upsampling
463466
# based on Base.filt
464467
function filtup!(add2out::Bool, f::Vector{T}, si::Vector{T},
465-
out::AbstractVector{T}, iout::Integer, nout::Integer,
466-
x::AbstractVector{T}, ix::Integer,
468+
out::AbstractVector{<:Number}, iout::Integer, nout::Integer,
469+
x::AbstractVector{<:Number}, ix::Integer,
467470
shift::Integer=0, ss::Bool=false) where T<:Number
468471
nx = nout>>1
469472
silen = length(si)

src/mod/transforms_lifting.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ end
7979
# tmp: size at least n>>2
8080
# oopc: use oop computation, if false iy and incy are assumed to be 1
8181
# oopv: the out of place location
82-
function unsafe_dwt1level!(y::AbstractArray{T}, iy::Integer, incy::Integer, oopc::Bool, oopv::Vector{T}, scheme::GLS, fw::Bool, stepseq::Vector, norm1::T, norm2::T, tmp::Vector{T}) where T<:Number
82+
function unsafe_dwt1level!(y::AbstractArray{T}, iy::Integer, incy::Integer, oopc::Bool,
83+
oopv::FVector{T}, scheme::GLS, fw::Bool, stepseq::FVector, norm1::T, norm2::T,
84+
tmp::FVector{T}) where T<:Number
8385
if !oopc
8486
oopv = y
8587
end

0 commit comments

Comments
 (0)