summaryrefslogtreecommitdiff
path: root/parallel-libs
diff options
context:
space:
mode:
authorJason Henline <jhen@google.com>2016-09-06 17:07:22 +0000
committerJason Henline <jhen@google.com>2016-09-06 17:07:22 +0000
commitc092db2acabdd22c647addffdb6e9506555bc2b3 (patch)
tree7eb53c9fc170514b9d12dc55c29192774df3ad00 /parallel-libs
parent4702a886ba000ca1331ad51e06244a2c91784d17 (diff)
[SE] Remove Platform*Handle classes
Summary: As pointed out by jprice, these classes don't serve a purpose. Instead, we stay consistent with the way memory is managed and let the Stream and Kernel classes directly hold opaque handles to device Stream and Kernel instances, respectively. Reviewers: jprice, jlebar Subscribers: parallel_libs-commits Differential Revision: https://reviews.llvm.org/D24213
Diffstat (limited to 'parallel-libs')
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/Device.h5
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/Kernel.h28
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h64
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/Stream.h29
-rw-r--r--parallel-libs/streamexecutor/lib/Device.cpp7
-rw-r--r--parallel-libs/streamexecutor/lib/Kernel.cpp41
-rw-r--r--parallel-libs/streamexecutor/lib/PlatformInterfaces.cpp2
-rw-r--r--parallel-libs/streamexecutor/lib/Stream.cpp37
-rw-r--r--parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h12
-rw-r--r--parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp12
10 files changed, 142 insertions, 95 deletions
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Device.h b/parallel-libs/streamexecutor/include/streamexecutor/Device.h
index 95d9b5c62fb..0ee2b2fbc0b 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/Device.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/Device.h
@@ -35,12 +35,11 @@ public:
Expected<typename std::enable_if<std::is_base_of<KernelBase, KernelT>::value,
KernelT>::type>
createKernel(const MultiKernelLoaderSpec &Spec) {
- Expected<std::unique_ptr<PlatformKernelHandle>> MaybeKernelHandle =
- PDevice->createKernel(Spec);
+ Expected<const void *> MaybeKernelHandle = PDevice->createKernel(Spec);
if (!MaybeKernelHandle) {
return MaybeKernelHandle.takeError();
}
- return KernelT(Spec.getKernelName(), std::move(*MaybeKernelHandle));
+ return KernelT(PDevice, *MaybeKernelHandle, Spec.getKernelName());
}
/// Creates a stream object for this device.
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Kernel.h b/parallel-libs/streamexecutor/include/streamexecutor/Kernel.h
index c9b4180afee..6ea7c361803 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/Kernel.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/Kernel.h
@@ -28,19 +28,32 @@
namespace streamexecutor {
-class PlatformKernelHandle;
+class PlatformDevice;
/// The base class for all kernel types.
///
/// Stores the name of the kernel in both mangled and demangled forms.
class KernelBase {
public:
- KernelBase(llvm::StringRef Name);
+ KernelBase(PlatformDevice *D, const void *PlatformKernelHandle,
+ llvm::StringRef Name);
+ KernelBase(const KernelBase &Other) = delete;
+ KernelBase &operator=(const KernelBase &Other) = delete;
+
+ KernelBase(KernelBase &&Other);
+ KernelBase &operator=(KernelBase &&Other);
+
+ ~KernelBase();
+
+ const void *getPlatformHandle() const { return PlatformKernelHandle; }
const std::string &getName() const { return Name; }
const std::string &getDemangledName() const { return DemangledName; }
private:
+ PlatformDevice *PDevice;
+ const void *PlatformKernelHandle;
+
std::string Name;
std::string DemangledName;
};
@@ -51,17 +64,12 @@ private:
/// function.
template <typename... ParameterTs> class Kernel : public KernelBase {
public:
- Kernel(llvm::StringRef Name, std::unique_ptr<PlatformKernelHandle> PHandle)
- : KernelBase(Name), PHandle(std::move(PHandle)) {}
+ Kernel(PlatformDevice *D, const void *PlatformKernelHandle,
+ llvm::StringRef Name)
+ : KernelBase(D, PlatformKernelHandle, Name) {}
Kernel(Kernel &&Other) = default;
Kernel &operator=(Kernel &&Other) = default;
-
- /// Gets the underlying platform-specific handle for this kernel.
- PlatformKernelHandle *getPlatformHandle() const { return PHandle.get(); }
-
-private:
- std::unique_ptr<PlatformKernelHandle> PHandle;
};
} // namespace streamexecutor
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h b/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h
index b3deff31f50..946f8f96a94 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h
@@ -31,34 +31,6 @@
namespace streamexecutor {
-class PlatformDevice;
-
-/// Platform-specific kernel handle.
-class PlatformKernelHandle {
-public:
- explicit PlatformKernelHandle(PlatformDevice *PDevice) : PDevice(PDevice) {}
-
- virtual ~PlatformKernelHandle();
-
- PlatformDevice *getDevice() { return PDevice; }
-
-private:
- PlatformDevice *PDevice;
-};
-
-/// Platform-specific stream handle.
-class PlatformStreamHandle {
-public:
- explicit PlatformStreamHandle(PlatformDevice *PDevice) : PDevice(PDevice) {}
-
- virtual ~PlatformStreamHandle();
-
- PlatformDevice *getDevice() { return PDevice; }
-
-private:
- PlatformDevice *PDevice;
-};
-
/// Raw executor methods that must be implemented by each platform.
///
/// This class defines the platform interface that supports executing work on a
@@ -73,19 +45,30 @@ public:
virtual std::string getName() const = 0;
/// Creates a platform-specific kernel.
- virtual Expected<std::unique_ptr<PlatformKernelHandle>>
+ virtual Expected<const void *>
createKernel(const MultiKernelLoaderSpec &Spec) {
return make_error("createKernel not implemented for platform " + getName());
}
+ virtual Error destroyKernel(const void *Handle) {
+ return make_error("destroyKernel not implemented for platform " +
+ getName());
+ }
+
/// Creates a platform-specific stream.
- virtual Expected<std::unique_ptr<PlatformStreamHandle>> createStream() {
+ virtual Expected<const void *> createStream() {
return make_error("createStream not implemented for platform " + getName());
}
+ virtual Error destroyStream(const void *Handle) {
+ return make_error("destroyStream not implemented for platform " +
+ getName());
+ }
+
/// Launches a kernel on the given stream.
- virtual Error launch(PlatformStreamHandle *S, BlockDimensions BlockSize,
- GridDimensions GridSize, PlatformKernelHandle *K,
+ virtual Error launch(const void *PlatformStreamHandle,
+ BlockDimensions BlockSize, GridDimensions GridSize,
+ const void *PKernelHandle,
const PackedKernelArgumentArrayBase &ArgumentArray) {
return make_error("launch not implemented for platform " + getName());
}
@@ -94,9 +77,9 @@ public:
///
/// HostDst should have been allocated by allocateHostMemory or registered
/// with registerHostMemory.
- virtual Error copyD2H(PlatformStreamHandle *S, const void *DeviceSrcHandle,
- size_t SrcByteOffset, void *HostDst,
- size_t DstByteOffset, size_t ByteCount) {
+ virtual Error copyD2H(const void *PlatformStreamHandle,
+ const void *DeviceSrcHandle, size_t SrcByteOffset,
+ void *HostDst, size_t DstByteOffset, size_t ByteCount) {
return make_error("copyD2H not implemented for platform " + getName());
}
@@ -104,22 +87,23 @@ public:
///
/// HostSrc should have been allocated by allocateHostMemory or registered
/// with registerHostMemory.
- virtual Error copyH2D(PlatformStreamHandle *S, const void *HostSrc,
+ virtual Error copyH2D(const void *PlatformStreamHandle, const void *HostSrc,
size_t SrcByteOffset, const void *DeviceDstHandle,
size_t DstByteOffset, size_t ByteCount) {
return make_error("copyH2D not implemented for platform " + getName());
}
/// Copies data from one device location to another.
- virtual Error copyD2D(PlatformStreamHandle *S, const void *DeviceSrcHandle,
- size_t SrcByteOffset, const void *DeviceDstHandle,
- size_t DstByteOffset, size_t ByteCount) {
+ virtual Error copyD2D(const void *PlatformStreamHandle,
+ const void *DeviceSrcHandle, size_t SrcByteOffset,
+ const void *DeviceDstHandle, size_t DstByteOffset,
+ size_t ByteCount) {
return make_error("copyD2D not implemented for platform " + getName());
}
/// Blocks the host until the given stream completes all the work enqueued up
/// to the point this function is called.
- virtual Error blockHostUntilDone(PlatformStreamHandle *S) {
+ virtual Error blockHostUntilDone(const void *PlatformStreamHandle) {
return make_error("blockHostUntilDone not implemented for platform " +
getName());
}
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Stream.h b/parallel-libs/streamexecutor/include/streamexecutor/Stream.h
index 81f9ada7792..48dcf32368a 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/Stream.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/Stream.h
@@ -59,10 +59,13 @@ namespace streamexecutor {
/// of a stream once it is in an error state.
class Stream {
public:
- explicit Stream(std::unique_ptr<PlatformStreamHandle> PStream);
+ Stream(PlatformDevice *D, const void *PlatformStreamHandle);
- Stream(Stream &&Other) = default;
- Stream &operator=(Stream &&Other) = default;
+ Stream(const Stream &Other) = delete;
+ Stream &operator=(const Stream &Other) = delete;
+
+ Stream(Stream &&Other);
+ Stream &operator=(Stream &&Other);
~Stream();
@@ -88,7 +91,7 @@ public:
//
// Returns the result of getStatus() after the Stream work completes.
Error blockHostUntilDone() {
- setError(PDevice->blockHostUntilDone(ThePlatformStream.get()));
+ setError(PDevice->blockHostUntilDone(PlatformStreamHandle));
return getStatus();
}
@@ -105,7 +108,7 @@ public:
const ParameterTs &... Arguments) {
auto ArgumentArray =
make_kernel_argument_pack<ParameterTs...>(Arguments...);
- setError(PDevice->launch(ThePlatformStream.get(), BlockSize, GridSize,
+ setError(PDevice->launch(PlatformStreamHandle, BlockSize, GridSize,
K.getPlatformHandle(), ArgumentArray));
return *this;
}
@@ -136,7 +139,7 @@ public:
setError("copying too many elements, " + llvm::Twine(ElementCount) +
", to a host array of element count " + llvm::Twine(Dst.size()));
else
- setError(PDevice->copyD2H(ThePlatformStream.get(),
+ setError(PDevice->copyD2H(PlatformStreamHandle,
Src.getBaseMemory().getHandle(),
Src.getElementOffset() * sizeof(T), Dst.data(),
0, ElementCount * sizeof(T)));
@@ -196,10 +199,9 @@ public:
", to a device array of element count " +
llvm::Twine(Dst.getElementCount()));
else
- setError(PDevice->copyH2D(ThePlatformStream.get(), Src.data(), 0,
- Dst.getBaseMemory().getHandle(),
- Dst.getElementOffset() * sizeof(T),
- ElementCount * sizeof(T)));
+ setError(PDevice->copyH2D(
+ PlatformStreamHandle, Src.data(), 0, Dst.getBaseMemory().getHandle(),
+ Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T)));
return *this;
}
@@ -254,7 +256,7 @@ public:
llvm::Twine(Dst.getElementCount()));
else
setError(PDevice->copyD2D(
- ThePlatformStream.get(), Src.getBaseMemory().getHandle(),
+ PlatformStreamHandle, Src.getBaseMemory().getHandle(),
Src.getElementOffset() * sizeof(T), Dst.getBaseMemory().getHandle(),
Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T)));
return *this;
@@ -342,7 +344,7 @@ private:
PlatformDevice *PDevice;
/// The platform-specific stream handle for this instance.
- std::unique_ptr<PlatformStreamHandle> ThePlatformStream;
+ const void *PlatformStreamHandle;
/// Mutex that guards the error state flags.
std::unique_ptr<llvm::sys::RWMutex> ErrorMessageMutex;
@@ -350,9 +352,6 @@ private:
/// First error message for an operation in this stream or empty if there have
/// been no errors.
llvm::Optional<std::string> ErrorMessage;
-
- Stream(const Stream &) = delete;
- void operator=(const Stream &) = delete;
};
} // namespace streamexecutor
diff --git a/parallel-libs/streamexecutor/lib/Device.cpp b/parallel-libs/streamexecutor/lib/Device.cpp
index 54f03849c68..0d81fb78e2d 100644
--- a/parallel-libs/streamexecutor/lib/Device.cpp
+++ b/parallel-libs/streamexecutor/lib/Device.cpp
@@ -28,14 +28,11 @@ Device::Device(PlatformDevice *PDevice) : PDevice(PDevice) {}
Device::~Device() = default;
Expected<Stream> Device::createStream() {
- Expected<std::unique_ptr<PlatformStreamHandle>> MaybePlatformStream =
- PDevice->createStream();
+ Expected<const void *> MaybePlatformStream = PDevice->createStream();
if (!MaybePlatformStream) {
return MaybePlatformStream.takeError();
}
- assert((*MaybePlatformStream)->getDevice() == PDevice &&
- "an executor created a stream with a different stored executor");
- return Stream(std::move(*MaybePlatformStream));
+ return Stream(PDevice, *MaybePlatformStream);
}
} // namespace streamexecutor
diff --git a/parallel-libs/streamexecutor/lib/Kernel.cpp b/parallel-libs/streamexecutor/lib/Kernel.cpp
index 1f4218c4df3..61305372f18 100644
--- a/parallel-libs/streamexecutor/lib/Kernel.cpp
+++ b/parallel-libs/streamexecutor/lib/Kernel.cpp
@@ -12,16 +12,49 @@
///
//===----------------------------------------------------------------------===//
-#include "streamexecutor/Kernel.h"
+#include <cassert>
+
#include "streamexecutor/Device.h"
+#include "streamexecutor/Kernel.h"
#include "streamexecutor/PlatformInterfaces.h"
#include "llvm/DebugInfo/Symbolize/Symbolize.h"
namespace streamexecutor {
-KernelBase::KernelBase(llvm::StringRef Name)
- : Name(Name), DemangledName(llvm::symbolize::LLVMSymbolizer::DemangleName(
- Name, nullptr)) {}
+KernelBase::KernelBase(PlatformDevice *D, const void *PlatformKernelHandle,
+ llvm::StringRef Name)
+ : PDevice(D), PlatformKernelHandle(PlatformKernelHandle), Name(Name),
+ DemangledName(
+ llvm::symbolize::LLVMSymbolizer::DemangleName(Name, nullptr)) {
+ assert(D != nullptr &&
+ "cannot construct a kernel object with a null platform device");
+ assert(PlatformKernelHandle != nullptr &&
+ "cannot construct a kernel object with a null platform kernel handle");
+}
+
+KernelBase::KernelBase(KernelBase &&Other)
+ : PDevice(Other.PDevice), PlatformKernelHandle(Other.PlatformKernelHandle),
+ Name(std::move(Other.Name)),
+ DemangledName(std::move(Other.DemangledName)) {
+ Other.PDevice = nullptr;
+ Other.PlatformKernelHandle = nullptr;
+}
+
+KernelBase &KernelBase::operator=(KernelBase &&Other) {
+ PDevice = Other.PDevice;
+ PlatformKernelHandle = Other.PlatformKernelHandle;
+ Name = std::move(Other.Name);
+ DemangledName = std::move(Other.DemangledName);
+ Other.PDevice = nullptr;
+ Other.PlatformKernelHandle = nullptr;
+ return *this;
+}
+
+KernelBase::~KernelBase() {
+ if (PlatformKernelHandle)
+ // TODO(jhen): Handle the error here.
+ consumeError(PDevice->destroyKernel(PlatformKernelHandle));
+}
} // namespace streamexecutor
diff --git a/parallel-libs/streamexecutor/lib/PlatformInterfaces.cpp b/parallel-libs/streamexecutor/lib/PlatformInterfaces.cpp
index 770cd170c4f..e9378b519df 100644
--- a/parallel-libs/streamexecutor/lib/PlatformInterfaces.cpp
+++ b/parallel-libs/streamexecutor/lib/PlatformInterfaces.cpp
@@ -16,8 +16,6 @@
namespace streamexecutor {
-PlatformStreamHandle::~PlatformStreamHandle() = default;
-
PlatformDevice::~PlatformDevice() = default;
} // namespace streamexecutor
diff --git a/parallel-libs/streamexecutor/lib/Stream.cpp b/parallel-libs/streamexecutor/lib/Stream.cpp
index e1fca58cc19..96aad044c9c 100644
--- a/parallel-libs/streamexecutor/lib/Stream.cpp
+++ b/parallel-libs/streamexecutor/lib/Stream.cpp
@@ -12,14 +12,43 @@
///
//===----------------------------------------------------------------------===//
+#include <cassert>
+
#include "streamexecutor/Stream.h"
namespace streamexecutor {
-Stream::Stream(std::unique_ptr<PlatformStreamHandle> PStream)
- : PDevice(PStream->getDevice()), ThePlatformStream(std::move(PStream)),
- ErrorMessageMutex(llvm::make_unique<llvm::sys::RWMutex>()) {}
+Stream::Stream(PlatformDevice *D, const void *PlatformStreamHandle)
+ : PDevice(D), PlatformStreamHandle(PlatformStreamHandle),
+ ErrorMessageMutex(llvm::make_unique<llvm::sys::RWMutex>()) {
+ assert(D != nullptr &&
+ "cannot construct a stream object with a null platform device");
+ assert(PlatformStreamHandle != nullptr &&
+ "cannot construct a stream object with a null platform stream handle");
+}
+
+Stream::Stream(Stream &&Other)
+ : PDevice(Other.PDevice), PlatformStreamHandle(Other.PlatformStreamHandle),
+ ErrorMessageMutex(std::move(Other.ErrorMessageMutex)),
+ ErrorMessage(std::move(Other.ErrorMessage)) {
+ Other.PDevice = nullptr;
+ Other.PlatformStreamHandle = nullptr;
+}
+
+Stream &Stream::operator=(Stream &&Other) {
+ PDevice = Other.PDevice;
+ PlatformStreamHandle = Other.PlatformStreamHandle;
+ ErrorMessageMutex = std::move(Other.ErrorMessageMutex);
+ ErrorMessage = std::move(Other.ErrorMessage);
+ Other.PDevice = nullptr;
+ Other.PlatformStreamHandle = nullptr;
+ return *this;
+}
-Stream::~Stream() = default;
+Stream::~Stream() {
+ if (PlatformStreamHandle)
+ // TODO(jhen): Handle error condition here.
+ consumeError(PDevice->destroyStream(PlatformStreamHandle));
+}
} // namespace streamexecutor
diff --git a/parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h b/parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h
index 184c2d7f273..b54b31dd457 100644
--- a/parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h
+++ b/parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h
@@ -34,9 +34,7 @@ class SimpleHostPlatformDevice : public streamexecutor::PlatformDevice {
public:
std::string getName() const override { return "SimpleHostPlatformDevice"; }
- streamexecutor::Expected<
- std::unique_ptr<streamexecutor::PlatformStreamHandle>>
- createStream() override {
+ streamexecutor::Expected<const void *> createStream() override {
return nullptr;
}
@@ -69,7 +67,7 @@ public:
return streamexecutor::Error::success();
}
- streamexecutor::Error copyD2H(streamexecutor::PlatformStreamHandle *S,
+ streamexecutor::Error copyD2H(const void *StreamHandle,
const void *DeviceHandleSrc,
size_t SrcByteOffset, void *HostDst,
size_t DstByteOffset,
@@ -80,8 +78,8 @@ public:
return streamexecutor::Error::success();
}
- streamexecutor::Error copyH2D(streamexecutor::PlatformStreamHandle *S,
- const void *HostSrc, size_t SrcByteOffset,
+ streamexecutor::Error copyH2D(const void *StreamHandle, const void *HostSrc,
+ size_t SrcByteOffset,
const void *DeviceHandleDst,
size_t DstByteOffset,
size_t ByteCount) override {
@@ -92,7 +90,7 @@ public:
}
streamexecutor::Error
- copyD2D(streamexecutor::PlatformStreamHandle *S, const void *DeviceHandleSrc,
+ copyD2D(const void *StreamHandle, const void *DeviceHandleSrc,
size_t SrcByteOffset, const void *DeviceHandleDst,
size_t DstByteOffset, size_t ByteCount) override {
std::memcpy(static_cast<char *>(const_cast<void *>(DeviceHandleDst)) +
diff --git a/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp b/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp
index 4f42bbe8e72..3a0f4e6fdd2 100644
--- a/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp
+++ b/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp
@@ -34,11 +34,11 @@ const auto &getDeviceValue =
class StreamTest : public ::testing::Test {
public:
StreamTest()
- : Device(&PDevice),
- Stream(llvm::make_unique<se::PlatformStreamHandle>(&PDevice)),
- HostA5{0, 1, 2, 3, 4}, HostB5{5, 6, 7, 8, 9},
- HostA7{10, 11, 12, 13, 14, 15, 16}, HostB7{17, 18, 19, 20, 21, 22, 23},
- Host5{24, 25, 26, 27, 28}, Host7{29, 30, 31, 32, 33, 34, 35},
+ : DummyPlatformStream(1), Device(&PDevice),
+ Stream(&PDevice, &DummyPlatformStream), HostA5{0, 1, 2, 3, 4},
+ HostB5{5, 6, 7, 8, 9}, HostA7{10, 11, 12, 13, 14, 15, 16},
+ HostB7{17, 18, 19, 20, 21, 22, 23}, Host5{24, 25, 26, 27, 28},
+ Host7{29, 30, 31, 32, 33, 34, 35},
DeviceA5(getOrDie(Device.allocateDeviceMemory<int>(5))),
DeviceB5(getOrDie(Device.allocateDeviceMemory<int>(5))),
DeviceA7(getOrDie(Device.allocateDeviceMemory<int>(7))),
@@ -50,6 +50,8 @@ public:
}
protected:
+ int DummyPlatformStream; // Mimicking a platform where the platform stream
+ // handle is just a stream number.
se::test::SimpleHostPlatformDevice PDevice;
se::Device Device;
se::Stream Stream;