diff options
Diffstat (limited to 'parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp')
-rw-r--r-- | parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp | 49 |
1 files changed, 2 insertions, 47 deletions
diff --git a/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp b/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp index d05c928dcb3..b194bf02082 100644 --- a/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp +++ b/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp @@ -14,6 +14,7 @@ #include <cstring> +#include "SimpleHostPlatformDevice.h" #include "streamexecutor/Device.h" #include "streamexecutor/Kernel.h" #include "streamexecutor/KernelSpec.h" @@ -26,52 +27,6 @@ namespace { namespace se = ::streamexecutor; -/// Mock PlatformDevice that performs asynchronous memcpy operations by -/// ignoring the stream argument and calling std::memcpy on device memory -/// handles. -class MockPlatformDevice : public se::PlatformDevice { -public: - ~MockPlatformDevice() override {} - - std::string getName() const override { return "MockPlatformDevice"; } - - se::Expected<std::unique_ptr<se::PlatformStreamHandle>> - createStream() override { - return nullptr; - } - - se::Error copyD2H(se::PlatformStreamHandle *S, - const se::GlobalDeviceMemoryBase &DeviceSrc, - size_t SrcByteOffset, void *HostDst, size_t DstByteOffset, - size_t ByteCount) override { - std::memcpy(HostDst, static_cast<const char *>(DeviceSrc.getHandle()) + - SrcByteOffset, - ByteCount); - return se::Error::success(); - } - - se::Error copyH2D(se::PlatformStreamHandle *S, const void *HostSrc, - size_t SrcByteOffset, se::GlobalDeviceMemoryBase DeviceDst, - size_t DstByteOffset, size_t ByteCount) override { - std::memcpy(static_cast<char *>(const_cast<void *>(DeviceDst.getHandle())) + - DstByteOffset, - HostSrc, ByteCount); - return se::Error::success(); - } - - se::Error copyD2D(se::PlatformStreamHandle *S, - const se::GlobalDeviceMemoryBase &DeviceSrc, - size_t SrcByteOffset, se::GlobalDeviceMemoryBase DeviceDst, - size_t DstByteOffset, size_t ByteCount) override { - std::memcpy(static_cast<char *>(const_cast<void *>(DeviceDst.getHandle())) + - DstByteOffset, - static_cast<const char *>(DeviceSrc.getHandle()) + - SrcByteOffset, - ByteCount); - return se::Error::success(); - } -}; - /// Test fixture to hold objects used by tests. class StreamTest : public ::testing::Test { public: @@ -100,7 +55,7 @@ protected: int Host5[5]; int Host7[7]; - MockPlatformDevice PDevice; + SimpleHostPlatformDevice PDevice; se::Stream Stream; }; |