summaryrefslogtreecommitdiff
path: root/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp')
-rw-r--r--parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp49
1 files changed, 2 insertions, 47 deletions
diff --git a/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp b/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp
index d05c928dcb3..b194bf02082 100644
--- a/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp
+++ b/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp
@@ -14,6 +14,7 @@
#include <cstring>
+#include "SimpleHostPlatformDevice.h"
#include "streamexecutor/Device.h"
#include "streamexecutor/Kernel.h"
#include "streamexecutor/KernelSpec.h"
@@ -26,52 +27,6 @@ namespace {
namespace se = ::streamexecutor;
-/// Mock PlatformDevice that performs asynchronous memcpy operations by
-/// ignoring the stream argument and calling std::memcpy on device memory
-/// handles.
-class MockPlatformDevice : public se::PlatformDevice {
-public:
- ~MockPlatformDevice() override {}
-
- std::string getName() const override { return "MockPlatformDevice"; }
-
- se::Expected<std::unique_ptr<se::PlatformStreamHandle>>
- createStream() override {
- return nullptr;
- }
-
- se::Error copyD2H(se::PlatformStreamHandle *S,
- const se::GlobalDeviceMemoryBase &DeviceSrc,
- size_t SrcByteOffset, void *HostDst, size_t DstByteOffset,
- size_t ByteCount) override {
- std::memcpy(HostDst, static_cast<const char *>(DeviceSrc.getHandle()) +
- SrcByteOffset,
- ByteCount);
- return se::Error::success();
- }
-
- se::Error copyH2D(se::PlatformStreamHandle *S, const void *HostSrc,
- size_t SrcByteOffset, se::GlobalDeviceMemoryBase DeviceDst,
- size_t DstByteOffset, size_t ByteCount) override {
- std::memcpy(static_cast<char *>(const_cast<void *>(DeviceDst.getHandle())) +
- DstByteOffset,
- HostSrc, ByteCount);
- return se::Error::success();
- }
-
- se::Error copyD2D(se::PlatformStreamHandle *S,
- const se::GlobalDeviceMemoryBase &DeviceSrc,
- size_t SrcByteOffset, se::GlobalDeviceMemoryBase DeviceDst,
- size_t DstByteOffset, size_t ByteCount) override {
- std::memcpy(static_cast<char *>(const_cast<void *>(DeviceDst.getHandle())) +
- DstByteOffset,
- static_cast<const char *>(DeviceSrc.getHandle()) +
- SrcByteOffset,
- ByteCount);
- return se::Error::success();
- }
-};
-
/// Test fixture to hold objects used by tests.
class StreamTest : public ::testing::Test {
public:
@@ -100,7 +55,7 @@ protected:
int Host5[5];
int Host7[7];
- MockPlatformDevice PDevice;
+ SimpleHostPlatformDevice PDevice;
se::Stream Stream;
};