diff options
Diffstat (limited to 'parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h')
-rw-r--r-- | parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h | 64 |
1 files changed, 24 insertions, 40 deletions
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h b/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h index b3deff31f50..946f8f96a94 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h @@ -31,34 +31,6 @@ namespace streamexecutor { -class PlatformDevice; - -/// Platform-specific kernel handle. -class PlatformKernelHandle { -public: - explicit PlatformKernelHandle(PlatformDevice *PDevice) : PDevice(PDevice) {} - - virtual ~PlatformKernelHandle(); - - PlatformDevice *getDevice() { return PDevice; } - -private: - PlatformDevice *PDevice; -}; - -/// Platform-specific stream handle. -class PlatformStreamHandle { -public: - explicit PlatformStreamHandle(PlatformDevice *PDevice) : PDevice(PDevice) {} - - virtual ~PlatformStreamHandle(); - - PlatformDevice *getDevice() { return PDevice; } - -private: - PlatformDevice *PDevice; -}; - /// Raw executor methods that must be implemented by each platform. /// /// This class defines the platform interface that supports executing work on a @@ -73,19 +45,30 @@ public: virtual std::string getName() const = 0; /// Creates a platform-specific kernel. - virtual Expected<std::unique_ptr<PlatformKernelHandle>> + virtual Expected<const void *> createKernel(const MultiKernelLoaderSpec &Spec) { return make_error("createKernel not implemented for platform " + getName()); } + virtual Error destroyKernel(const void *Handle) { + return make_error("destroyKernel not implemented for platform " + + getName()); + } + /// Creates a platform-specific stream. - virtual Expected<std::unique_ptr<PlatformStreamHandle>> createStream() { + virtual Expected<const void *> createStream() { return make_error("createStream not implemented for platform " + getName()); } + virtual Error destroyStream(const void *Handle) { + return make_error("destroyStream not implemented for platform " + + getName()); + } + /// Launches a kernel on the given stream. - virtual Error launch(PlatformStreamHandle *S, BlockDimensions BlockSize, - GridDimensions GridSize, PlatformKernelHandle *K, + virtual Error launch(const void *PlatformStreamHandle, + BlockDimensions BlockSize, GridDimensions GridSize, + const void *PKernelHandle, const PackedKernelArgumentArrayBase &ArgumentArray) { return make_error("launch not implemented for platform " + getName()); } @@ -94,9 +77,9 @@ public: /// /// HostDst should have been allocated by allocateHostMemory or registered /// with registerHostMemory. - virtual Error copyD2H(PlatformStreamHandle *S, const void *DeviceSrcHandle, - size_t SrcByteOffset, void *HostDst, - size_t DstByteOffset, size_t ByteCount) { + virtual Error copyD2H(const void *PlatformStreamHandle, + const void *DeviceSrcHandle, size_t SrcByteOffset, + void *HostDst, size_t DstByteOffset, size_t ByteCount) { return make_error("copyD2H not implemented for platform " + getName()); } @@ -104,22 +87,23 @@ public: /// /// HostSrc should have been allocated by allocateHostMemory or registered /// with registerHostMemory. - virtual Error copyH2D(PlatformStreamHandle *S, const void *HostSrc, + virtual Error copyH2D(const void *PlatformStreamHandle, const void *HostSrc, size_t SrcByteOffset, const void *DeviceDstHandle, size_t DstByteOffset, size_t ByteCount) { return make_error("copyH2D not implemented for platform " + getName()); } /// Copies data from one device location to another. - virtual Error copyD2D(PlatformStreamHandle *S, const void *DeviceSrcHandle, - size_t SrcByteOffset, const void *DeviceDstHandle, - size_t DstByteOffset, size_t ByteCount) { + virtual Error copyD2D(const void *PlatformStreamHandle, + const void *DeviceSrcHandle, size_t SrcByteOffset, + const void *DeviceDstHandle, size_t DstByteOffset, + size_t ByteCount) { return make_error("copyD2D not implemented for platform " + getName()); } /// Blocks the host until the given stream completes all the work enqueued up /// to the point this function is called. - virtual Error blockHostUntilDone(PlatformStreamHandle *S) { + virtual Error blockHostUntilDone(const void *PlatformStreamHandle) { return make_error("blockHostUntilDone not implemented for platform " + getName()); } |