summaryrefslogtreecommitdiff
path: root/parallel-libs
diff options
context:
space:
mode:
authorJason Henline <jhen@google.com>2016-08-24 16:58:20 +0000
committerJason Henline <jhen@google.com>2016-08-24 16:58:20 +0000
commit4604dd398111e1daee2096222726d53d6df69857 (patch)
treefabdd126b1e4bfd8702ea85682415fd8bcbad9e3 /parallel-libs
parent053eb6c38cce09117d0b4badb415496b3865c03d (diff)
[StreamExecutor] Executor add synchronous methods
Summary: Add Executor methods that block the host until completion. Since these methods are host-synchronous, they don't require Stream arguments. Reviewers: jlebar Subscribers: jprice, parallel_libs-commits Differential Revision: https://reviews.llvm.org/D23577
Diffstat (limited to 'parallel-libs')
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/DeviceMemory.h77
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/Executor.h307
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h105
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/Stream.h290
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/Utils/Error.h6
-rw-r--r--parallel-libs/streamexecutor/lib/Utils/Error.cpp6
-rw-r--r--parallel-libs/streamexecutor/lib/unittests/CMakeLists.txt10
-rw-r--r--parallel-libs/streamexecutor/lib/unittests/ExecutorTest.cpp451
-rw-r--r--parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp340
9 files changed, 1475 insertions, 117 deletions
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/DeviceMemory.h b/parallel-libs/streamexecutor/include/streamexecutor/DeviceMemory.h
index b3b0fd2faf2..45faf7b10f8 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/DeviceMemory.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/DeviceMemory.h
@@ -18,9 +18,9 @@
/// and a byte count to tell how much memory is pointed to by that void*.
///
/// GlobalDeviceMemory<T> is a subclass of GlobalDeviceMemoryBase which keeps
-/// track of the type of element to be stored in the device array. It is similar
-/// to a pair of a T* pointer and an element count to tell how many elements of
-/// type T fit in the memory pointed to by that T*.
+/// track of the type of element to be stored in the device memory. It is
+/// similar to a pair of a T* pointer and an element count to tell how many
+/// elements of type T fit in the memory pointed to by that T*.
///
/// SharedDeviceMemoryBase is just the size in bytes of a shared memory buffer.
///
@@ -38,6 +38,7 @@
#ifndef STREAMEXECUTOR_DEVICEMEMORY_H
#define STREAMEXECUTOR_DEVICEMEMORY_H
+#include <cassert>
#include <cstddef>
namespace streamexecutor {
@@ -91,6 +92,71 @@ private:
size_t ByteCount; // Size in bytes of this allocation.
};
+template <typename ElemT> class GlobalDeviceMemory;
+
+/// Reference to a slice of device memory.
+///
+/// Contains a base memory handle, an element count offset into that base
+/// memory, and an element count for the size of the slice.
+template <typename ElemT> class GlobalDeviceMemorySlice {
+public:
+ /// Intentionally implicit so GlobalDeviceMemory<T> can be passed to functions
+ /// expecting GlobalDeviceMemorySlice<T> arguments.
+ GlobalDeviceMemorySlice(const GlobalDeviceMemory<ElemT> &Memory)
+ : BaseMemory(Memory), ElementOffset(0),
+ ElementCount(Memory.getElementCount()) {}
+
+ GlobalDeviceMemorySlice(const GlobalDeviceMemory<ElemT> &BaseMemory,
+ size_t ElementOffset, size_t ElementCount)
+ : BaseMemory(BaseMemory), ElementOffset(ElementOffset),
+ ElementCount(ElementCount) {
+ assert(ElementOffset + ElementCount <= BaseMemory.getElementCount() &&
+ "slicing past the end of a GlobalDeviceMemory buffer");
+ }
+
+ /// Gets the GlobalDeviceMemory backing this slice.
+ GlobalDeviceMemory<ElemT> getBaseMemory() const { return BaseMemory; }
+
+ /// Gets the offset of this slice from the base memory.
+ ///
+ /// The offset is measured in elements, not bytes.
+ size_t getElementOffset() const { return ElementOffset; }
+
+ /// Gets the number of elements in this slice.
+ size_t getElementCount() const { return ElementCount; }
+
+ /// Creates a slice of the memory with the first DropCount elements removed.
+ GlobalDeviceMemorySlice<ElemT> drop_front(size_t DropCount) const {
+ assert(DropCount <= ElementCount &&
+ "dropping more than the size of a slice");
+ return GlobalDeviceMemorySlice<ElemT>(BaseMemory, ElementOffset + DropCount,
+ ElementCount - DropCount);
+ }
+
+ /// Creates a slice of the memory with the last DropCount elements removed.
+ GlobalDeviceMemorySlice<ElemT> drop_back(size_t DropCount) const {
+ assert(DropCount <= ElementCount &&
+ "dropping more than the size of a slice");
+ return GlobalDeviceMemorySlice<ElemT>(BaseMemory, ElementOffset,
+ ElementCount - DropCount);
+ }
+
+ /// Creates a slice of the memory that chops off the first DropCount elements
+ /// and keeps the next TakeCount elements.
+ GlobalDeviceMemorySlice<ElemT> slice(size_t DropCount,
+ size_t TakeCount) const {
+ assert(DropCount + TakeCount <= ElementCount &&
+ "sub-slice operation overruns slice bounds");
+ return GlobalDeviceMemorySlice<ElemT>(BaseMemory, ElementOffset + DropCount,
+ TakeCount);
+ }
+
+private:
+ GlobalDeviceMemory<ElemT> BaseMemory;
+ size_t ElementOffset;
+ size_t ElementCount;
+};
+
/// Typed wrapper around the "void *"-like GlobalDeviceMemoryBase class.
///
/// For example, GlobalDeviceMemory<int> is a simple wrapper around
@@ -125,6 +191,11 @@ public:
/// allocation.
size_t getElementCount() const { return getByteCount() / sizeof(ElemT); }
+ /// Converts this memory object into a slice.
+ GlobalDeviceMemorySlice<ElemT> asSlice() {
+ return GlobalDeviceMemorySlice<ElemT>(*this);
+ }
+
private:
/// Constructs a GlobalDeviceMemory instance from an opaque handle and an
/// element count.
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Executor.h b/parallel-libs/streamexecutor/include/streamexecutor/Executor.h
index 0f0696279a3..ea4224eb30c 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/Executor.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/Executor.h
@@ -16,12 +16,12 @@
#define STREAMEXECUTOR_EXECUTOR_H
#include "streamexecutor/KernelSpec.h"
+#include "streamexecutor/PlatformInterfaces.h"
#include "streamexecutor/Utils/Error.h"
namespace streamexecutor {
class KernelInterface;
-class PlatformExecutor;
class Stream;
class Executor {
@@ -38,6 +38,311 @@ public:
Expected<std::unique_ptr<Stream>> createStream();
+ /// Allocates an array of ElementCount entries of type T in device memory.
+ template <typename T>
+ Expected<GlobalDeviceMemory<T>> allocateDeviceMemory(size_t ElementCount) {
+ return PExecutor->allocateDeviceMemory(ElementCount * sizeof(T));
+ }
+
+ /// Frees memory previously allocated with allocateDeviceMemory.
+ template <typename T> Error freeDeviceMemory(GlobalDeviceMemory<T> Memory) {
+ return PExecutor->freeDeviceMemory(Memory);
+ }
+
+ /// Allocates an array of ElementCount entries of type T in host memory.
+ ///
+ /// Host memory allocated by this function can be used for asynchronous memory
+ /// copies on streams. See Stream::thenCopyD2H and Stream::thenCopyH2D.
+ template <typename T> Expected<T *> allocateHostMemory(size_t ElementCount) {
+ return PExecutor->allocateHostMemory(ElementCount * sizeof(T));
+ }
+
+ /// Frees memory previously allocated with allocateHostMemory.
+ template <typename T> Error freeHostMemory(T *Memory) {
+ return PExecutor->freeHostMemory(Memory);
+ }
+
+ /// Registers a previously allocated host array of type T for asynchronous
+ /// memory operations.
+ ///
+ /// Host memory registered by this function can be used for asynchronous
+ /// memory copies on streams. See Stream::thenCopyD2H and Stream::thenCopyH2D.
+ template <typename T>
+ Error registerHostMemory(T *Memory, size_t ElementCount) {
+ return PExecutor->registerHostMemory(Memory, ElementCount * sizeof(T));
+ }
+
+ /// Unregisters host memory previously registered by registerHostMemory.
+ template <typename T> Error unregisterHostMemory(T *Memory) {
+ return PExecutor->unregisterHostMemory(Memory);
+ }
+
+ /// Host-synchronously copies a slice of an array of elements of type T from
+ /// host to device memory.
+ ///
+ /// Returns an error if ElementCount is too large for the source slice or the
+ /// destination.
+ ///
+ /// The calling host thread is blocked until the copy completes. Can be used
+ /// with any host memory, the host memory does not have to be allocated with
+ /// allocateHostMemory or registered with registerHostMemory. Does not block
+ /// any ongoing device calls.
+ template <typename T>
+ Error synchronousCopyD2H(GlobalDeviceMemorySlice<T> Src,
+ llvm::MutableArrayRef<T> Dst, size_t ElementCount) {
+ if (ElementCount > Src.getElementCount())
+ return make_error("copying too many elements, " +
+ llvm::Twine(ElementCount) +
+ ", from a device array of element count " +
+ llvm::Twine(Src.getElementCount()));
+ if (ElementCount > Dst.size())
+ return make_error(
+ "copying too many elements, " + llvm::Twine(ElementCount) +
+ ", to a host array of element count " + llvm::Twine(Dst.size()));
+ return PExecutor->synchronousCopyD2H(
+ Src.getBaseMemory(), Src.getElementOffset() * sizeof(T), Dst.data(), 0,
+ ElementCount * sizeof(T));
+ }
+
+ /// Similar to synchronousCopyD2H(GlobalDeviceMemorySlice<T>,
+ /// llvm::MutableArrayRef<T>, size_t) but does not take an element count
+ /// argument because it copies the entire source array.
+ ///
+ /// Returns an error if the Src and Dst sizes do not match.
+ template <typename T>
+ Error synchronousCopyD2H(GlobalDeviceMemorySlice<T> Src,
+ llvm::MutableArrayRef<T> Dst) {
+ if (Src.getElementCount() != Dst.size())
+ return make_error(
+ "array size mismatch for D2H, device source has element count " +
+ llvm::Twine(Src.getElementCount()) +
+ " but host destination has element count " + llvm::Twine(Dst.size()));
+ return synchronousCopyD2H(Src, Dst, Src.getElementCount());
+ }
+
+ /// Similar to synchronousCopyD2H(GlobalDeviceMemorySlice<T>,
+ /// llvm::MutableArrayRef<T>, size_t) but copies to a pointer rather than an
+ /// llvm::MutableArrayRef.
+ ///
+ /// Returns an error if ElementCount is too large for the source slice.
+ template <typename T>
+ Error synchronousCopyD2H(GlobalDeviceMemorySlice<T> Src, T *Dst,
+ size_t ElementCount) {
+ return synchronousCopyD2H(Src, llvm::MutableArrayRef<T>(Dst, ElementCount),
+ ElementCount);
+ }
+
+ /// Similar to synchronousCopyD2H(GlobalDeviceMemorySlice<T>,
+ /// llvm::MutableArrayRef<T>, size_t) but the source is a GlobalDeviceMemory
+ /// rather than a GlobalDeviceMemorySlice.
+ template <typename T>
+ Error synchronousCopyD2H(GlobalDeviceMemory<T> Src,
+ llvm::MutableArrayRef<T> Dst, size_t ElementCount) {
+ return synchronousCopyD2H(Src.asSlice(), Dst, ElementCount);
+ }
+
+ /// Similar to synchronousCopyD2H(GlobalDeviceMemorySlice<T>,
+ /// llvm::MutableArrayRef<T>) but the source is a GlobalDeviceMemory rather
+ /// than a GlobalDeviceMemorySlice.
+ template <typename T>
+ Error synchronousCopyD2H(GlobalDeviceMemory<T> Src,
+ llvm::MutableArrayRef<T> Dst) {
+ return synchronousCopyD2H(Src.asSlice(), Dst);
+ }
+
+ /// Similar to synchronousCopyD2H(GlobalDeviceMemorySlice<T>, T*, size_t) but
+ /// the source is a GlobalDeviceMemory rather than a GlobalDeviceMemorySlice.
+ template <typename T>
+ Error synchronousCopyD2H(GlobalDeviceMemory<T> Src, T *Dst,
+ size_t ElementCount) {
+ return synchronousCopyD2H(Src.asSlice(), Dst, ElementCount);
+ }
+
+ /// Host-synchronously copies a slice of an array of elements of type T from
+ /// device to host memory.
+ ///
+ /// Returns an error if ElementCount is too large for the source or the
+ /// destination.
+ ///
+ /// The calling host thread is blocked until the copy completes. Can be used
+ /// with any host memory, the host memory does not have to be allocated with
+ /// allocateHostMemory or registered with registerHostMemory. Does not block
+ /// any ongoing device calls.
+ template <typename T>
+ Error synchronousCopyH2D(llvm::ArrayRef<T> Src,
+ GlobalDeviceMemorySlice<T> Dst,
+ size_t ElementCount) {
+ if (ElementCount > Src.size())
+ return make_error(
+ "copying too many elements, " + llvm::Twine(ElementCount) +
+ ", from a host array of element count " + llvm::Twine(Src.size()));
+ if (ElementCount > Dst.getElementCount())
+ return make_error("copying too many elements, " +
+ llvm::Twine(ElementCount) +
+ ", to a device array of element count " +
+ llvm::Twine(Dst.getElementCount()));
+ return PExecutor->synchronousCopyH2D(Src.data(), 0, Dst.getBaseMemory(),
+ Dst.getElementOffset() * sizeof(T),
+ ElementCount * sizeof(T));
+ }
+
+ /// Similar to synchronousCopyH2D(llvm::ArrayRef<T>,
+ /// GlobalDeviceMemorySlice<T>, size_t) but does not take an element count
+ /// argument because it copies the entire source array.
+ ///
+ /// Returns an error if the Src and Dst sizes do not match.
+ template <typename T>
+ Error synchronousCopyH2D(llvm::ArrayRef<T> Src,
+ GlobalDeviceMemorySlice<T> Dst) {
+ if (Src.size() != Dst.getElementCount())
+ return make_error(
+ "array size mismatch for H2D, host source has element count " +
+ llvm::Twine(Src.size()) +
+ " but device destination has element count " +
+ llvm::Twine(Dst.getElementCount()));
+ return synchronousCopyH2D(Src, Dst, Dst.getElementCount());
+ }
+
+ /// Similar to synchronousCopyH2D(llvm::ArrayRef<T>,
+ /// GlobalDeviceMemorySlice<T>, size_t) but copies from a pointer rather than
+ /// an llvm::ArrayRef.
+ ///
+ /// Returns an error if ElementCount is too large for the destination.
+ template <typename T>
+ Error synchronousCopyH2D(T *Src, GlobalDeviceMemorySlice<T> Dst,
+ size_t ElementCount) {
+ return synchronousCopyH2D(llvm::ArrayRef<T>(Src, ElementCount), Dst,
+ ElementCount);
+ }
+
+ /// Similar to synchronousCopyH2D(llvm::ArrayRef<T>,
+ /// GlobalDeviceMemorySlice<T>, size_t) but the destination is a
+ /// GlobalDeviceMemory rather than a GlobalDeviceMemorySlice.
+ template <typename T>
+ Error synchronousCopyH2D(llvm::ArrayRef<T> Src, GlobalDeviceMemory<T> Dst,
+ size_t ElementCount) {
+ return synchronousCopyH2D(Src, Dst.asSlice(), ElementCount);
+ }
+
+ /// Similar to synchronousCopyH2D(llvm::ArrayRef<T>,
+ /// GlobalDeviceMemorySlice<T>) but the destination is a GlobalDeviceMemory
+ /// rather than a GlobalDeviceMemorySlice.
+ template <typename T>
+ Error synchronousCopyH2D(llvm::ArrayRef<T> Src, GlobalDeviceMemory<T> Dst) {
+ return synchronousCopyH2D(Src, Dst.asSlice());
+ }
+
+ /// Similar to synchronousCopyH2D(T*, GlobalDeviceMemorySlice<T>, size_t) but
+ /// the destination is a GlobalDeviceMemory rather than a
+ /// GlobalDeviceMemorySlice.
+ template <typename T>
+ Error synchronousCopyH2D(T *Src, GlobalDeviceMemory<T> Dst,
+ size_t ElementCount) {
+ return synchronousCopyH2D(Src, Dst.asSlice(), ElementCount);
+ }
+
+ /// Host-synchronously copies a slice of an array of elements of type T from
+ /// one location in device memory to another.
+ ///
+ /// Returns an error if ElementCount is too large for the source slice or the
+ /// destination.
+ ///
+ /// The calling host thread is blocked until the copy completes. Can be used
+ /// with any host memory, the host memory does not have to be allocated with
+ /// allocateHostMemory or registered with registerHostMemory. Does not block
+ /// any ongoing device calls.
+ template <typename T>
+ Error synchronousCopyD2D(GlobalDeviceMemorySlice<T> Src,
+ GlobalDeviceMemorySlice<T> Dst,
+ size_t ElementCount) {
+ if (ElementCount > Src.getElementCount())
+ return make_error("copying too many elements, " +
+ llvm::Twine(ElementCount) +
+ ", from a device array of element count " +
+ llvm::Twine(Src.getElementCount()));
+ if (ElementCount > Dst.getElementCount())
+ return make_error("copying too many elements, " +
+ llvm::Twine(ElementCount) +
+ ", to a device array of element count " +
+ llvm::Twine(Dst.getElementCount()));
+ return PExecutor->synchronousCopyD2D(
+ Src.getBaseMemory(), Src.getElementOffset() * sizeof(T),
+ Dst.getBaseMemory(), Dst.getElementOffset() * sizeof(T),
+ ElementCount * sizeof(T));
+ }
+
+ /// Similar to synchronousCopyD2D(GlobalDeviceMemorySlice<T>,
+ /// GlobalDeviceMemorySlice<T>, size_t) but does not take an element count
+ /// argument because it copies the entire source array.
+ ///
+ /// Returns an error if the Src and Dst sizes do not match.
+ template <typename T>
+ Error synchronousCopyD2D(GlobalDeviceMemorySlice<T> Src,
+ GlobalDeviceMemorySlice<T> Dst) {
+ if (Src.getElementCount() != Dst.getElementCount())
+ return make_error(
+ "array size mismatch for D2D, device source has element count " +
+ llvm::Twine(Src.getElementCount()) +
+ " but device destination has element count " +
+ llvm::Twine(Dst.getElementCount()));
+ return synchronousCopyD2D(Src, Dst, Src.getElementCount());
+ }
+
+ /// Similar to synchronousCopyD2D(GlobalDeviceMemorySlice<T>,
+ /// GlobalDeviceMemorySlice<T>, size_t) but the source is a
+ /// GlobalDeviceMemory<T> rather than a GlobalDeviceMemorySlice<T>.
+ template <typename T>
+ Error synchronousCopyD2D(GlobalDeviceMemory<T> Src,
+ GlobalDeviceMemorySlice<T> Dst,
+ size_t ElementCount) {
+ return synchronousCopyD2D(Src.asSlice(), Dst, ElementCount);
+ }
+
+ /// Similar to synchronousCopyD2D(GlobalDeviceMemorySlice<T>,
+ /// GlobalDeviceMemorySlice<T>) but the source is a GlobalDeviceMemory<T>
+ /// rather than a GlobalDeviceMemorySlice<T>.
+ template <typename T>
+ Error synchronousCopyD2D(GlobalDeviceMemory<T> Src,
+ GlobalDeviceMemorySlice<T> Dst) {
+ return synchronousCopyD2D(Src.asSlice(), Dst);
+ }
+
+ /// Similar to synchronousCopyD2D(GlobalDeviceMemorySlice<T>,
+ /// GlobalDeviceMemorySlice<T>, size_t) but the destination is a
+ /// GlobalDeviceMemory<T> rather than a GlobalDeviceMemorySlice<T>.
+ template <typename T>
+ Error synchronousCopyD2D(GlobalDeviceMemorySlice<T> Src,
+ GlobalDeviceMemory<T> Dst, size_t ElementCount) {
+ return synchronousCopyD2D(Src, Dst.asSlice(), ElementCount);
+ }
+
+ /// Similar to synchronousCopyD2D(GlobalDeviceMemorySlice<T>,
+ /// GlobalDeviceMemorySlice<T>) but the destination is a GlobalDeviceMemory<T>
+ /// rather than a GlobalDeviceMemorySlice<T>.
+ template <typename T>
+ Error synchronousCopyD2D(GlobalDeviceMemorySlice<T> Src,
+ GlobalDeviceMemory<T> Dst) {
+ return synchronousCopyD2D(Src, Dst.asSlice());
+ }
+
+ /// Similar to synchronousCopyD2D(GlobalDeviceMemorySlice<T>,
+ /// GlobalDeviceMemorySlice<T>, size_t) but the source and destination are
+ /// GlobalDeviceMemory<T> rather than a GlobalDeviceMemorySlice<T>.
+ template <typename T>
+ Error synchronousCopyD2D(GlobalDeviceMemory<T> Src, GlobalDeviceMemory<T> Dst,
+ size_t ElementCount) {
+ return synchronousCopyD2D(Src.asSlice(), Dst.asSlice(), ElementCount);
+ }
+
+ /// Similar to synchronousCopyD2D(GlobalDeviceMemorySlice<T>,
+ /// GlobalDeviceMemorySlice<T>) but the source and destination are
+ /// GlobalDeviceMemory<T> rather than a GlobalDeviceMemorySlice<T>.
+ template <typename T>
+ Error synchronousCopyD2D(GlobalDeviceMemory<T> Src,
+ GlobalDeviceMemory<T> Dst) {
+ return synchronousCopyD2D(Src.asSlice(), Dst.asSlice());
+ }
+
private:
PlatformExecutor *PExecutor;
};
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h b/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h
index 23bae9e9a05..2c8fce39078 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h
@@ -76,23 +76,32 @@ public:
}
/// Copies data from the device to the host.
- virtual Error memcpyD2H(PlatformStreamHandle *S,
- const GlobalDeviceMemoryBase &DeviceSrc,
- void *HostDst, size_t ByteCount) {
- return make_error("memcpyD2H not implemented for platform " + getName());
+ ///
+ /// HostDst should have been allocated by allocateHostMemory or registered
+ /// with registerHostMemory.
+ virtual Error copyD2H(PlatformStreamHandle *S,
+ const GlobalDeviceMemoryBase &DeviceSrc,
+ size_t SrcByteOffset, void *HostDst,
+ size_t DstByteOffset, size_t ByteCount) {
+ return make_error("copyD2H not implemented for platform " + getName());
}
/// Copies data from the host to the device.
- virtual Error memcpyH2D(PlatformStreamHandle *S, const void *HostSrc,
- GlobalDeviceMemoryBase *DeviceDst, size_t ByteCount) {
- return make_error("memcpyH2D not implemented for platform " + getName());
+ ///
+ /// HostSrc should have been allocated by allocateHostMemory or registered
+ /// with registerHostMemory.
+ virtual Error copyH2D(PlatformStreamHandle *S, const void *HostSrc,
+ size_t SrcByteOffset, GlobalDeviceMemoryBase DeviceDst,
+ size_t DstByteOffset, size_t ByteCount) {
+ return make_error("copyH2D not implemented for platform " + getName());
}
/// Copies data from one device location to another.
- virtual Error memcpyD2D(PlatformStreamHandle *S,
- const GlobalDeviceMemoryBase &DeviceSrc,
- GlobalDeviceMemoryBase *DeviceDst, size_t ByteCount) {
- return make_error("memcpyD2D not implemented for platform " + getName());
+ virtual Error copyD2D(PlatformStreamHandle *S,
+ const GlobalDeviceMemoryBase &DeviceSrc,
+ size_t SrcByteOffset, GlobalDeviceMemoryBase DeviceDst,
+ size_t DstByteOffset, size_t ByteCount) {
+ return make_error("copyD2D not implemented for platform " + getName());
}
/// Blocks the host until the given stream completes all the work enqueued up
@@ -101,6 +110,80 @@ public:
return make_error("blockHostUntilDone not implemented for platform " +
getName());
}
+
+ /// Allocates untyped device memory of a given size in bytes.
+ virtual Expected<GlobalDeviceMemoryBase>
+ allocateDeviceMemory(size_t ByteCount) {
+ return make_error("allocateDeviceMemory not implemented for platform " +
+ getName());
+ }
+
+ /// Frees device memory previously allocated by allocateDeviceMemory.
+ virtual Error freeDeviceMemory(GlobalDeviceMemoryBase Memory) {
+ return make_error("freeDeviceMemory not implemented for platform " +
+ getName());
+ }
+
+ /// Allocates untyped host memory of a given size in bytes.
+ ///
+ /// Host memory allocated via this method is suitable for use with copyH2D and
+ /// copyD2H.
+ virtual Expected<void *> allocateHostMemory(size_t ByteCount) {
+ return make_error("allocateHostMemory not implemented for platform " +
+ getName());
+ }
+
+ /// Frees host memory allocated by allocateHostMemory.
+ virtual Error freeHostMemory(void *Memory) {
+ return make_error("freeHostMemory not implemented for platform " +
+ getName());
+ }
+
+ /// Registers previously allocated host memory so it can be used with copyH2D
+ /// and copyD2H.
+ virtual Error registerHostMemory(void *Memory, size_t ByteCount) {
+ return make_error("registerHostMemory not implemented for platform " +
+ getName());
+ }
+
+ /// Unregisters host memory previously registered with registerHostMemory.
+ virtual Error unregisterHostMemory(void *Memory) {
+ return make_error("unregisterHostMemory not implemented for platform " +
+ getName());
+ }
+
+ /// Copies the given number of bytes from device memory to host memory.
+ ///
+ /// Blocks the calling host thread until the copy is completed. Can operate on
+ /// any host memory, not just registered host memory or host memory allocated
+ /// by allocateHostMemory. Does not block any ongoing device calls.
+ virtual Error synchronousCopyD2H(const GlobalDeviceMemoryBase &DeviceSrc,
+ size_t SrcByteOffset, void *HostDst,
+ size_t DstByteOffset, size_t ByteCount) {
+ return make_error("synchronousCopyD2H not implemented for platform " +
+ getName());
+ }
+
+ /// Similar to synchronousCopyD2H(const GlobalDeviceMemoryBase &, size_t, void
+ /// *, size_t, size_t), but copies memory from host to device rather than
+ /// device to host.
+ virtual Error synchronousCopyH2D(const void *HostSrc, size_t SrcByteOffset,
+ GlobalDeviceMemoryBase DeviceDst,
+ size_t DstByteOffset, size_t ByteCount) {
+ return make_error("synchronousCopyH2D not implemented for platform " +
+ getName());
+ }
+
+ /// Similar to synchronousCopyD2H(const GlobalDeviceMemoryBase &, size_t, void
+ /// *, size_t, size_t), but copies memory from one location in device memory
+ /// to another rather than from device to host.
+ virtual Error synchronousCopyD2D(GlobalDeviceMemoryBase DeviceDst,
+ size_t DstByteOffset,
+ const GlobalDeviceMemoryBase &DeviceSrc,
+ size_t SrcByteOffset, size_t ByteCount) {
+ return make_error("synchronousCopyD2D not implemented for platform " +
+ getName());
+ }
};
} // namespace streamexecutor
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Stream.h b/parallel-libs/streamexecutor/include/streamexecutor/Stream.h
index ba126faaa7b..87a2c7c3885 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/Stream.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/Stream.h
@@ -17,7 +17,7 @@
/// The Stream instance will perform its work on the device managed by the
/// Executor that created it.
///
-/// The various "then" methods of the Stream object, such as thenMemcpyH2D and
+/// The various "then" methods of the Stream object, such as thenCopyH2D and
/// thenLaunch, may be used to enqueue work on the Stream, and the
/// blockHostUntilDone() method may be used to block the host code until the
/// Stream has completed all its work.
@@ -99,102 +99,262 @@ public:
return *this;
}
- /// Entrain onto the stream a memcpy of a given number of elements from a
- /// device source to a host destination.
+ /// Enqueues on this stream a command to copy a slice of an array of elements
+ /// of type T from device to host memory.
///
- /// HostDst must be a pointer to host memory allocated by
- /// Executor::allocateHostMemory or otherwise allocated and then
- /// registered with Executor::registerHostMemory.
+ /// Sets an error if ElementCount is too large for the source or the
+ /// destination.
+ ///
+ /// If the Src memory was not created by allocateHostMemory or registered with
+ /// registerHostMemory, then the copy operation may cause the host and device
+ /// to block until the copy operation is completed.
template <typename T>
- Stream &thenMemcpyD2H(const GlobalDeviceMemory<T> &DeviceSrc,
- llvm::MutableArrayRef<T> HostDst, size_t ElementCount) {
- if (ElementCount > DeviceSrc.getElementCount())
+ Stream &thenCopyD2H(GlobalDeviceMemorySlice<T> Src,
+ llvm::MutableArrayRef<T> Dst, size_t ElementCount) {
+ if (ElementCount > Src.getElementCount())
setError("copying too many elements, " + llvm::Twine(ElementCount) +
- ", from device memory array of size " +
- llvm::Twine(DeviceSrc.getElementCount()));
- else if (ElementCount > HostDst.size())
+ ", from a device array of element count " +
+ llvm::Twine(Src.getElementCount()));
+ else if (ElementCount > Dst.size())
setError("copying too many elements, " + llvm::Twine(ElementCount) +
- ", to host array of size " + llvm::Twine(HostDst.size()));
+ ", to a host array of element count " + llvm::Twine(Dst.size()));
else
- setError(PExecutor->memcpyD2H(ThePlatformStream.get(), DeviceSrc,
- HostDst.data(), ElementCount * sizeof(T)));
+ setError(PExecutor->copyD2H(ThePlatformStream.get(), Src.getBaseMemory(),
+ Src.getElementOffset() * sizeof(T),
+ Dst.data(), 0, ElementCount * sizeof(T)));
return *this;
}
- /// Same as thenMemcpyD2H above, but copies the entire source to the
- /// destination.
+ /// Similar to thenCopyD2H(GlobalDeviceMemorySlice<T>,
+ /// llvm::MutableArrayRef<T>, size_t) but does not take an element count
+ /// argument because it copies the entire source array.
+ ///
+ /// Sets an error if the Src and Dst sizes do not match.
template <typename T>
- Stream &thenMemcpyD2H(const GlobalDeviceMemory<T> &DeviceSrc,
- llvm::MutableArrayRef<T> HostDst) {
- return thenMemcpyD2H(DeviceSrc, HostDst, DeviceSrc.getElementCount());
+ Stream &thenCopyD2H(GlobalDeviceMemorySlice<T> Src,
+ llvm::MutableArrayRef<T> Dst) {
+ if (Src.getElementCount() != Dst.size())
+ setError("array size mismatch for D2H, device source has element count " +
+ llvm::Twine(Src.getElementCount()) +
+ " but host destination has element count " +
+ llvm::Twine(Dst.size()));
+ else
+ thenCopyD2H(Src, Dst, Src.getElementCount());
+ return *this;
}
- /// Entrain onto the stream a memcpy of a given number of elements from a host
- /// source to a device destination.
+ /// Similar to thenCopyD2H(GlobalDeviceMemorySlice<T>,
+ /// llvm::MutableArrayRef<T>, size_t) but copies to a pointer rather than an
+ /// llvm::MutableArrayRef.
///
- /// HostSrc must be a pointer to host memory allocated by
- /// Executor::allocateHostMemory or otherwise allocated and then
- /// registered with Executor::registerHostMemory.
+ /// Sets an error if ElementCount is too large for the source slice.
+ template <typename T>
+ Stream &thenCopyD2H(GlobalDeviceMemorySlice<T> Src, T *Dst,
+ size_t ElementCount) {
+ thenCopyD2H(Src, llvm::MutableArrayRef<T>(Dst, ElementCount), ElementCount);
+ return *this;
+ }
+
+ /// Similar to thenCopyD2H(GlobalDeviceMemorySlice<T>,
+ /// llvm::MutableArrayRef<T>, size_t) but the source is a GlobalDeviceMemory
+ /// rather than a GlobalDeviceMemorySlice.
+ template <typename T>
+ Stream &thenCopyD2H(GlobalDeviceMemory<T> Src, llvm::MutableArrayRef<T> Dst,
+ size_t ElementCount) {
+ thenCopyD2H(Src.asSlice(), Dst, ElementCount);
+ return *this;
+ }
+
+ /// Similar to thenCopyD2H(GlobalDeviceMemorySlice<T>,
+ /// llvm::MutableArrayRef<T>) but the source is a GlobalDeviceMemory rather
+ /// than a GlobalDeviceMemorySlice.
+ template <typename T>
+ Stream &thenCopyD2H(GlobalDeviceMemory<T> Src, llvm::MutableArrayRef<T> Dst) {
+ thenCopyD2H(Src.asSlice(), Dst);
+ return *this;
+ }
+
+ /// Similar to thenCopyD2H(GlobalDeviceMemorySlice<T>, T*, size_t) but the
+ /// source is a GlobalDeviceMemory rather than a GlobalDeviceMemorySlice.
template <typename T>
- Stream &thenMemcpyH2D(llvm::ArrayRef<T> HostSrc,
- GlobalDeviceMemory<T> *DeviceDst, size_t ElementCount) {
- if (ElementCount > HostSrc.size())
+ Stream &thenCopyD2H(GlobalDeviceMemory<T> Src, T *Dst, size_t ElementCount) {
+ thenCopyD2H(Src.asSlice(), Dst, ElementCount);
+ return *this;
+ }
+
+ /// Similar to thenCopyD2H(GlobalDeviceMemorySlice<T>,
+ /// llvm::MutableArrayRef<T>, size_t) but copies from host to device memory
+ /// rather than device to host memory.
+ template <typename T>
+ Stream &thenCopyH2D(llvm::ArrayRef<T> Src, GlobalDeviceMemorySlice<T> Dst,
+ size_t ElementCount) {
+ if (ElementCount > Src.size())
setError("copying too many elements, " + llvm::Twine(ElementCount) +
- ", from host array of size " + llvm::Twine(HostSrc.size()));
- else if (ElementCount > DeviceDst->getElementCount())
+ ", from a host array of element count " +
+ llvm::Twine(Src.size()));
+ else if (ElementCount > Dst.getElementCount())
setError("copying too many elements, " + llvm::Twine(ElementCount) +
- ", to device memory array of size " +
- llvm::Twine(DeviceDst->getElementCount()));
+ ", to a device array of element count " +
+ llvm::Twine(Dst.getElementCount()));
else
- setError(PExecutor->memcpyH2D(ThePlatformStream.get(), HostSrc.data(),
- DeviceDst, ElementCount * sizeof(T)));
+ setError(PExecutor->copyH2D(
+ ThePlatformStream.get(), Src.data(), 0, Dst.getBaseMemory(),
+ Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T)));
return *this;
}
- /// Same as thenMemcpyH2D above, but copies the entire source to the
- /// destination.
+ /// Similar to thenCopyH2D(llvm::ArrayRef<T>, GlobalDeviceMemorySlice<T>,
+ /// size_t) but does not take an element count argument because it copies the
+ /// entire source array.
+ ///
+ /// Sets an error if the Src and Dst sizes do not match.
+ template <typename T>
+ Stream &thenCopyH2D(llvm::ArrayRef<T> Src, GlobalDeviceMemorySlice<T> Dst) {
+ if (Src.size() != Dst.getElementCount())
+ setError("array size mismatch for H2D, host source has element count " +
+ llvm::Twine(Src.size()) +
+ " but device destination has element count " +
+ llvm::Twine(Dst.getElementCount()));
+ else
+ thenCopyH2D(Src, Dst, Dst.getElementCount());
+ return *this;
+ }
+
+ /// Similar to thenCopyH2D(llvm::ArrayRef<T>, GlobalDeviceMemorySlice<T>,
+ /// size_t) but copies from a pointer rather than an llvm::ArrayRef.
+ ///
+ /// Sets an error if ElementCount is too large for the destination.
+ template <typename T>
+ Stream &thenCopyH2D(T *Src, GlobalDeviceMemorySlice<T> Dst,
+ size_t ElementCount) {
+ thenCopyH2D(llvm::ArrayRef<T>(Src, ElementCount), Dst, ElementCount);
+ return *this;
+ }
+
+ /// Similar to thenCopyH2D(llvm::ArrayRef<T>, GlobalDeviceMemorySlice<T>,
+ /// size_t) but the destination is a GlobalDeviceMemory rather than a
+ /// GlobalDeviceMemorySlice.
+ template <typename T>
+ Stream &thenCopyH2D(llvm::ArrayRef<T> Src, GlobalDeviceMemory<T> Dst,
+ size_t ElementCount) {
+ thenCopyH2D(Src, Dst.asSlice(), ElementCount);
+ return *this;
+ }
+
+ /// Similar to thenCopyH2D(llvm::ArrayRef<T>, GlobalDeviceMemorySlice<T>) but
+ /// the destination is a GlobalDeviceMemory rather than a
+ /// GlobalDeviceMemorySlice.
template <typename T>
- Stream &thenMemcpyH2D(llvm::ArrayRef<T> HostSrc,
- GlobalDeviceMemory<T> *DeviceDst) {
- return thenMemcpyH2D(HostSrc, DeviceDst, HostSrc.size());
+ Stream &thenCopyH2D(llvm::ArrayRef<T> Src, GlobalDeviceMemory<T> Dst) {
+ thenCopyH2D(Src, Dst.asSlice());
+ return *this;
}
- /// Entrain onto the stream a memcpy of a given number of elements from a
- /// device source to a device destination.
+ /// Similar to thenCopyH2D(T*, GlobalDeviceMemorySlice<T>, size_t) but the
+ /// destination is a GlobalDeviceMemory rather than a GlobalDeviceMemorySlice.
template <typename T>
- Stream &thenMemcpyD2D(const GlobalDeviceMemory<T> &DeviceSrc,
- GlobalDeviceMemory<T> *DeviceDst, size_t ElementCount) {
- if (ElementCount > DeviceSrc.getElementCount())
+ Stream &thenCopyH2D(T *Src, GlobalDeviceMemory<T> Dst, size_t ElementCount) {
+ thenCopyH2D(Src, Dst.asSlice(), ElementCount);
+ return *this;
+ }
+
+ /// Similar to thenCopyD2H(GlobalDeviceMemorySlice<T>,
+ /// llvm::MutableArrayRef<T>, size_t) but copies from one location in device
+ /// memory to another rather than from device to host memory.
+ template <typename T>
+ Stream &thenCopyD2D(GlobalDeviceMemorySlice<T> Src,
+ GlobalDeviceMemorySlice<T> Dst, size_t ElementCount) {
+ if (ElementCount > Src.getElementCount())
setError("copying too many elements, " + llvm::Twine(ElementCount) +
- ", from device memory array of size " +
- llvm::Twine(DeviceSrc.getElementCount()));
- else if (ElementCount > DeviceDst->getElementCount())
+ ", from a device array of element count " +
+ llvm::Twine(Src.getElementCount()));
+ else if (ElementCount > Dst.getElementCount())
setError("copying too many elements, " + llvm::Twine(ElementCount) +
- ", to device memory array of size " +
- llvm::Twine(DeviceDst->getElementCount()));
+ ", to a device array of element count " +
+ llvm::Twine(Dst.getElementCount()));
else
- setError(PExecutor->memcpyD2D(ThePlatformStream.get(), DeviceSrc,
- DeviceDst, ElementCount * sizeof(T)));
+ setError(PExecutor->copyD2D(
+ ThePlatformStream.get(), Src.getBaseMemory(),
+ Src.getElementOffset() * sizeof(T), Dst.getBaseMemory(),
+ Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T)));
return *this;
}
- /// Same as thenMemcpyD2D above, but copies the entire source to the
- /// destination.
+ /// Similar to thenCopyD2D(GlobalDeviceMemorySlice<T>,
+ /// GlobalDeviceMemorySlice<T>, size_t) but does not take an element count
+ /// argument because it copies the entire source array.
+ ///
+ /// Sets an error if the Src and Dst sizes do not match.
template <typename T>
- Stream &thenMemcpyD2D(const GlobalDeviceMemory<T> &DeviceSrc,
- GlobalDeviceMemory<T> *DeviceDst) {
- return thenMemcpyD2D(DeviceSrc, DeviceDst, DeviceSrc.getElementCount());
+ Stream &thenCopyD2D(GlobalDeviceMemorySlice<T> Src,
+ GlobalDeviceMemorySlice<T> Dst) {
+ if (Src.getElementCount() != Dst.getElementCount())
+ setError("array size mismatch for D2D, device source has element count " +
+ llvm::Twine(Src.getElementCount()) +
+ " but device destination has element count " +
+ llvm::Twine(Dst.getElementCount()));
+ else
+ thenCopyD2D(Src, Dst, Src.getElementCount());
+ return *this;
}
- /// Blocks the host code, waiting for the operations entrained on the stream
- /// (enqueued up to this point in program execution) to complete.
- ///
- /// Returns true if there are no errors on the stream.
- bool blockHostUntilDone() {
- Error E = PExecutor->blockHostUntilDone(ThePlatformStream.get());
- bool returnValue = static_cast<bool>(E);
- setError(std::move(E));
- return returnValue;
+ /// Similar to thenCopyD2D(GlobalDeviceMemorySlice<T>,
+ /// GlobalDeviceMemorySlice<T>, size_t) but the source is a
+ /// GlobalDeviceMemory<T> rather than a GlobalDeviceMemorySlice<T>.
+ template <typename T>
+ Stream &thenCopyD2D(GlobalDeviceMemory<T> Src, GlobalDeviceMemorySlice<T> Dst,
+ size_t ElementCount) {
+ thenCopyD2D(Src.asSlice(), Dst, ElementCount);
+ return *this;
+ }
+
+ /// Similar to thenCopyD2D(GlobalDeviceMemorySlice<T>,
+ /// GlobalDeviceMemorySlice<T>) but the source is a GlobalDeviceMemory<T>
+ /// rather than a GlobalDeviceMemorySlice<T>.
+ template <typename T>
+ Stream &thenCopyD2D(GlobalDeviceMemory<T> Src,
+ GlobalDeviceMemorySlice<T> Dst) {
+ thenCopyD2D(Src.asSlice(), Dst);
+ return *this;
+ }
+
+ /// Similar to thenCopyD2D(GlobalDeviceMemorySlice<T>,
+ /// GlobalDeviceMemorySlice<T>, size_t) but the destination is a
+ /// GlobalDeviceMemory<T> rather than a GlobalDeviceMemorySlice<T>.
+ template <typename T>
+ Stream &thenCopyD2D(GlobalDeviceMemorySlice<T> Src, GlobalDeviceMemory<T> Dst,
+ size_t ElementCount) {
+ thenCopyD2D(Src, Dst.asSlice(), ElementCount);
+ return *this;
+ }
+
+ /// Similar to thenCopyD2D(GlobalDeviceMemorySlice<T>,
+ /// GlobalDeviceMemorySlice<T>) but the destination is a GlobalDeviceMemory<T>
+ /// rather than a GlobalDeviceMemorySlice<T>.
+ template <typename T>
+ Stream &thenCopyD2D(GlobalDeviceMemorySlice<T> Src,
+ GlobalDeviceMemory<T> Dst) {
+ thenCopyD2D(Src, Dst.asSlice());
+ return *this;
+ }
+
+ /// Similar to thenCopyD2D(GlobalDeviceMemorySlice<T>,
+ /// GlobalDeviceMemorySlice<T>, size_t) but the source and destination are
+ /// GlobalDeviceMemory<T> rather than a GlobalDeviceMemorySlice<T>.
+ template <typename T>
+ Stream &thenCopyD2D(GlobalDeviceMemory<T> Src, GlobalDeviceMemory<T> Dst,
+ size_t ElementCount) {
+ thenCopyD2D(Src.asSlice(), Dst.asSlice(), ElementCount);
+ return *this;
+ }
+
+ /// Similar to thenCopyD2D(GlobalDeviceMemorySlice<T>,
+ /// GlobalDeviceMemorySlice<T>) but the source and destination are
+ /// GlobalDeviceMemory<T> rather than a GlobalDeviceMemorySlice<T>.
+ template <typename T>
+ Stream &thenCopyD2D(GlobalDeviceMemory<T> Src, GlobalDeviceMemory<T> Dst) {
+ thenCopyD2D(Src.asSlice(), Dst.asSlice());
+ return *this;
}
private:
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Utils/Error.h b/parallel-libs/streamexecutor/include/streamexecutor/Utils/Error.h
index 381184717e5..e7b313ea057 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/Utils/Error.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/Utils/Error.h
@@ -30,7 +30,7 @@
/// }
/// \endcode
///
-/// Error instances are implicitly convertable to bool. Error values convert to
+/// Error instances are implicitly convertible to bool. Error values convert to
/// true and successes convert to false. Error instances must have their boolean
/// values checked or they must be moved before they go out of scope, otherwise
/// their destruction will cause the program to abort with a warning about an
@@ -169,10 +169,10 @@ namespace streamexecutor {
using llvm::consumeError;
using llvm::Error;
using llvm::Expected;
-using llvm::StringRef;
+using llvm::Twine;
// Makes an Error object from an error message.
-Error make_error(StringRef Message);
+Error make_error(Twine Message);
// Consumes the input error and returns its error message.
//
diff --git a/parallel-libs/streamexecutor/lib/Utils/Error.cpp b/parallel-libs/streamexecutor/lib/Utils/Error.cpp
index 78912c531ae..f3d09673c21 100644
--- a/parallel-libs/streamexecutor/lib/Utils/Error.cpp
+++ b/parallel-libs/streamexecutor/lib/Utils/Error.cpp
@@ -27,7 +27,7 @@ public:
std::error_code convertToErrorCode() const override {
llvm_unreachable(
- "StreamExecutorError does not support convertion to std::error_code");
+ "StreamExecutorError does not support conversion to std::error_code");
}
std::string getErrorMessage() const { return Message; }
@@ -44,8 +44,8 @@ char StreamExecutorError::ID = 0;
namespace streamexecutor {
-Error make_error(StringRef Message) {
- return llvm::make_error<StreamExecutorError>(Message);
+Error make_error(Twine Message) {
+ return llvm::make_error<StreamExecutorError>(Message.str());
}
std::string consumeAndGetMessage(Error &&E) {
diff --git a/parallel-libs/streamexecutor/lib/unittests/CMakeLists.txt b/parallel-libs/streamexecutor/lib/unittests/CMakeLists.txt
index f6e6edbebfd..244312ff12c 100644
--- a/parallel-libs/streamexecutor/lib/unittests/CMakeLists.txt
+++ b/parallel-libs/streamexecutor/lib/unittests/CMakeLists.txt
@@ -1,4 +1,14 @@
add_executable(
+ executor_test
+ ExecutorTest.cpp)
+target_link_libraries(
+ executor_test
+ streamexecutor
+ ${GTEST_BOTH_LIBRARIES}
+ ${CMAKE_THREAD_LIBS_INIT})
+add_test(ExecutorTest executor_test)
+
+add_executable(
kernel_test
KernelTest.cpp)
target_link_libraries(
diff --git a/parallel-libs/streamexecutor/lib/unittests/ExecutorTest.cpp b/parallel-libs/streamexecutor/lib/unittests/ExecutorTest.cpp
new file mode 100644
index 00000000000..d2d03fb6c88
--- /dev/null
+++ b/parallel-libs/streamexecutor/lib/unittests/ExecutorTest.cpp
@@ -0,0 +1,451 @@
+//===-- ExecutorTest.cpp - Tests for Executor -----------------------------===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file contains the unit tests for Executor code.
+///
+//===----------------------------------------------------------------------===//
+
+#include <cstdlib>
+#include <cstring>
+
+#include "streamexecutor/Executor.h"
+#include "streamexecutor/PlatformInterfaces.h"
+
+#include "gtest/gtest.h"
+
+namespace {
+
+namespace se = ::streamexecutor;
+
+class MockPlatformExecutor : public se::PlatformExecutor {
+public:
+ ~MockPlatformExecutor() override {}
+
+ std::string getName() const override { return "MockPlatformExecutor"; }
+
+ se::Expected<std::unique_ptr<se::PlatformStreamHandle>>
+ createStream() override {
+ return se::make_error("not implemented");
+ }
+
+ se::Expected<se::GlobalDeviceMemoryBase>
+ allocateDeviceMemory(size_t ByteCount) override {
+ return se::GlobalDeviceMemoryBase(std::malloc(ByteCount));
+ }
+
+ se::Error freeDeviceMemory(se::GlobalDeviceMemoryBase Memory) override {
+ std::free(const_cast<void *>(Memory.getHandle()));
+ return se::Error::success();
+ }
+
+ se::Expected<void *> allocateHostMemory(size_t ByteCount) override {
+ return std::malloc(ByteCount);
+ }
+
+ se::Error freeHostMemory(void *Memory) override {
+ std::free(Memory);
+ return se::Error::success();
+ }
+
+ se::Error synchronousCopyD2H(const se::GlobalDeviceMemoryBase &DeviceSrc,
+ size_t SrcByteOffset, void *HostDst,
+ size_t DstByteOffset,
+ size_t ByteCount) override {
+ std::memcpy(static_cast<char *>(HostDst) + DstByteOffset,
+ static_cast<const char *>(DeviceSrc.getHandle()) +
+ SrcByteOffset,
+ ByteCount);
+ return se::Error::success();
+ }
+
+ se::Error synchronousCopyH2D(const void *HostSrc, size_t SrcByteOffset,
+ se::GlobalDeviceMemoryBase DeviceDst,
+ size_t DstByteOffset,
+ size_t ByteCount) override {
+ std::memcpy(static_cast<char *>(const_cast<void *>(DeviceDst.getHandle())) +
+ DstByteOffset,
+ static_cast<const char *>(HostSrc) + SrcByteOffset, ByteCount);
+ return se::Error::success();
+ }
+
+ se::Error synchronousCopyD2D(se::GlobalDeviceMemoryBase DeviceDst,
+ size_t DstByteOffset,
+ const se::GlobalDeviceMemoryBase &DeviceSrc,
+ size_t SrcByteOffset,
+ size_t ByteCount) override {
+ std::memcpy(static_cast<char *>(const_cast<void *>(DeviceDst.getHandle())) +
+ DstByteOffset,
+ static_cast<const char *>(DeviceSrc.getHandle()) +
+ SrcByteOffset,
+ ByteCount);
+ return se::Error::success();
+ }
+};
+
+/// Test fixture to hold objects used by tests.
+class ExecutorTest : public ::testing::Test {
+public:
+ ExecutorTest()
+ : HostA5{0, 1, 2, 3, 4}, HostB5{5, 6, 7, 8, 9},
+ HostA7{10, 11, 12, 13, 14, 15, 16}, HostB7{17, 18, 19, 20, 21, 22, 23},
+ DeviceA5(se::GlobalDeviceMemory<int>::makeFromElementCount(HostA5, 5)),
+ DeviceB5(se::GlobalDeviceMemory<int>::makeFromElementCount(HostB5, 5)),
+ DeviceA7(se::GlobalDeviceMemory<int>::makeFromElementCount(HostA7, 7)),
+ DeviceB7(se::GlobalDeviceMemory<int>::makeFromElementCount(HostB7, 7)),
+ Host5{24, 25, 26, 27, 28}, Host7{29, 30, 31, 32, 33, 34, 35},
+ Executor(&PExecutor) {}
+
+ // Device memory is backed by host arrays.
+ int HostA5[5];
+ int HostB5[5];
+ int HostA7[7];
+ int HostB7[7];
+ se::GlobalDeviceMemory<int> DeviceA5;
+ se::GlobalDeviceMemory<int> DeviceB5;
+ se::GlobalDeviceMemory<int> DeviceA7;
+ se::GlobalDeviceMemory<int> DeviceB7;
+
+ // Host memory to be used as actual host memory.
+ int Host5[5];
+ int Host7[7];
+
+ MockPlatformExecutor PExecutor;
+ se::Executor Executor;
+};
+
+#define EXPECT_NO_ERROR(E) EXPECT_FALSE(static_cast<bool>(E))
+#define EXPECT_ERROR(E) \
+ do { \
+ se::Error E__ = E; \
+ EXPECT_TRUE(static_cast<bool>(E__)); \
+ consumeError(std::move(E__)); \
+ } while (false)
+
+using llvm::ArrayRef;
+using llvm::MutableArrayRef;
+
+// D2H tests
+
+TEST_F(ExecutorTest, SyncCopyD2HToMutableArrayRefByCount) {
+ EXPECT_NO_ERROR(
+ Executor.synchronousCopyD2H(DeviceA5, MutableArrayRef<int>(Host5), 5));
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
+
+ EXPECT_NO_ERROR(
+ Executor.synchronousCopyD2H(DeviceB5, MutableArrayRef<int>(Host5), 2));
+ for (int I = 0; I < 2; ++I) {
+ EXPECT_EQ(HostB5[I], Host5[I]);
+ }
+
+ EXPECT_ERROR(
+ Executor.synchronousCopyD2H(DeviceA7, MutableArrayRef<int>(Host5), 7));
+
+ EXPECT_ERROR(
+ Executor.synchronousCopyD2H(DeviceA5, MutableArrayRef<int>(Host7), 7));
+
+ EXPECT_ERROR(
+ Executor.synchronousCopyD2H(DeviceA5, MutableArrayRef<int>(Host5), 7));
+}
+
+TEST_F(ExecutorTest, SyncCopyD2HToMutableArrayRef) {
+ EXPECT_NO_ERROR(
+ Executor.synchronousCopyD2H(DeviceA5, MutableArrayRef<int>(Host5)));
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
+
+ EXPECT_ERROR(
+ Executor.synchronousCopyD2H(DeviceA7, MutableArrayRef<int>(Host5)));
+
+ EXPECT_ERROR(
+ Executor.synchronousCopyD2H(DeviceA5, MutableArrayRef<int>(Host7)));
+}
+
+TEST_F(ExecutorTest, SyncCopyD2HToPointer) {
+ EXPECT_NO_ERROR(Executor.synchronousCopyD2H(DeviceA5, Host5, 5));
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
+
+ EXPECT_ERROR(Executor.synchronousCopyD2H(DeviceA5, Host7, 7));
+}
+
+TEST_F(ExecutorTest, SyncCopyD2HSliceToMutableArrayRefByCount) {
+ EXPECT_NO_ERROR(Executor.synchronousCopyD2H(
+ DeviceA5.asSlice().drop_front(1), MutableArrayRef<int>(Host5 + 1, 4), 4));
+ for (int I = 1; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
+
+ EXPECT_NO_ERROR(Executor.synchronousCopyD2H(DeviceB5.asSlice().drop_back(1),
+ MutableArrayRef<int>(Host5), 2));
+ for (int I = 0; I < 2; ++I) {
+ EXPECT_EQ(HostB5[I], Host5[I]);
+ }
+
+ EXPECT_ERROR(Executor.synchronousCopyD2H(DeviceA7.asSlice(),
+ MutableArrayRef<int>(Host5), 7));
+
+ EXPECT_ERROR(Executor.synchronousCopyD2H(DeviceA5.asSlice(),
+ MutableArrayRef<int>(Host7), 7));
+
+ EXPECT_ERROR(Executor.synchronousCopyD2H(DeviceA5.asSlice(),
+ MutableArrayRef<int>(Host5), 7));
+}
+
+TEST_F(ExecutorTest, SyncCopyD2HSliceToMutableArrayRef) {
+ EXPECT_NO_ERROR(Executor.synchronousCopyD2H(DeviceA7.asSlice().slice(1, 5),
+ MutableArrayRef<int>(Host5)));
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA7[I + 1], Host5[I]);
+ }
+
+ EXPECT_ERROR(Executor.synchronousCopyD2H(DeviceA7.asSlice().drop_back(1),
+ MutableArrayRef<int>(Host5)));
+
+ EXPECT_ERROR(Executor.synchronousCopyD2H(DeviceA5.asSlice(),
+ MutableArrayRef<int>(Host7)));
+}
+
+TEST_F(ExecutorTest, SyncCopyD2HSliceToPointer) {
+ EXPECT_NO_ERROR(Executor.synchronousCopyD2H(DeviceA5.asSlice().drop_front(1),
+ Host5 + 1, 4));
+ for (int I = 1; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
+
+ EXPECT_ERROR(Executor.synchronousCopyD2H(DeviceA5.asSlice(), Host7, 7));
+}
+
+// H2D tests
+
+TEST_F(ExecutorTest, SyncCopyH2DToArrayRefByCount) {
+ EXPECT_NO_ERROR(
+ Executor.synchronousCopyH2D(ArrayRef<int>(Host5), DeviceA5, 5));
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
+
+ EXPECT_NO_ERROR(
+ Executor.synchronousCopyH2D(ArrayRef<int>(Host5), DeviceB5, 2));
+ for (int I = 0; I < 2; ++I) {
+ EXPECT_EQ(HostB5[I], Host5[I]);
+ }
+
+ EXPECT_ERROR(Executor.synchronousCopyH2D(ArrayRef<int>(Host7), DeviceA5, 7));
+
+ EXPECT_ERROR(Executor.synchronousCopyH2D(ArrayRef<int>(Host5), DeviceA7, 7));
+
+ EXPECT_ERROR(Executor.synchronousCopyH2D(ArrayRef<int>(Host5), DeviceA5, 7));
+}
+
+TEST_F(ExecutorTest, SyncCopyH2DToArrayRef) {
+ EXPECT_NO_ERROR(Executor.synchronousCopyH2D(ArrayRef<int>(Host5), DeviceA5));
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
+
+ EXPECT_ERROR(Executor.synchronousCopyH2D(ArrayRef<int>(Host5), DeviceA7));
+
+ EXPECT_ERROR(Executor.synchronousCopyH2D(ArrayRef<int>(Host7), DeviceA5));
+}
+
+TEST_F(ExecutorTest, SyncCopyH2DToPointer) {
+ EXPECT_NO_ERROR(Executor.synchronousCopyH2D(Host5, DeviceA5, 5));
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
+
+ EXPECT_ERROR(Executor.synchronousCopyH2D(Host7, DeviceA5, 7));
+}
+
+TEST_F(ExecutorTest, SyncCopyH2DSliceToArrayRefByCount) {
+ EXPECT_NO_ERROR(Executor.synchronousCopyH2D(
+ ArrayRef<int>(Host5 + 1, 4), DeviceA5.asSlice().drop_front(1), 4));
+ for (int I = 1; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
+
+ EXPECT_NO_ERROR(Executor.synchronousCopyH2D(
+ ArrayRef<int>(Host5), DeviceB5.asSlice().drop_back(1), 2));
+ for (int I = 0; I < 2; ++I) {
+ EXPECT_EQ(HostB5[I], Host5[I]);
+ }
+
+ EXPECT_ERROR(
+ Executor.synchronousCopyH2D(ArrayRef<int>(Host7), DeviceA5.asSlice(), 7));
+
+ EXPECT_ERROR(
+ Executor.synchronousCopyH2D(ArrayRef<int>(Host5), DeviceA7.asSlice(), 7));
+
+ EXPECT_ERROR(
+ Executor.synchronousCopyH2D(ArrayRef<int>(Host5), DeviceA5.asSlice(), 7));
+}
+
+TEST_F(ExecutorTest, SyncCopyH2DSliceToArrayRef) {
+ EXPECT_NO_ERROR(
+ Executor.synchronousCopyH2D(ArrayRef<int>(Host5), DeviceA5.asSlice()));
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
+
+ EXPECT_ERROR(
+ Executor.synchronousCopyH2D(ArrayRef<int>(Host5), DeviceA7.asSlice()));
+
+ EXPECT_ERROR(
+ Executor.synchronousCopyH2D(ArrayRef<int>(Host7), DeviceA5.asSlice()));
+}
+
+TEST_F(ExecutorTest, SyncCopyH2DSliceToPointer) {
+ EXPECT_NO_ERROR(Executor.synchronousCopyH2D(Host5, DeviceA5.asSlice(), 5));
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
+
+ EXPECT_ERROR(Executor.synchronousCopyH2D(Host7, DeviceA5.asSlice(), 7));
+}
+
+// D2D tests
+
+TEST_F(ExecutorTest, SyncCopyD2DByCount) {
+ EXPECT_NO_ERROR(Executor.synchronousCopyD2D(DeviceA5, DeviceB5, 5));
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], HostB5[I]);
+ }
+
+ EXPECT_NO_ERROR(Executor.synchronousCopyD2D(DeviceA7, DeviceB7, 2));
+ for (int I = 0; I < 2; ++I) {
+ EXPECT_EQ(HostA7[I], HostB7[I]);
+ }
+
+ EXPECT_ERROR(Executor.synchronousCopyD2D(DeviceA5, DeviceB5, 7));
+
+ EXPECT_ERROR(Executor.synchronousCopyD2D(DeviceA7, DeviceB5, 7));
+
+ EXPECT_ERROR(Executor.synchronousCopyD2D(DeviceA5, DeviceB7, 7));
+}
+
+TEST_F(ExecutorTest, SyncCopyD2D) {
+ EXPECT_NO_ERROR(Executor.synchronousCopyD2D(DeviceA5, DeviceB5));
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], HostB5[I]);
+ }
+
+ EXPECT_ERROR(Executor.synchronousCopyD2D(DeviceA7, DeviceB5));
+
+ EXPECT_ERROR(Executor.synchronousCopyD2D(DeviceA5, DeviceB7));
+}
+
+TEST_F(ExecutorTest, SyncCopySliceD2DByCount) {
+ EXPECT_NO_ERROR(Executor.synchronousCopyD2D(DeviceA5.asSlice().drop_front(1),
+ DeviceB5, 4));
+ for (int I = 0; I < 4; ++I) {
+ EXPECT_EQ(HostA5[I + 1], HostB5[I]);
+ }
+
+ EXPECT_NO_ERROR(Executor.synchronousCopyD2D(DeviceA7.asSlice().drop_back(1),
+ DeviceB7, 2));
+ for (int I = 0; I < 2; ++I) {
+ EXPECT_EQ(HostA7[I], HostB7[I]);
+ }
+
+ EXPECT_ERROR(Executor.synchronousCopyD2D(DeviceA5.asSlice(), DeviceB5, 7));
+
+ EXPECT_ERROR(Executor.synchronousCopyD2D(DeviceA7.asSlice(), DeviceB5, 7));
+
+ EXPECT_ERROR(Executor.synchronousCopyD2D(DeviceA5.asSlice(), DeviceB7, 7));
+}
+
+TEST_F(ExecutorTest, SyncCopySliceD2D) {
+ EXPECT_NO_ERROR(
+ Executor.synchronousCopyD2D(DeviceA7.asSlice().drop_back(2), DeviceB5));
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA7[I], HostB5[I]);
+ }
+
+ EXPECT_ERROR(
+ Executor.synchronousCopyD2D(DeviceA7.asSlice().drop_front(1), DeviceB5));
+
+ EXPECT_ERROR(
+ Executor.synchronousCopyD2D(DeviceA5.asSlice().drop_back(1), DeviceB7));
+}
+
+TEST_F(ExecutorTest, SyncCopyD2DSliceByCount) {
+ EXPECT_NO_ERROR(Executor.synchronousCopyD2D(
+ DeviceA5, DeviceB7.asSlice().drop_front(2), 5));
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], HostB7[I + 2]);
+ }
+
+ EXPECT_NO_ERROR(Executor.synchronousCopyD2D(
+ DeviceA7, DeviceB7.asSlice().drop_back(3), 2));
+ for (int I = 0; I < 2; ++I) {
+ EXPECT_EQ(HostA7[I], HostB7[I]);
+ }
+
+ EXPECT_ERROR(Executor.synchronousCopyD2D(DeviceA5, DeviceB5.asSlice(), 7));
+
+ EXPECT_ERROR(Executor.synchronousCopyD2D(DeviceA7, DeviceB5.asSlice(), 7));
+
+ EXPECT_ERROR(Executor.synchronousCopyD2D(DeviceA5, DeviceB7.asSlice(), 7));
+}
+
+TEST_F(ExecutorTest, SyncCopyD2DSlice) {
+ EXPECT_NO_ERROR(
+ Executor.synchronousCopyD2D(DeviceA5, DeviceB7.asSlice().drop_back(2)));
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], HostB7[I]);
+ }
+
+ EXPECT_ERROR(Executor.synchronousCopyD2D(DeviceA7, DeviceB5.asSlice()));
+
+ EXPECT_ERROR(Executor.synchronousCopyD2D(DeviceA5, DeviceB7.asSlice()));
+}
+
+TEST_F(ExecutorTest, SyncCopySliceD2DSliceByCount) {
+ EXPECT_NO_ERROR(
+ Executor.synchronousCopyD2D(DeviceA5.asSlice(), DeviceB5.asSlice(), 5));
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], HostB5[I]);
+ }
+
+ EXPECT_NO_ERROR(
+ Executor.synchronousCopyD2D(DeviceA7.asSlice(), DeviceB7.asSlice(), 2));
+ for (int I = 0; I < 2; ++I) {
+ EXPECT_EQ(HostA7[I], HostB7[I]);
+ }
+
+ EXPECT_ERROR(
+ Executor.synchronousCopyD2D(DeviceA5.asSlice(), DeviceB5.asSlice(), 7));
+
+ EXPECT_ERROR(
+ Executor.synchronousCopyD2D(DeviceA7.asSlice(), DeviceB5.asSlice(), 7));
+
+ EXPECT_ERROR(
+ Executor.synchronousCopyD2D(DeviceA5.asSlice(), DeviceB7.asSlice(), 7));
+}
+
+TEST_F(ExecutorTest, SyncCopySliceD2DSlice) {
+ EXPECT_NO_ERROR(
+ Executor.synchronousCopyD2D(DeviceA5.asSlice(), DeviceB5.asSlice()));
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], HostB5[I]);
+ }
+
+ EXPECT_ERROR(
+ Executor.synchronousCopyD2D(DeviceA7.asSlice(), DeviceB5.asSlice()));
+
+ EXPECT_ERROR(
+ Executor.synchronousCopyD2D(DeviceA5.asSlice(), DeviceB7.asSlice()));
+}
+
+} // namespace
diff --git a/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp b/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp
index 6ef21833108..756467057ac 100644
--- a/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp
+++ b/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp
@@ -40,26 +40,34 @@ public:
return nullptr;
}
- se::Error memcpyD2H(se::PlatformStreamHandle *,
- const se::GlobalDeviceMemoryBase &DeviceSrc,
- void *HostDst, size_t ByteCount) override {
- std::memcpy(HostDst, DeviceSrc.getHandle(), ByteCount);
+ se::Error copyD2H(se::PlatformStreamHandle *S,
+ const se::GlobalDeviceMemoryBase &DeviceSrc,
+ size_t SrcByteOffset, void *HostDst, size_t DstByteOffset,
+ size_t ByteCount) override {
+ std::memcpy(HostDst, static_cast<const char *>(DeviceSrc.getHandle()) +
+ SrcByteOffset,
+ ByteCount);
return se::Error::success();
}
- se::Error memcpyH2D(se::PlatformStreamHandle *, const void *HostSrc,
- se::GlobalDeviceMemoryBase *DeviceDst,
- size_t ByteCount) override {
- std::memcpy(const_cast<void *>(DeviceDst->getHandle()), HostSrc, ByteCount);
+ se::Error copyH2D(se::PlatformStreamHandle *S, const void *HostSrc,
+ size_t SrcByteOffset, se::GlobalDeviceMemoryBase DeviceDst,
+ size_t DstByteOffset, size_t ByteCount) override {
+ std::memcpy(static_cast<char *>(const_cast<void *>(DeviceDst.getHandle())) +
+ DstByteOffset,
+ HostSrc, ByteCount);
return se::Error::success();
}
- se::Error memcpyD2D(se::PlatformStreamHandle *,
- const se::GlobalDeviceMemoryBase &DeviceSrc,
- se::GlobalDeviceMemoryBase *DeviceDst,
- size_t ByteCount) override {
- std::memcpy(const_cast<void *>(DeviceDst->getHandle()),
- DeviceSrc.getHandle(), ByteCount);
+ se::Error copyD2D(se::PlatformStreamHandle *S,
+ const se::GlobalDeviceMemoryBase &DeviceSrc,
+ size_t SrcByteOffset, se::GlobalDeviceMemoryBase DeviceDst,
+ size_t DstByteOffset, size_t ByteCount) override {
+ std::memcpy(static_cast<char *>(const_cast<void *>(DeviceDst.getHandle())) +
+ DstByteOffset,
+ static_cast<const char *>(DeviceSrc.getHandle()) +
+ SrcByteOffset,
+ ByteCount);
return se::Error::success();
}
};
@@ -68,47 +76,317 @@ public:
class StreamTest : public ::testing::Test {
public:
StreamTest()
- : DeviceA(se::GlobalDeviceMemory<int>::makeFromElementCount(HostA, 10)),
- DeviceB(se::GlobalDeviceMemory<int>::makeFromElementCount(HostB, 10)),
+ : HostA5{0, 1, 2, 3, 4}, HostB5{5, 6, 7, 8, 9},
+ HostA7{10, 11, 12, 13, 14, 15, 16}, HostB7{17, 18, 19, 20, 21, 22, 23},
+ DeviceA5(se::GlobalDeviceMemory<int>::makeFromElementCount(HostA5, 5)),
+ DeviceB5(se::GlobalDeviceMemory<int>::makeFromElementCount(HostB5, 5)),
+ DeviceA7(se::GlobalDeviceMemory<int>::makeFromElementCount(HostA7, 7)),
+ DeviceB7(se::GlobalDeviceMemory<int>::makeFromElementCount(HostB7, 7)),
+ Host5{24, 25, 26, 27, 28}, Host7{29, 30, 31, 32, 33, 34, 35},
Stream(llvm::make_unique<se::PlatformStreamHandle>(&PExecutor)) {}
protected:
// Device memory is backed by host arrays.
- int HostA[10];
- se::GlobalDeviceMemory<int> DeviceA;
- int HostB[10];
- se::GlobalDeviceMemory<int> DeviceB;
+ int HostA5[5];
+ int HostB5[5];
+ int HostA7[7];
+ int HostB7[7];
+ se::GlobalDeviceMemory<int> DeviceA5;
+ se::GlobalDeviceMemory<int> DeviceB5;
+ se::GlobalDeviceMemory<int> DeviceA7;
+ se::GlobalDeviceMemory<int> DeviceB7;
// Host memory to be used as actual host memory.
- int Host[10];
+ int Host5[5];
+ int Host7[7];
MockPlatformExecutor PExecutor;
se::Stream Stream;
};
-TEST_F(StreamTest, MemcpyCorrectSize) {
- Stream.thenMemcpyH2D(llvm::ArrayRef<int>(Host), &DeviceA);
+using llvm::ArrayRef;
+using llvm::MutableArrayRef;
+
+// D2H tests
+
+TEST_F(StreamTest, CopyD2HToMutableArrayRefByCount) {
+ Stream.thenCopyD2H(DeviceA5, MutableArrayRef<int>(Host5), 5);
EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
- Stream.thenMemcpyD2H(DeviceA, llvm::MutableArrayRef<int>(Host));
+ Stream.thenCopyD2H(DeviceB5, MutableArrayRef<int>(Host5), 2);
EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 2; ++I) {
+ EXPECT_EQ(HostB5[I], Host5[I]);
+ }
- Stream.thenMemcpyD2D(DeviceA, &DeviceB);
+ Stream.thenCopyD2H(DeviceA7, MutableArrayRef<int>(Host5), 7);
+ EXPECT_FALSE(Stream.isOK());
+}
+
+TEST_F(StreamTest, CopyD2HToMutableArrayRef) {
+ Stream.thenCopyD2H(DeviceA5, MutableArrayRef<int>(Host5));
EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
+
+ Stream.thenCopyD2H(DeviceA5, MutableArrayRef<int>(Host7));
+ EXPECT_FALSE(Stream.isOK());
}
-TEST_F(StreamTest, MemcpyH2DTooManyElements) {
- Stream.thenMemcpyH2D(llvm::ArrayRef<int>(Host), &DeviceA, 20);
+TEST_F(StreamTest, CopyD2HToPointer) {
+ Stream.thenCopyD2H(DeviceA5, Host5, 5);
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
+
+ Stream.thenCopyD2H(DeviceA5, Host7, 7);
EXPECT_FALSE(Stream.isOK());
}
-TEST_F(StreamTest, MemcpyD2HTooManyElements) {
- Stream.thenMemcpyD2H(DeviceA, llvm::MutableArrayRef<int>(Host), 20);
+TEST_F(StreamTest, CopyD2HSliceToMutableArrayRefByCount) {
+ Stream.thenCopyD2H(DeviceA5.asSlice().drop_front(1),
+ MutableArrayRef<int>(Host5 + 1, 4), 4);
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 1; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
+
+ Stream.thenCopyD2H(DeviceB5.asSlice().drop_back(1),
+ MutableArrayRef<int>(Host5), 2);
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 2; ++I) {
+ EXPECT_EQ(HostB5[I], Host5[I]);
+ }
+
+ Stream.thenCopyD2H(DeviceA5.asSlice(), MutableArrayRef<int>(Host7), 7);
+ EXPECT_FALSE(Stream.isOK());
+}
+
+TEST_F(StreamTest, CopyD2HSliceToMutableArrayRef) {
+ Stream.thenCopyD2H(DeviceA7.asSlice().slice(1, 5),
+ MutableArrayRef<int>(Host5));
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA7[I + 1], Host5[I]);
+ }
+
+ Stream.thenCopyD2H(DeviceA5.asSlice(), MutableArrayRef<int>(Host7));
+ EXPECT_FALSE(Stream.isOK());
+}
+
+TEST_F(StreamTest, CopyD2HSliceToPointer) {
+ Stream.thenCopyD2H(DeviceA5.asSlice().drop_front(1), Host5 + 1, 4);
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 1; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
+
+ Stream.thenCopyD2H(DeviceA5.asSlice(), Host7, 7);
+ EXPECT_FALSE(Stream.isOK());
+}
+
+// H2D tests
+
+TEST_F(StreamTest, CopyH2DToArrayRefByCount) {
+ Stream.thenCopyH2D(ArrayRef<int>(Host5), DeviceA5, 5);
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
+
+ Stream.thenCopyH2D(ArrayRef<int>(Host5), DeviceB5, 2);
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 2; ++I) {
+ EXPECT_EQ(HostB5[I], Host5[I]);
+ }
+
+ Stream.thenCopyH2D(ArrayRef<int>(Host7), DeviceA5, 7);
+ EXPECT_FALSE(Stream.isOK());
+}
+
+TEST_F(StreamTest, CopyH2DToArrayRef) {
+ Stream.thenCopyH2D(ArrayRef<int>(Host5), DeviceA5);
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
+
+ Stream.thenCopyH2D(ArrayRef<int>(Host7), DeviceA5);
+ EXPECT_FALSE(Stream.isOK());
+}
+
+TEST_F(StreamTest, CopyH2DToPointer) {
+ Stream.thenCopyH2D(Host5, DeviceA5, 5);
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
+
+ Stream.thenCopyH2D(Host7, DeviceA5, 7);
+ EXPECT_FALSE(Stream.isOK());
+}
+
+TEST_F(StreamTest, CopyH2DSliceToArrayRefByCount) {
+ Stream.thenCopyH2D(ArrayRef<int>(Host5 + 1, 4),
+ DeviceA5.asSlice().drop_front(1), 4);
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 1; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
+
+ Stream.thenCopyH2D(ArrayRef<int>(Host5), DeviceB5.asSlice().drop_back(1), 2);
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 2; ++I) {
+ EXPECT_EQ(HostB5[I], Host5[I]);
+ }
+
+ Stream.thenCopyH2D(ArrayRef<int>(Host5), DeviceA5.asSlice(), 7);
+ EXPECT_FALSE(Stream.isOK());
+}
+
+TEST_F(StreamTest, CopyH2DSliceToArrayRef) {
+
+ Stream.thenCopyH2D(ArrayRef<int>(Host5), DeviceA5.asSlice());
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
+
+ Stream.thenCopyH2D(ArrayRef<int>(Host7), DeviceA5.asSlice());
EXPECT_FALSE(Stream.isOK());
}
-TEST_F(StreamTest, MemcpyD2DTooManyElements) {
- Stream.thenMemcpyD2D(DeviceA, &DeviceB, 20);
+TEST_F(StreamTest, CopyH2DSliceToPointer) {
+ Stream.thenCopyH2D(Host5, DeviceA5.asSlice(), 5);
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], Host5[I]);
+ }
+
+ Stream.thenCopyH2D(Host7, DeviceA5.asSlice(), 7);
+ EXPECT_FALSE(Stream.isOK());
+}
+
+// D2D tests
+
+TEST_F(StreamTest, CopyD2DByCount) {
+ Stream.thenCopyD2D(DeviceA5, DeviceB5, 5);
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], HostB5[I]);
+ }
+
+ Stream.thenCopyD2D(DeviceA7, DeviceB7, 2);
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 2; ++I) {
+ EXPECT_EQ(HostA7[I], HostB7[I]);
+ }
+
+ Stream.thenCopyD2D(DeviceA7, DeviceB5, 7);
+ EXPECT_FALSE(Stream.isOK());
+}
+
+TEST_F(StreamTest, CopyD2D) {
+ Stream.thenCopyD2D(DeviceA5, DeviceB5);
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], HostB5[I]);
+ }
+
+ Stream.thenCopyD2D(DeviceA7, DeviceB5);
+ EXPECT_FALSE(Stream.isOK());
+}
+
+TEST_F(StreamTest, CopySliceD2DByCount) {
+ Stream.thenCopyD2D(DeviceA5.asSlice().drop_front(1), DeviceB5, 4);
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 4; ++I) {
+ EXPECT_EQ(HostA5[I + 1], HostB5[I]);
+ }
+
+ Stream.thenCopyD2D(DeviceA7.asSlice().drop_back(1), DeviceB7, 2);
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 2; ++I) {
+ EXPECT_EQ(HostA7[I], HostB7[I]);
+ }
+
+ Stream.thenCopyD2D(DeviceA5.asSlice(), DeviceB5, 7);
+ EXPECT_FALSE(Stream.isOK());
+}
+
+TEST_F(StreamTest, CopySliceD2D) {
+
+ Stream.thenCopyD2D(DeviceA7.asSlice().drop_back(2), DeviceB5);
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA7[I], HostB5[I]);
+ }
+
+ Stream.thenCopyD2D(DeviceA5.asSlice().drop_back(1), DeviceB7);
+ EXPECT_FALSE(Stream.isOK());
+}
+
+TEST_F(StreamTest, CopyD2DSliceByCount) {
+ Stream.thenCopyD2D(DeviceA5, DeviceB7.asSlice().drop_front(2), 5);
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], HostB7[I + 2]);
+ }
+
+ Stream.thenCopyD2D(DeviceA7, DeviceB7.asSlice().drop_back(3), 2);
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 2; ++I) {
+ EXPECT_EQ(HostA7[I], HostB7[I]);
+ }
+
+ Stream.thenCopyD2D(DeviceA5, DeviceB7.asSlice(), 7);
+ EXPECT_FALSE(Stream.isOK());
+}
+
+TEST_F(StreamTest, CopyD2DSlice) {
+
+ Stream.thenCopyD2D(DeviceA5, DeviceB7.asSlice().drop_back(2));
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], HostB7[I]);
+ }
+
+ Stream.thenCopyD2D(DeviceA5, DeviceB7.asSlice());
+ EXPECT_FALSE(Stream.isOK());
+}
+
+TEST_F(StreamTest, CopySliceD2DSliceByCount) {
+
+ Stream.thenCopyD2D(DeviceA5.asSlice(), DeviceB5.asSlice(), 5);
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], HostB5[I]);
+ }
+
+ Stream.thenCopyD2D(DeviceA7.asSlice(), DeviceB7.asSlice(), 2);
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 2; ++I) {
+ EXPECT_EQ(HostA7[I], HostB7[I]);
+ }
+
+ Stream.thenCopyD2D(DeviceA7.asSlice(), DeviceB5.asSlice(), 7);
+ EXPECT_FALSE(Stream.isOK());
+}
+
+TEST_F(StreamTest, CopySliceD2DSlice) {
+
+ Stream.thenCopyD2D(DeviceA5.asSlice(), DeviceB5.asSlice());
+ EXPECT_TRUE(Stream.isOK());
+ for (int I = 0; I < 5; ++I) {
+ EXPECT_EQ(HostA5[I], HostB5[I]);
+ }
+
+ Stream.thenCopyD2D(DeviceA5.asSlice(), DeviceB7.asSlice());
EXPECT_FALSE(Stream.isOK());
}