summaryrefslogtreecommitdiff
path: root/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h
diff options
context:
space:
mode:
Diffstat (limited to 'parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h')
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h64
1 files changed, 24 insertions, 40 deletions
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h b/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h
index b3deff31f50..946f8f96a94 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h
@@ -31,34 +31,6 @@
namespace streamexecutor {
-class PlatformDevice;
-
-/// Platform-specific kernel handle.
-class PlatformKernelHandle {
-public:
- explicit PlatformKernelHandle(PlatformDevice *PDevice) : PDevice(PDevice) {}
-
- virtual ~PlatformKernelHandle();
-
- PlatformDevice *getDevice() { return PDevice; }
-
-private:
- PlatformDevice *PDevice;
-};
-
-/// Platform-specific stream handle.
-class PlatformStreamHandle {
-public:
- explicit PlatformStreamHandle(PlatformDevice *PDevice) : PDevice(PDevice) {}
-
- virtual ~PlatformStreamHandle();
-
- PlatformDevice *getDevice() { return PDevice; }
-
-private:
- PlatformDevice *PDevice;
-};
-
/// Raw executor methods that must be implemented by each platform.
///
/// This class defines the platform interface that supports executing work on a
@@ -73,19 +45,30 @@ public:
virtual std::string getName() const = 0;
/// Creates a platform-specific kernel.
- virtual Expected<std::unique_ptr<PlatformKernelHandle>>
+ virtual Expected<const void *>
createKernel(const MultiKernelLoaderSpec &Spec) {
return make_error("createKernel not implemented for platform " + getName());
}
+ virtual Error destroyKernel(const void *Handle) {
+ return make_error("destroyKernel not implemented for platform " +
+ getName());
+ }
+
/// Creates a platform-specific stream.
- virtual Expected<std::unique_ptr<PlatformStreamHandle>> createStream() {
+ virtual Expected<const void *> createStream() {
return make_error("createStream not implemented for platform " + getName());
}
+ virtual Error destroyStream(const void *Handle) {
+ return make_error("destroyStream not implemented for platform " +
+ getName());
+ }
+
/// Launches a kernel on the given stream.
- virtual Error launch(PlatformStreamHandle *S, BlockDimensions BlockSize,
- GridDimensions GridSize, PlatformKernelHandle *K,
+ virtual Error launch(const void *PlatformStreamHandle,
+ BlockDimensions BlockSize, GridDimensions GridSize,
+ const void *PKernelHandle,
const PackedKernelArgumentArrayBase &ArgumentArray) {
return make_error("launch not implemented for platform " + getName());
}
@@ -94,9 +77,9 @@ public:
///
/// HostDst should have been allocated by allocateHostMemory or registered
/// with registerHostMemory.
- virtual Error copyD2H(PlatformStreamHandle *S, const void *DeviceSrcHandle,
- size_t SrcByteOffset, void *HostDst,
- size_t DstByteOffset, size_t ByteCount) {
+ virtual Error copyD2H(const void *PlatformStreamHandle,
+ const void *DeviceSrcHandle, size_t SrcByteOffset,
+ void *HostDst, size_t DstByteOffset, size_t ByteCount) {
return make_error("copyD2H not implemented for platform " + getName());
}
@@ -104,22 +87,23 @@ public:
///
/// HostSrc should have been allocated by allocateHostMemory or registered
/// with registerHostMemory.
- virtual Error copyH2D(PlatformStreamHandle *S, const void *HostSrc,
+ virtual Error copyH2D(const void *PlatformStreamHandle, const void *HostSrc,
size_t SrcByteOffset, const void *DeviceDstHandle,
size_t DstByteOffset, size_t ByteCount) {
return make_error("copyH2D not implemented for platform " + getName());
}
/// Copies data from one device location to another.
- virtual Error copyD2D(PlatformStreamHandle *S, const void *DeviceSrcHandle,
- size_t SrcByteOffset, const void *DeviceDstHandle,
- size_t DstByteOffset, size_t ByteCount) {
+ virtual Error copyD2D(const void *PlatformStreamHandle,
+ const void *DeviceSrcHandle, size_t SrcByteOffset,
+ const void *DeviceDstHandle, size_t DstByteOffset,
+ size_t ByteCount) {
return make_error("copyD2D not implemented for platform " + getName());
}
/// Blocks the host until the given stream completes all the work enqueued up
/// to the point this function is called.
- virtual Error blockHostUntilDone(PlatformStreamHandle *S) {
+ virtual Error blockHostUntilDone(const void *PlatformStreamHandle) {
return make_error("blockHostUntilDone not implemented for platform " +
getName());
}