Skip to content

Commit aa6edcf

Browse files
authored
Simplify memcpy handling (#897)
1 parent cee5fed commit aa6edcf

File tree

2 files changed

+17
-69
lines changed

2 files changed

+17
-69
lines changed

src/array.jl

Lines changed: 8 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,6 @@ Base.convert(::Type{T}, x::T) where T <: ROCArray = x
194194

195195
## memory operations
196196

197-
# TODO rework, to pass pointers, instead of accessing .mem
198-
199197
function Base.copyto!(
200198
dest::Array{T}, d_offset::Integer,
201199
source::ROCArray{T}, s_offset::Integer, amount::Integer;
@@ -204,16 +202,10 @@ function Base.copyto!(
204202
amount == 0 && return dest
205203
@boundscheck checkbounds(dest, d_offset + amount - 1)
206204
@boundscheck checkbounds(source, s_offset + amount - 1)
207-
208-
@debug "[gpu -> cpu] T=$T, shape=$(size(dest))"
209205
stm = stream()
210-
Mem.download!(
211-
pointer(dest, d_offset),
212-
Mem.view(convert(Mem.AbstractAMDBuffer, source.buf[]),
213-
(source.offset + s_offset - 1) * aligned_sizeof(T)),
214-
amount * aligned_sizeof(T); stream=stm)
206+
Mem.memcpy!(pointer(dest, d_offset), pointer(source, s_offset), amount * aligned_sizeof(T); stream=stm)
215207
async || synchronize(stm)
216-
dest
208+
return dest
217209
end
218210

219211
function Base.copyto!(
@@ -223,13 +215,8 @@ function Base.copyto!(
223215
amount == 0 && return dest
224216
@boundscheck checkbounds(dest, d_offset + amount - 1)
225217
@boundscheck checkbounds(source, s_offset + amount - 1)
226-
227-
@debug "[cpu -> gpu] T=$T, shape=$(size(dest))"
228-
Mem.upload!(
229-
Mem.view(convert(Mem.AbstractAMDBuffer, dest.buf[]),
230-
(dest.offset + d_offset - 1) * aligned_sizeof(T)),
231-
pointer(source, s_offset), amount * aligned_sizeof(T); stream=stream())
232-
dest
218+
Mem.memcpy!(pointer(dest, d_offset), pointer(source, s_offset), amount * aligned_sizeof(T); stream=stream())
219+
return dest
233220
end
234221

235222
function Base.copyto!(
@@ -239,19 +226,14 @@ function Base.copyto!(
239226
amount == 0 && return dest
240227
@boundscheck checkbounds(dest, d_offset + amount - 1)
241228
@boundscheck checkbounds(source, s_offset + amount - 1)
242-
Mem.transfer!(
243-
Mem.view(convert(Mem.AbstractAMDBuffer, dest.buf[]),
244-
(dest.offset + d_offset - 1) * aligned_sizeof(T)),
245-
Mem.view(convert(Mem.AbstractAMDBuffer, source.buf[]),
246-
(source.offset + s_offset - 1) * aligned_sizeof(T)),
247-
amount * aligned_sizeof(T); stream=stream())
248-
dest
229+
Mem.memcpy!(pointer(dest, d_offset), pointer(source, s_offset), amount * aligned_sizeof(T); stream=stream())
230+
return dest
249231
end
250232

251233
function Base.copy(X::ROCArray{T}) where T
252234
Xnew = ROCArray{T}(undef, size(X))
253235
copyto!(Xnew, 1, X, 1, length(X))
254-
Xnew
236+
return Xnew
255237
end
256238

257239
function Base.unsafe_wrap(
@@ -344,21 +326,14 @@ function Base.resize!(A::ROCVector{T}, n::Integer) where T
344326
# if A.buf.host_ptr != C_NULL
345327
# throw(ArgumentError("Cannot resize an unowned `ROCVector`"))
346328
# end
347-
348-
# TODO: add additional space to allow for quicker resizing
349329
n == length(A) && return A
350330

351331
maxsize = n * aligned_sizeof(T)
352332
bufsize = Base.isbitsunion(T) ? (maxsize + n) : maxsize
353333
new_buf = Mem.HIPBuffer(bufsize; stream=stream())
354334

355335
copy_size = min(length(A), n) * aligned_sizeof(T)
356-
if copy_size > 0
357-
Mem.transfer!(new_buf, convert(Mem.AbstractAMDBuffer, A.buf[]),
358-
copy_size; stream=stream())
359-
end
360-
361-
# Free old buffer.
336+
copy_size > 0 && Mem.memcpy!(new_buf, pointer(A), copy_size; stream=stream())
362337
unsafe_free!(A)
363338

364339
A.buf = DataRef(pool_free, Managed(new_buf))

src/runtime/memory/hip.jl

Lines changed: 9 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -88,21 +88,16 @@ function free(buf::HIPBuffer; stream::HIP.HIPStream)
8888
return
8989
end
9090

91-
function upload!(dst::HIPBuffer, src::Ptr, bytesize::Int; stream::HIP.HIPStream)
91+
function memcpy!(dst, src, bytesize::Int; stream::HIP.HIPStream)
9292
bytesize == 0 && return
93-
HIP.hipMemcpyHtoDAsync(dst, src, bytesize, stream)
94-
return
95-
end
96-
97-
function download!(dst::Ptr, src::HIPBuffer, bytesize::Int; stream::HIP.HIPStream)
98-
bytesize == 0 && return
99-
HIP.hipMemcpyDtoHAsync(dst, src, bytesize, stream)
100-
return
101-
end
102-
103-
function transfer!(dst::HIPBuffer, src::HIPBuffer, bytesize::Int; stream::HIP.HIPStream)
104-
bytesize == 0 && return
105-
HIP.hipMemcpyDtoDAsync(dst, src, bytesize, stream)
93+
dst_type = attributes(convert(Ptr{Cvoid}, dst)).type
94+
src_type = attributes(convert(Ptr{Cvoid}, src)).type
95+
kind = if src_type == HIP.hipMemoryTypeDevice
96+
dst_type == HIP.hipMemoryTypeDevice ? HIP.hipMemcpyDeviceToDevice : HIP.hipMemcpyDeviceToHost
97+
else
98+
dst_type == HIP.hipMemoryTypeDevice ? HIP.hipMemcpyHostToDevice : HIP.hipMemcpyHostToHost
99+
end
100+
HIP.memcpy(dst, src, bytesize, kind, stream)
106101
return
107102
end
108103

@@ -165,28 +160,6 @@ function view(buf::HostBuffer, bytesize::Int)
165160
buf.bytesize - bytesize, buf.own)
166161
end
167162

168-
upload!(dst::HostBuffer, src::Ptr, sz::Int; stream::HIP.HIPStream) =
169-
HIP.memcpy(dst, src, sz, HIP.hipMemcpyHostToHost, stream)
170-
171-
upload!(dst::HostBuffer, src::HIPBuffer, sz::Int; stream::HIP.HIPStream) =
172-
HIP.memcpy(dst, src, sz, HIP.hipMemcpyDeviceToHost, stream)
173-
174-
download!(dst::Ptr, src::HostBuffer, sz::Int; stream::HIP.HIPStream) =
175-
HIP.memcpy(dst, src, sz, HIP.hipMemcpyHostToHost, stream)
176-
177-
download!(dst::HIPBuffer, src::HostBuffer, sz::Int; stream::HIP.HIPStream) =
178-
HIP.memcpy(dst, src, sz, HIP.hipMemcpyHostToDevice, stream)
179-
180-
transfer!(dst::HostBuffer, src::HostBuffer, sz::Int; stream::HIP.HIPStream) =
181-
HIP.memcpy(dst, src, sz, HIP.hipMemcpyHostToHost, stream)
182-
183-
# download!(::Ptr, ::HIPBuffer)
184-
transfer!(dst::HostBuffer, src::HIPBuffer, sz::Int; stream::HIP.HIPStream) =
185-
HIP.memcpy(dst, src, sz, HIP.hipMemcpyDeviceToHost, stream)
186-
187-
# upload!(::HIPBuffer, ::Ptr)
188-
transfer!(dst::HIPBuffer, src::HostBuffer, sz::Int; stream::HIP.HIPStream) =
189-
HIP.memcpy(dst, src, sz, HIP.hipMemcpyHostToDevice, stream)
190163

191164
Base.convert(::Type{Ptr{T}}, buf::HostBuffer) where T = convert(Ptr{T}, buf.ptr)
192165

0 commit comments

Comments
 (0)