summaryrefslogtreecommitdiff
path: root/parallel-libs
diff options
context:
space:
mode:
authorJason Henline <jhen@google.com>2016-09-13 23:56:46 +0000
committerJason Henline <jhen@google.com>2016-09-13 23:56:46 +0000
commit409fb55669c8cf4dc98be65886d02a80bf45c6e0 (patch)
tree3526868ca45a19bc050445973f2e4f8219883de8 /parallel-libs
parentdc9664bd6182cde517fc1a19b28ab4fccf92a7ef (diff)
[SE] Platforms return Device values
Summary: Platforms were returning Device pointers, but a Device is now basically just a pointer to an underlying PlatformDevice, so we will now just pass it around as a value. Reviewers: jlebar Subscribers: jprice, jlebar, parallel_libs-commits Differential Revision: https://reviews.llvm.org/D24537
Diffstat (limited to 'parallel-libs')
-rw-r--r--parallel-libs/streamexecutor/examples/CUDASaxpy.cpp14
-rw-r--r--parallel-libs/streamexecutor/examples/HostSaxpy.cpp14
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/Platform.h6
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/platforms/host/HostPlatform.h9
4 files changed, 19 insertions, 24 deletions
diff --git a/parallel-libs/streamexecutor/examples/CUDASaxpy.cpp b/parallel-libs/streamexecutor/examples/CUDASaxpy.cpp
index 0fce5ed046b..6b2c59e5cd6 100644
--- a/parallel-libs/streamexecutor/examples/CUDASaxpy.cpp
+++ b/parallel-libs/streamexecutor/examples/CUDASaxpy.cpp
@@ -108,25 +108,25 @@ int main() {
if (Platform->getDeviceCount() == 0) {
return EXIT_FAILURE;
}
- se::Device *Device = getOrDie(Platform->getDevice(0));
+ se::Device Device = getOrDie(Platform->getDevice(0));
// Load the kernel onto the device.
cg::SaxpyKernel Kernel =
- getOrDie(Device->createKernel<cg::SaxpyKernel>(cg::SaxpyLoaderSpec));
+ getOrDie(Device.createKernel<cg::SaxpyKernel>(cg::SaxpyLoaderSpec));
se::RegisteredHostMemory<float> RegisteredX =
- getOrDie(Device->registerHostMemory<float>(HostX));
+ getOrDie(Device.registerHostMemory<float>(HostX));
se::RegisteredHostMemory<float> RegisteredY =
- getOrDie(Device->registerHostMemory<float>(HostY));
+ getOrDie(Device.registerHostMemory<float>(HostY));
// Allocate memory on the device.
se::GlobalDeviceMemory<float> X =
- getOrDie(Device->allocateDeviceMemory<float>(ArraySize));
+ getOrDie(Device.allocateDeviceMemory<float>(ArraySize));
se::GlobalDeviceMemory<float> Y =
- getOrDie(Device->allocateDeviceMemory<float>(ArraySize));
+ getOrDie(Device.allocateDeviceMemory<float>(ArraySize));
// Run operations on a stream.
- se::Stream Stream = getOrDie(Device->createStream());
+ se::Stream Stream = getOrDie(Device.createStream());
Stream.thenCopyH2D(RegisteredX, X)
.thenCopyH2D(RegisteredY, Y)
.thenLaunch(ArraySize, 1, Kernel, A, X, Y)
diff --git a/parallel-libs/streamexecutor/examples/HostSaxpy.cpp b/parallel-libs/streamexecutor/examples/HostSaxpy.cpp
index 525c4453b01..5bcbcc898ce 100644
--- a/parallel-libs/streamexecutor/examples/HostSaxpy.cpp
+++ b/parallel-libs/streamexecutor/examples/HostSaxpy.cpp
@@ -62,25 +62,25 @@ int main() {
if (Platform->getDeviceCount() == 0) {
return EXIT_FAILURE;
}
- se::Device *Device = getOrDie(Platform->getDevice(0));
+ se::Device Device = getOrDie(Platform->getDevice(0));
// Load the kernel onto the device.
cg::SaxpyKernel Kernel =
- getOrDie(Device->createKernel<cg::SaxpyKernel>(cg::SaxpyLoaderSpec));
+ getOrDie(Device.createKernel<cg::SaxpyKernel>(cg::SaxpyLoaderSpec));
se::RegisteredHostMemory<float> RegisteredX =
- getOrDie(Device->registerHostMemory<float>(HostX));
+ getOrDie(Device.registerHostMemory<float>(HostX));
se::RegisteredHostMemory<float> RegisteredY =
- getOrDie(Device->registerHostMemory<float>(HostY));
+ getOrDie(Device.registerHostMemory<float>(HostY));
// Allocate memory on the device.
se::GlobalDeviceMemory<float> X =
- getOrDie(Device->allocateDeviceMemory<float>(ArraySize));
+ getOrDie(Device.allocateDeviceMemory<float>(ArraySize));
se::GlobalDeviceMemory<float> Y =
- getOrDie(Device->allocateDeviceMemory<float>(ArraySize));
+ getOrDie(Device.allocateDeviceMemory<float>(ArraySize));
// Run operations on a stream.
- se::Stream Stream = getOrDie(Device->createStream());
+ se::Stream Stream = getOrDie(Device.createStream());
Stream.thenCopyH2D(RegisteredX, X)
.thenCopyH2D(RegisteredY, Y)
.thenLaunch(1, 1, Kernel, A, X, Y, ArraySize)
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Platform.h b/parallel-libs/streamexecutor/include/streamexecutor/Platform.h
index 7b26f4972c8..8ced35d2066 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/Platform.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/Platform.h
@@ -31,10 +31,8 @@ public:
/// Gets the number of devices available for this platform.
virtual size_t getDeviceCount() const = 0;
- /// Gets a pointer to a Device with the given index for this platform.
- ///
- /// Ownership of the Device instance is NOT transferred to the caller.
- virtual Expected<Device *> getDevice(size_t DeviceIndex) = 0;
+ /// Gets a Device with the given index for this platform.
+ virtual Expected<Device> getDevice(size_t DeviceIndex) = 0;
};
} // namespace streamexecutor
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/platforms/host/HostPlatform.h b/parallel-libs/streamexecutor/include/streamexecutor/platforms/host/HostPlatform.h
index 52ad1ead5da..338e3f6265a 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/platforms/host/HostPlatform.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/platforms/host/HostPlatform.h
@@ -30,24 +30,21 @@ class HostPlatform : public Platform {
public:
size_t getDeviceCount() const override { return 1; }
- Expected<Device *> getDevice(size_t DeviceIndex) override {
+ Expected<Device> getDevice(size_t DeviceIndex) override {
if (DeviceIndex != 0) {
return make_error(
"Requested device index " + llvm::Twine(DeviceIndex) +
" from host platform which only supports device index 0");
}
llvm::sys::ScopedLock Lock(Mutex);
- if (!TheDevice) {
+ if (!ThePlatformDevice)
ThePlatformDevice = llvm::make_unique<HostPlatformDevice>();
- TheDevice = llvm::make_unique<Device>(ThePlatformDevice.get());
- }
- return TheDevice.get();
+ return Device(ThePlatformDevice.get());
}
private:
llvm::sys::Mutex Mutex;
std::unique_ptr<HostPlatformDevice> ThePlatformDevice;
- std::unique_ptr<Device> TheDevice;
};
} // namespace host