summaryrefslogtreecommitdiff
path: root/parallel-libs
diff options
context:
space:
mode:
authorJason Henline <jhen@google.com>2016-09-01 18:48:21 +0000
committerJason Henline <jhen@google.com>2016-09-01 18:48:21 +0000
commitdd2d97ef3d9e355d53b86439f1212a96d9d0381a (patch)
treee82cc870fe84b04e9856b7f24b4c10e78dd1fe6c /parallel-libs
parentb4a49d836bef4d805e03bf671296b394d117282e (diff)
[StreamExecutor] Dev handles in platform interface
Summary: This is the first in a series of patches that will convert GlobalDeviceMemory to own its device memory handle. The first step is to remove GlobalDeviceMemoryBase from the PlatformInterface interfaces and use raw handles there instead. This is useful because GlobalDeviceMemoryBase is going to lose its importance in this process. Reviewers: jlebar Subscribers: jprice, parallel_libs-commits Differential Revision: https://reviews.llvm.org/D24114
Diffstat (limited to 'parallel-libs')
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/Device.h23
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h27
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/Stream.h14
-rw-r--r--parallel-libs/streamexecutor/lib/unittests/DeviceTest.cpp76
-rw-r--r--parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h135
-rw-r--r--parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp49
6 files changed, 171 insertions, 153 deletions
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Device.h b/parallel-libs/streamexecutor/include/streamexecutor/Device.h
index 48ecf22ae76..24937816a75 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/Device.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/Device.h
@@ -56,16 +56,17 @@ public:
/// Allocates an array of ElementCount entries of type T in device memory.
template <typename T>
Expected<GlobalDeviceMemory<T>> allocateDeviceMemory(size_t ElementCount) {
- Expected<GlobalDeviceMemoryBase> MaybeBase =
+ Expected<void *> MaybeMemory =
PDevice->allocateDeviceMemory(ElementCount * sizeof(T));
- if (!MaybeBase)
- return MaybeBase.takeError();
- return GlobalDeviceMemory<T>(*MaybeBase);
+ if (!MaybeMemory)
+ return MaybeMemory.takeError();
+ return GlobalDeviceMemory<T>::makeFromElementCount(*MaybeMemory,
+ ElementCount);
}
/// Frees memory previously allocated with allocateDeviceMemory.
template <typename T> Error freeDeviceMemory(GlobalDeviceMemory<T> Memory) {
- return PDevice->freeDeviceMemory(Memory);
+ return PDevice->freeDeviceMemory(Memory.getHandle());
}
/// Allocates an array of ElementCount entries of type T in host memory.
@@ -140,7 +141,7 @@ public:
return make_error(
"copying too many elements, " + llvm::Twine(ElementCount) +
", to a host array of element count " + llvm::Twine(Dst.size()));
- return PDevice->synchronousCopyD2H(Src.getBaseMemory(),
+ return PDevice->synchronousCopyD2H(Src.getBaseMemory().getHandle(),
Src.getElementOffset() * sizeof(T),
Dst.data(), 0, ElementCount * sizeof(T));
}
@@ -194,9 +195,9 @@ public:
llvm::Twine(ElementCount) +
", to a device array of element count " +
llvm::Twine(Dst.getElementCount()));
- return PDevice->synchronousCopyH2D(Src.data(), 0, Dst.getBaseMemory(),
- Dst.getElementOffset() * sizeof(T),
- ElementCount * sizeof(T));
+ return PDevice->synchronousCopyH2D(
+ Src.data(), 0, Dst.getBaseMemory().getHandle(),
+ Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T));
}
template <typename T>
@@ -250,8 +251,8 @@ public:
", to a device array of element count " +
llvm::Twine(Dst.getElementCount()));
return PDevice->synchronousCopyD2D(
- Src.getBaseMemory(), Src.getElementOffset() * sizeof(T),
- Dst.getBaseMemory(), Dst.getElementOffset() * sizeof(T),
+ Src.getBaseMemory().getHandle(), Src.getElementOffset() * sizeof(T),
+ Dst.getBaseMemory().getHandle(), Dst.getElementOffset() * sizeof(T),
ElementCount * sizeof(T));
}
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h b/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h
index 8fa31b63ef2..b3deff31f50 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h
@@ -94,8 +94,7 @@ public:
///
/// HostDst should have been allocated by allocateHostMemory or registered
/// with registerHostMemory.
- virtual Error copyD2H(PlatformStreamHandle *S,
- const GlobalDeviceMemoryBase &DeviceSrc,
+ virtual Error copyD2H(PlatformStreamHandle *S, const void *DeviceSrcHandle,
size_t SrcByteOffset, void *HostDst,
size_t DstByteOffset, size_t ByteCount) {
return make_error("copyD2H not implemented for platform " + getName());
@@ -106,15 +105,14 @@ public:
/// HostSrc should have been allocated by allocateHostMemory or registered
/// with registerHostMemory.
virtual Error copyH2D(PlatformStreamHandle *S, const void *HostSrc,
- size_t SrcByteOffset, GlobalDeviceMemoryBase DeviceDst,
+ size_t SrcByteOffset, const void *DeviceDstHandle,
size_t DstByteOffset, size_t ByteCount) {
return make_error("copyH2D not implemented for platform " + getName());
}
/// Copies data from one device location to another.
- virtual Error copyD2D(PlatformStreamHandle *S,
- const GlobalDeviceMemoryBase &DeviceSrc,
- size_t SrcByteOffset, GlobalDeviceMemoryBase DeviceDst,
+ virtual Error copyD2D(PlatformStreamHandle *S, const void *DeviceSrcHandle,
+ size_t SrcByteOffset, const void *DeviceDstHandle,
size_t DstByteOffset, size_t ByteCount) {
return make_error("copyD2D not implemented for platform " + getName());
}
@@ -127,14 +125,13 @@ public:
}
/// Allocates untyped device memory of a given size in bytes.
- virtual Expected<GlobalDeviceMemoryBase>
- allocateDeviceMemory(size_t ByteCount) {
+ virtual Expected<void *> allocateDeviceMemory(size_t ByteCount) {
return make_error("allocateDeviceMemory not implemented for platform " +
getName());
}
/// Frees device memory previously allocated by allocateDeviceMemory.
- virtual Error freeDeviceMemory(GlobalDeviceMemoryBase Memory) {
+ virtual Error freeDeviceMemory(const void *Handle) {
return make_error("freeDeviceMemory not implemented for platform " +
getName());
}
@@ -172,29 +169,29 @@ public:
/// Blocks the calling host thread until the copy is completed. Can operate on
/// any host memory, not just registered host memory or host memory allocated
/// by allocateHostMemory. Does not block any ongoing device calls.
- virtual Error synchronousCopyD2H(const GlobalDeviceMemoryBase &DeviceSrc,
+ virtual Error synchronousCopyD2H(const void *DeviceSrcHandle,
size_t SrcByteOffset, void *HostDst,
size_t DstByteOffset, size_t ByteCount) {
return make_error("synchronousCopyD2H not implemented for platform " +
getName());
}
- /// Similar to synchronousCopyD2H(const GlobalDeviceMemoryBase &, size_t, void
+ /// Similar to synchronousCopyD2H(const void *, size_t, void
/// *, size_t, size_t), but copies memory from host to device rather than
/// device to host.
virtual Error synchronousCopyH2D(const void *HostSrc, size_t SrcByteOffset,
- GlobalDeviceMemoryBase DeviceDst,
+ const void *DeviceDstHandle,
size_t DstByteOffset, size_t ByteCount) {
return make_error("synchronousCopyH2D not implemented for platform " +
getName());
}
- /// Similar to synchronousCopyD2H(const GlobalDeviceMemoryBase &, size_t, void
+ /// Similar to synchronousCopyD2H(const void *, size_t, void
/// *, size_t, size_t), but copies memory from one location in device memory
/// to another rather than from device to host.
- virtual Error synchronousCopyD2D(GlobalDeviceMemoryBase DeviceDst,
+ virtual Error synchronousCopyD2D(const void *DeviceDstHandle,
size_t DstByteOffset,
- const GlobalDeviceMemoryBase &DeviceSrc,
+ const void *DeviceSrcHandle,
size_t SrcByteOffset, size_t ByteCount) {
return make_error("synchronousCopyD2D not implemented for platform " +
getName());
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Stream.h b/parallel-libs/streamexecutor/include/streamexecutor/Stream.h
index 1acb18139d8..054b1593aa8 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/Stream.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/Stream.h
@@ -136,7 +136,8 @@ public:
setError("copying too many elements, " + llvm::Twine(ElementCount) +
", to a host array of element count " + llvm::Twine(Dst.size()));
else
- setError(PDevice->copyD2H(ThePlatformStream.get(), Src.getBaseMemory(),
+ setError(PDevice->copyD2H(ThePlatformStream.get(),
+ Src.getBaseMemory().getHandle(),
Src.getElementOffset() * sizeof(T), Dst.data(),
0, ElementCount * sizeof(T)));
return *this;
@@ -193,9 +194,10 @@ public:
", to a device array of element count " +
llvm::Twine(Dst.getElementCount()));
else
- setError(PDevice->copyH2D(
- ThePlatformStream.get(), Src.data(), 0, Dst.getBaseMemory(),
- Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T)));
+ setError(PDevice->copyH2D(ThePlatformStream.get(), Src.data(), 0,
+ Dst.getBaseMemory().getHandle(),
+ Dst.getElementOffset() * sizeof(T),
+ ElementCount * sizeof(T)));
return *this;
}
@@ -250,8 +252,8 @@ public:
llvm::Twine(Dst.getElementCount()));
else
setError(PDevice->copyD2D(
- ThePlatformStream.get(), Src.getBaseMemory(),
- Src.getElementOffset() * sizeof(T), Dst.getBaseMemory(),
+ ThePlatformStream.get(), Src.getBaseMemory().getHandle(),
+ Src.getElementOffset() * sizeof(T), Dst.getBaseMemory().getHandle(),
Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T)));
return *this;
}
diff --git a/parallel-libs/streamexecutor/lib/unittests/DeviceTest.cpp b/parallel-libs/streamexecutor/lib/unittests/DeviceTest.cpp
index cb34b8b92d5..93d378fff5a 100644
--- a/parallel-libs/streamexecutor/lib/unittests/DeviceTest.cpp
+++ b/parallel-libs/streamexecutor/lib/unittests/DeviceTest.cpp
@@ -15,6 +15,7 @@
#include <cstdlib>
#include <cstring>
+#include "SimpleHostPlatformDevice.h"
#include "streamexecutor/Device.h"
#include "streamexecutor/PlatformInterfaces.h"
@@ -24,79 +25,6 @@ namespace {
namespace se = ::streamexecutor;
-class MockPlatformDevice : public se::PlatformDevice {
-public:
- ~MockPlatformDevice() override {}
-
- std::string getName() const override { return "MockPlatformDevice"; }
-
- se::Expected<std::unique_ptr<se::PlatformStreamHandle>>
- createStream() override {
- return se::make_error("not implemented");
- }
-
- se::Expected<se::GlobalDeviceMemoryBase>
- allocateDeviceMemory(size_t ByteCount) override {
- return se::GlobalDeviceMemoryBase(std::malloc(ByteCount));
- }
-
- se::Error freeDeviceMemory(se::GlobalDeviceMemoryBase Memory) override {
- std::free(const_cast<void *>(Memory.getHandle()));
- return se::Error::success();
- }
-
- se::Expected<void *> allocateHostMemory(size_t ByteCount) override {
- return std::malloc(ByteCount);
- }
-
- se::Error freeHostMemory(void *Memory) override {
- std::free(Memory);
- return se::Error::success();
- }
-
- se::Error registerHostMemory(void *, size_t) override {
- return se::Error::success();
- }
-
- se::Error unregisterHostMemory(void *) override {
- return se::Error::success();
- }
-
- se::Error synchronousCopyD2H(const se::GlobalDeviceMemoryBase &DeviceSrc,
- size_t SrcByteOffset, void *HostDst,
- size_t DstByteOffset,
- size_t ByteCount) override {
- std::memcpy(static_cast<char *>(HostDst) + DstByteOffset,
- static_cast<const char *>(DeviceSrc.getHandle()) +
- SrcByteOffset,
- ByteCount);
- return se::Error::success();
- }
-
- se::Error synchronousCopyH2D(const void *HostSrc, size_t SrcByteOffset,
- se::GlobalDeviceMemoryBase DeviceDst,
- size_t DstByteOffset,
- size_t ByteCount) override {
- std::memcpy(static_cast<char *>(const_cast<void *>(DeviceDst.getHandle())) +
- DstByteOffset,
- static_cast<const char *>(HostSrc) + SrcByteOffset, ByteCount);
- return se::Error::success();
- }
-
- se::Error synchronousCopyD2D(se::GlobalDeviceMemoryBase DeviceDst,
- size_t DstByteOffset,
- const se::GlobalDeviceMemoryBase &DeviceSrc,
- size_t SrcByteOffset,
- size_t ByteCount) override {
- std::memcpy(static_cast<char *>(const_cast<void *>(DeviceDst.getHandle())) +
- DstByteOffset,
- static_cast<const char *>(DeviceSrc.getHandle()) +
- SrcByteOffset,
- ByteCount);
- return se::Error::success();
- }
-};
-
/// Test fixture to hold objects used by tests.
class DeviceTest : public ::testing::Test {
public:
@@ -124,7 +52,7 @@ public:
int Host5[5];
int Host7[7];
- MockPlatformDevice PDevice;
+ SimpleHostPlatformDevice PDevice;
se::Device Device;
};
diff --git a/parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h b/parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h
new file mode 100644
index 00000000000..a2dd3c8738f
--- /dev/null
+++ b/parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h
@@ -0,0 +1,135 @@
+//===-- SimpleHostPlatformDevice.h - Host device for testing ----*- C++ -*-===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// The SimpleHostPlatformDevice class is a streamexecutor::PlatformDevice that
+/// is really just the host processor and memory. It is useful for testing
+/// because no extra device platform is required.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef STREAMEXECUTOR_LIB_UNITTESTS_SIMPLEHOSTPLATFORMDEVICE_H
+#define STREAMEXECUTOR_LIB_UNITTESTS_SIMPLEHOSTPLATFORMDEVICE_H
+
+#include <cstdlib>
+#include <cstring>
+
+#include "streamexecutor/PlatformInterfaces.h"
+
+/// A streamexecutor::PlatformDevice that simply forwards all operations to the
+/// host platform.
+///
+/// The allocate and copy methods are simple wrappers for std::malloc and
+/// std::memcpy.
+class SimpleHostPlatformDevice : public streamexecutor::PlatformDevice {
+ std::string getName() const override { return "SimpleHostPlatformDevice"; }
+
+ streamexecutor::Expected<
+ std::unique_ptr<streamexecutor::PlatformStreamHandle>>
+ createStream() override {
+ return nullptr;
+ }
+
+ streamexecutor::Expected<void *>
+ allocateDeviceMemory(size_t ByteCount) override {
+ return std::malloc(ByteCount);
+ }
+
+ streamexecutor::Error freeDeviceMemory(const void *Handle) override {
+ std::free(const_cast<void *>(Handle));
+ return streamexecutor::Error::success();
+ }
+
+ streamexecutor::Expected<void *>
+ allocateHostMemory(size_t ByteCount) override {
+ return std::malloc(ByteCount);
+ }
+
+ streamexecutor::Error freeHostMemory(void *Memory) override {
+ std::free(const_cast<void *>(Memory));
+ return streamexecutor::Error::success();
+ }
+
+ streamexecutor::Error registerHostMemory(void *Memory,
+ size_t ByteCount) override {
+ return streamexecutor::Error::success();
+ }
+
+ streamexecutor::Error unregisterHostMemory(void *Memory) override {
+ return streamexecutor::Error::success();
+ }
+
+ streamexecutor::Error copyD2H(streamexecutor::PlatformStreamHandle *S,
+ const void *DeviceHandleSrc,
+ size_t SrcByteOffset, void *HostDst,
+ size_t DstByteOffset,
+ size_t ByteCount) override {
+ std::memcpy(static_cast<char *>(HostDst) + DstByteOffset,
+ static_cast<const char *>(DeviceHandleSrc) + SrcByteOffset,
+ ByteCount);
+ return streamexecutor::Error::success();
+ }
+
+ streamexecutor::Error copyH2D(streamexecutor::PlatformStreamHandle *S,
+ const void *HostSrc, size_t SrcByteOffset,
+ const void *DeviceHandleDst,
+ size_t DstByteOffset,
+ size_t ByteCount) override {
+ std::memcpy(static_cast<char *>(const_cast<void *>(DeviceHandleDst)) +
+ DstByteOffset,
+ static_cast<const char *>(HostSrc) + SrcByteOffset, ByteCount);
+ return streamexecutor::Error::success();
+ }
+
+ streamexecutor::Error
+ copyD2D(streamexecutor::PlatformStreamHandle *S, const void *DeviceHandleSrc,
+ size_t SrcByteOffset, const void *DeviceHandleDst,
+ size_t DstByteOffset, size_t ByteCount) override {
+ std::memcpy(static_cast<char *>(const_cast<void *>(DeviceHandleDst)) +
+ DstByteOffset,
+ static_cast<const char *>(DeviceHandleSrc) + SrcByteOffset,
+ ByteCount);
+ return streamexecutor::Error::success();
+ }
+
+ streamexecutor::Error synchronousCopyD2H(const void *DeviceHandleSrc,
+ size_t SrcByteOffset, void *HostDst,
+ size_t DstByteOffset,
+ size_t ByteCount) override {
+ std::memcpy(static_cast<char *>(HostDst) + DstByteOffset,
+ static_cast<const char *>(DeviceHandleSrc) + SrcByteOffset,
+ ByteCount);
+ return streamexecutor::Error::success();
+ }
+
+ streamexecutor::Error synchronousCopyH2D(const void *HostSrc,
+ size_t SrcByteOffset,
+ const void *DeviceHandleDst,
+ size_t DstByteOffset,
+ size_t ByteCount) override {
+ std::memcpy(static_cast<char *>(const_cast<void *>(DeviceHandleDst)) +
+ DstByteOffset,
+ static_cast<const char *>(HostSrc) + SrcByteOffset, ByteCount);
+ return streamexecutor::Error::success();
+ }
+
+ streamexecutor::Error synchronousCopyD2D(const void *DeviceHandleSrc,
+ size_t SrcByteOffset,
+ const void *DeviceHandleDst,
+ size_t DstByteOffset,
+ size_t ByteCount) override {
+ std::memcpy(static_cast<char *>(const_cast<void *>(DeviceHandleDst)) +
+ DstByteOffset,
+ static_cast<const char *>(DeviceHandleSrc) + SrcByteOffset,
+ ByteCount);
+ return streamexecutor::Error::success();
+ }
+};
+
+#endif // STREAMEXECUTOR_LIB_UNITTESTS_SIMPLEHOSTPLATFORMDEVICE_H
diff --git a/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp b/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp
index d05c928dcb3..b194bf02082 100644
--- a/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp
+++ b/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp
@@ -14,6 +14,7 @@
#include <cstring>
+#include "SimpleHostPlatformDevice.h"
#include "streamexecutor/Device.h"
#include "streamexecutor/Kernel.h"
#include "streamexecutor/KernelSpec.h"
@@ -26,52 +27,6 @@ namespace {
namespace se = ::streamexecutor;
-/// Mock PlatformDevice that performs asynchronous memcpy operations by
-/// ignoring the stream argument and calling std::memcpy on device memory
-/// handles.
-class MockPlatformDevice : public se::PlatformDevice {
-public:
- ~MockPlatformDevice() override {}
-
- std::string getName() const override { return "MockPlatformDevice"; }
-
- se::Expected<std::unique_ptr<se::PlatformStreamHandle>>
- createStream() override {
- return nullptr;
- }
-
- se::Error copyD2H(se::PlatformStreamHandle *S,
- const se::GlobalDeviceMemoryBase &DeviceSrc,
- size_t SrcByteOffset, void *HostDst, size_t DstByteOffset,
- size_t ByteCount) override {
- std::memcpy(HostDst, static_cast<const char *>(DeviceSrc.getHandle()) +
- SrcByteOffset,
- ByteCount);
- return se::Error::success();
- }
-
- se::Error copyH2D(se::PlatformStreamHandle *S, const void *HostSrc,
- size_t SrcByteOffset, se::GlobalDeviceMemoryBase DeviceDst,
- size_t DstByteOffset, size_t ByteCount) override {
- std::memcpy(static_cast<char *>(const_cast<void *>(DeviceDst.getHandle())) +
- DstByteOffset,
- HostSrc, ByteCount);
- return se::Error::success();
- }
-
- se::Error copyD2D(se::PlatformStreamHandle *S,
- const se::GlobalDeviceMemoryBase &DeviceSrc,
- size_t SrcByteOffset, se::GlobalDeviceMemoryBase DeviceDst,
- size_t DstByteOffset, size_t ByteCount) override {
- std::memcpy(static_cast<char *>(const_cast<void *>(DeviceDst.getHandle())) +
- DstByteOffset,
- static_cast<const char *>(DeviceSrc.getHandle()) +
- SrcByteOffset,
- ByteCount);
- return se::Error::success();
- }
-};
-
/// Test fixture to hold objects used by tests.
class StreamTest : public ::testing::Test {
public:
@@ -100,7 +55,7 @@ protected:
int Host5[5];
int Host7[7];
- MockPlatformDevice PDevice;
+ SimpleHostPlatformDevice PDevice;
se::Stream Stream;
};