diff options
author | Jason Henline <jhen@google.com> | 2016-09-13 23:56:46 +0000 |
---|---|---|
committer | Jason Henline <jhen@google.com> | 2016-09-13 23:56:46 +0000 |
commit | 409fb55669c8cf4dc98be65886d02a80bf45c6e0 (patch) | |
tree | 3526868ca45a19bc050445973f2e4f8219883de8 /parallel-libs | |
parent | dc9664bd6182cde517fc1a19b28ab4fccf92a7ef (diff) |
[SE] Platforms return Device values
Summary:
Platforms were returning Device pointers, but a Device is now basically
just a pointer to an underlying PlatformDevice, so we will now just pass
it around as a value.
Reviewers: jlebar
Subscribers: jprice, jlebar, parallel_libs-commits
Differential Revision: https://reviews.llvm.org/D24537
Diffstat (limited to 'parallel-libs')
4 files changed, 19 insertions, 24 deletions
diff --git a/parallel-libs/streamexecutor/examples/CUDASaxpy.cpp b/parallel-libs/streamexecutor/examples/CUDASaxpy.cpp index 0fce5ed046b..6b2c59e5cd6 100644 --- a/parallel-libs/streamexecutor/examples/CUDASaxpy.cpp +++ b/parallel-libs/streamexecutor/examples/CUDASaxpy.cpp @@ -108,25 +108,25 @@ int main() { if (Platform->getDeviceCount() == 0) { return EXIT_FAILURE; } - se::Device *Device = getOrDie(Platform->getDevice(0)); + se::Device Device = getOrDie(Platform->getDevice(0)); // Load the kernel onto the device. cg::SaxpyKernel Kernel = - getOrDie(Device->createKernel<cg::SaxpyKernel>(cg::SaxpyLoaderSpec)); + getOrDie(Device.createKernel<cg::SaxpyKernel>(cg::SaxpyLoaderSpec)); se::RegisteredHostMemory<float> RegisteredX = - getOrDie(Device->registerHostMemory<float>(HostX)); + getOrDie(Device.registerHostMemory<float>(HostX)); se::RegisteredHostMemory<float> RegisteredY = - getOrDie(Device->registerHostMemory<float>(HostY)); + getOrDie(Device.registerHostMemory<float>(HostY)); // Allocate memory on the device. se::GlobalDeviceMemory<float> X = - getOrDie(Device->allocateDeviceMemory<float>(ArraySize)); + getOrDie(Device.allocateDeviceMemory<float>(ArraySize)); se::GlobalDeviceMemory<float> Y = - getOrDie(Device->allocateDeviceMemory<float>(ArraySize)); + getOrDie(Device.allocateDeviceMemory<float>(ArraySize)); // Run operations on a stream. - se::Stream Stream = getOrDie(Device->createStream()); + se::Stream Stream = getOrDie(Device.createStream()); Stream.thenCopyH2D(RegisteredX, X) .thenCopyH2D(RegisteredY, Y) .thenLaunch(ArraySize, 1, Kernel, A, X, Y) diff --git a/parallel-libs/streamexecutor/examples/HostSaxpy.cpp b/parallel-libs/streamexecutor/examples/HostSaxpy.cpp index 525c4453b01..5bcbcc898ce 100644 --- a/parallel-libs/streamexecutor/examples/HostSaxpy.cpp +++ b/parallel-libs/streamexecutor/examples/HostSaxpy.cpp @@ -62,25 +62,25 @@ int main() { if (Platform->getDeviceCount() == 0) { return EXIT_FAILURE; } - se::Device *Device = getOrDie(Platform->getDevice(0)); + se::Device Device = getOrDie(Platform->getDevice(0)); // Load the kernel onto the device. cg::SaxpyKernel Kernel = - getOrDie(Device->createKernel<cg::SaxpyKernel>(cg::SaxpyLoaderSpec)); + getOrDie(Device.createKernel<cg::SaxpyKernel>(cg::SaxpyLoaderSpec)); se::RegisteredHostMemory<float> RegisteredX = - getOrDie(Device->registerHostMemory<float>(HostX)); + getOrDie(Device.registerHostMemory<float>(HostX)); se::RegisteredHostMemory<float> RegisteredY = - getOrDie(Device->registerHostMemory<float>(HostY)); + getOrDie(Device.registerHostMemory<float>(HostY)); // Allocate memory on the device. se::GlobalDeviceMemory<float> X = - getOrDie(Device->allocateDeviceMemory<float>(ArraySize)); + getOrDie(Device.allocateDeviceMemory<float>(ArraySize)); se::GlobalDeviceMemory<float> Y = - getOrDie(Device->allocateDeviceMemory<float>(ArraySize)); + getOrDie(Device.allocateDeviceMemory<float>(ArraySize)); // Run operations on a stream. - se::Stream Stream = getOrDie(Device->createStream()); + se::Stream Stream = getOrDie(Device.createStream()); Stream.thenCopyH2D(RegisteredX, X) .thenCopyH2D(RegisteredY, Y) .thenLaunch(1, 1, Kernel, A, X, Y, ArraySize) diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Platform.h b/parallel-libs/streamexecutor/include/streamexecutor/Platform.h index 7b26f4972c8..8ced35d2066 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/Platform.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/Platform.h @@ -31,10 +31,8 @@ public: /// Gets the number of devices available for this platform. virtual size_t getDeviceCount() const = 0; - /// Gets a pointer to a Device with the given index for this platform. - /// - /// Ownership of the Device instance is NOT transferred to the caller. - virtual Expected<Device *> getDevice(size_t DeviceIndex) = 0; + /// Gets a Device with the given index for this platform. + virtual Expected<Device> getDevice(size_t DeviceIndex) = 0; }; } // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/include/streamexecutor/platforms/host/HostPlatform.h b/parallel-libs/streamexecutor/include/streamexecutor/platforms/host/HostPlatform.h index 52ad1ead5da..338e3f6265a 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/platforms/host/HostPlatform.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/platforms/host/HostPlatform.h @@ -30,24 +30,21 @@ class HostPlatform : public Platform { public: size_t getDeviceCount() const override { return 1; } - Expected<Device *> getDevice(size_t DeviceIndex) override { + Expected<Device> getDevice(size_t DeviceIndex) override { if (DeviceIndex != 0) { return make_error( "Requested device index " + llvm::Twine(DeviceIndex) + " from host platform which only supports device index 0"); } llvm::sys::ScopedLock Lock(Mutex); - if (!TheDevice) { + if (!ThePlatformDevice) ThePlatformDevice = llvm::make_unique<HostPlatformDevice>(); - TheDevice = llvm::make_unique<Device>(ThePlatformDevice.get()); - } - return TheDevice.get(); + return Device(ThePlatformDevice.get()); } private: llvm::sys::Mutex Mutex; std::unique_ptr<HostPlatformDevice> ThePlatformDevice; - std::unique_ptr<Device> TheDevice; }; } // namespace host |