summaryrefslogtreecommitdiff
path: root/parallel-libs
diff options
context:
space:
mode:
authorJason Henline <jhen@google.com>2016-08-30 23:35:24 +0000
committerJason Henline <jhen@google.com>2016-08-30 23:35:24 +0000
commit7b79cbd85f70a2f3ee9d467d4dee55989770de04 (patch)
tree5a34faf7360202efec60d599182f5d2b274e39a2 /parallel-libs
parentcb460e203894b46465f3a2e441dbb66938a5c5c0 (diff)
[StreamExecutor] Simplify Kernel classes
Summary: Make the Kernel class follow the pattern of the other classes. It now has a type-safe user wrapper and a typeless, platform-specific handle. Reviewers: jlebar Subscribers: jprice, parallel_libs-commits Differential Revision: https://reviews.llvm.org/D24043
Diffstat (limited to 'parallel-libs')
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/Device.h26
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/Kernel.h114
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h26
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/Stream.h6
-rw-r--r--parallel-libs/streamexecutor/lib/Kernel.cpp24
-rw-r--r--parallel-libs/streamexecutor/lib/unittests/CMakeLists.txt10
-rw-r--r--parallel-libs/streamexecutor/lib/unittests/KernelTest.cpp93
7 files changed, 87 insertions, 212 deletions
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Device.h b/parallel-libs/streamexecutor/include/streamexecutor/Device.h
index 34bba80859d..c37f9b1affb 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/Device.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/Device.h
@@ -15,13 +15,14 @@
#ifndef STREAMEXECUTOR_DEVICE_H
#define STREAMEXECUTOR_DEVICE_H
+#include <type_traits>
+
#include "streamexecutor/KernelSpec.h"
#include "streamexecutor/PlatformInterfaces.h"
#include "streamexecutor/Utils/Error.h"
namespace streamexecutor {
-class KernelInterface;
class Stream;
class Device {
@@ -29,11 +30,24 @@ public:
explicit Device(PlatformDevice *PDevice);
virtual ~Device();
- /// Gets the kernel implementation for the underlying platform.
- virtual Expected<std::unique_ptr<KernelInterface>>
- getKernelImplementation(const MultiKernelLoaderSpec &Spec) {
- // TODO(jhen): Implement this.
- return nullptr;
+ /// Creates a kernel object for this device.
+ ///
+ /// If the return value is not an error, the returned pointer will never be
+ /// null.
+ ///
+ /// See \ref CompilerGeneratedKernelExample "Kernel.h" for an example of how
+ /// this method is used.
+ template <typename KernelT>
+ Expected<std::unique_ptr<typename std::enable_if<
+ std::is_base_of<KernelBase, KernelT>::value, KernelT>::type>>
+ createKernel(const MultiKernelLoaderSpec &Spec) {
+ Expected<std::unique_ptr<PlatformKernelHandle>> MaybeKernelHandle =
+ PDevice->createKernel(Spec);
+ if (!MaybeKernelHandle) {
+ return MaybeKernelHandle.takeError();
+ }
+ return llvm::make_unique<KernelT>(Spec.getKernelName(),
+ std::move(*MaybeKernelHandle));
}
Expected<std::unique_ptr<Stream>> createStream();
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Kernel.h b/parallel-libs/streamexecutor/include/streamexecutor/Kernel.h
index 4a2eeb4b915..63d9c711425 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/Kernel.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/Kernel.h
@@ -11,62 +11,64 @@
/// Types to represent device kernels (code compiled to run on GPU or other
/// accelerator).
///
-/// The TypedKernel class is used to provide type safety to the user API's
-/// launch functions, and the KernelBase class is used like a void* function
-/// pointer to perform type-unsafe operations inside StreamExecutor.
-///
-/// With the kernel parameter types recorded in the TypedKernel template
-/// parameters, type-safe kernel launch functions can be written with signatures
-/// like the following:
+/// With the kernel parameter types recorded in the Kernel template parameters,
+/// type-safe kernel launch functions can be written with signatures like the
+/// following:
/// \code
/// template <typename... ParameterTs>
/// void Launch(
-/// const TypedKernel<ParameterTs...> &Kernel, ParamterTs... Arguments);
+/// const Kernel<ParameterTs...> &Kernel, ParamterTs... Arguments);
/// \endcode
/// and the compiler will check that the user passes in arguments with types
/// matching the corresponding kernel parameters.
///
-/// A problem is that a TypedKernel template specialization with the right
-/// parameter types must be passed as the first argument to the Launch function,
-/// and it's just as hard to get the types right in that template specialization
-/// as it is to get them right for the kernel arguments.
+/// A problem is that a Kernel template specialization with the right parameter
+/// types must be passed as the first argument to the Launch function, and it's
+/// just as hard to get the types right in that template specialization as it is
+/// to get them right for the kernel arguments.
///
/// With this problem in mind, it is not recommended for users to specialize the
-/// TypedKernel template class themselves, but instead to let the compiler do it
-/// for them. When the compiler encounters a device kernel function, it can
-/// create a TypedKernel template specialization in the host code that has the
-/// right parameter types for that kernel and which has a type name based on the
-/// name of the kernel function.
+/// Kernel template class themselves, but instead to let the compiler do it for
+/// them. When the compiler encounters a device kernel function, it can create a
+/// Kernel template specialization in the host code that has the right parameter
+/// types for that kernel and which has a type name based on the name of the
+/// kernel function.
///
+/// \anchor CompilerGeneratedKernelExample
/// For example, if a CUDA device kernel function with the following signature
/// has been defined:
/// \code
-/// void Saxpy(float *A, float *X, float *Y);
+/// void Saxpy(float A, float *X, float *Y);
/// \endcode
/// the compiler can insert the following declaration in the host code:
/// \code
/// namespace compiler_cuda_namespace {
+/// namespace se = streamexecutor;
/// using SaxpyKernel =
-/// streamexecutor::TypedKernel<float *, float *, float *>;
+/// se::Kernel<
+/// float,
+/// se::GlobalDeviceMemory<float>,
+/// se::GlobalDeviceMemory<float>>;
/// } // namespace compiler_cuda_namespace
/// \endcode
/// and then the user can launch the kernel by calling the StreamExecutor launch
/// function as follows:
/// \code
/// namespace ccn = compiler_cuda_namespace;
+/// using KernelPtr = std::unique_ptr<cnn::SaxpyKernel>;
/// // Assumes Device is a pointer to the Device on which to launch the
/// // kernel.
/// //
/// // See KernelSpec.h for details on how the compiler can create a
/// // MultiKernelLoaderSpec instance like SaxpyKernelLoaderSpec below.
-/// Expected<ccn::SaxpyKernel> MaybeKernel =
-/// ccn::SaxpyKernel::create(Device, ccn::SaxpyKernelLoaderSpec);
+/// Expected<KernelPtr> MaybeKernel =
+/// Device->createKernel<ccn::SaxpyKernel>(ccn::SaxpyKernelLoaderSpec);
/// if (!MaybeKernel) { /* Handle error */ }
-/// ccn::SaxpyKernel SaxpyKernel = *MaybeKernel;
-/// Launch(SaxpyKernel, A, X, Y);
+/// KernelPtr SaxpyKernel = std::move(*MaybeKernel);
+/// Launch(*SaxpyKernel, A, X, Y);
/// \endcode
///
-/// With the compiler's help in specializing TypedKernel for each device kernel
+/// With the compiler's help in specializing Kernel for each device kernel
/// function (and generating a MultiKernelLoaderSpec instance for each kernel),
/// the user can safely launch the device kernel from the host and get an error
/// message at compile time if the argument types don't match the kernel
@@ -84,73 +86,37 @@
namespace streamexecutor {
-class Device;
-class KernelInterface;
+class PlatformKernelHandle;
-/// The base class for device kernel functions.
-///
-/// This class has no information about the types of the parameters taken by the
-/// kernel, so it is analogous to a void* pointer to a device function.
+/// The base class for all kernel types.
///
-/// See the TypedKernel class below for the subclass which does have information
-/// about parameter types.
+/// Stores the name of the kernel in both mangled and demangled forms.
class KernelBase {
public:
- KernelBase(KernelBase &&) = default;
- KernelBase &operator=(KernelBase &&) = default;
- ~KernelBase();
-
- /// Creates a kernel object from a Device and a MultiKernelLoaderSpec.
- ///
- /// The Device knows which platform it belongs to and the
- /// MultiKernelLoaderSpec knows how to find the kernel code for different
- /// platforms, so the combined information is enough to get the kernel code
- /// for the appropriate platform.
- static Expected<KernelBase> create(Device *Dev,
- const MultiKernelLoaderSpec &Spec);
+ KernelBase(llvm::StringRef Name);
const std::string &getName() const { return Name; }
const std::string &getDemangledName() const { return DemangledName; }
- /// Gets a pointer to the platform-specific implementation of this kernel.
- KernelInterface *getImplementation() { return Implementation.get(); }
-
private:
- KernelBase(Device *Dev, const std::string &Name,
- const std::string &DemangledName,
- std::unique_ptr<KernelInterface> Implementation);
-
- Device *TheDevice;
std::string Name;
std::string DemangledName;
- std::unique_ptr<KernelInterface> Implementation;
-
- KernelBase(const KernelBase &) = delete;
- KernelBase &operator=(const KernelBase &) = delete;
};
-/// A device kernel function with specified parameter types.
-template <typename... ParameterTs> class TypedKernel : public KernelBase {
+/// A StreamExecutor kernel.
+///
+/// The template parameters are the types of the parameters to the kernel
+/// function.
+template <typename... ParameterTs> class Kernel : public KernelBase {
public:
- TypedKernel(TypedKernel &&) = default;
- TypedKernel &operator=(TypedKernel &&) = default;
+ Kernel(llvm::StringRef Name, std::unique_ptr<PlatformKernelHandle> PHandle)
+ : KernelBase(Name), PHandle(std::move(PHandle)) {}
- /// Parameters here have the same meaning as in KernelBase::create.
- static Expected<TypedKernel> create(Device *Dev,
- const MultiKernelLoaderSpec &Spec) {
- auto MaybeBase = KernelBase::create(Dev, Spec);
- if (!MaybeBase) {
- return MaybeBase.takeError();
- }
- TypedKernel Instance(std::move(*MaybeBase));
- return std::move(Instance);
- }
+ /// Gets the underlying platform-specific handle for this kernel.
+ PlatformKernelHandle *getPlatformHandle() const { return PHandle.get(); }
private:
- TypedKernel(KernelBase &&Base) : KernelBase(std::move(Base)) {}
-
- TypedKernel(const TypedKernel &) = delete;
- TypedKernel &operator=(const TypedKernel &) = delete;
+ std::unique_ptr<PlatformKernelHandle> PHandle;
};
} // namespace streamexecutor
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h b/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h
index b7737e82e7d..8fa31b63ef2 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h
@@ -33,9 +33,17 @@ namespace streamexecutor {
class PlatformDevice;
-/// Methods supported by device kernel function objects on all platforms.
-class KernelInterface {
- // TODO(jhen): Add methods.
+/// Platform-specific kernel handle.
+class PlatformKernelHandle {
+public:
+ explicit PlatformKernelHandle(PlatformDevice *PDevice) : PDevice(PDevice) {}
+
+ virtual ~PlatformKernelHandle();
+
+ PlatformDevice *getDevice() { return PDevice; }
+
+private:
+ PlatformDevice *PDevice;
};
/// Platform-specific stream handle.
@@ -64,12 +72,20 @@ public:
virtual std::string getName() const = 0;
+ /// Creates a platform-specific kernel.
+ virtual Expected<std::unique_ptr<PlatformKernelHandle>>
+ createKernel(const MultiKernelLoaderSpec &Spec) {
+ return make_error("createKernel not implemented for platform " + getName());
+ }
+
/// Creates a platform-specific stream.
- virtual Expected<std::unique_ptr<PlatformStreamHandle>> createStream() = 0;
+ virtual Expected<std::unique_ptr<PlatformStreamHandle>> createStream() {
+ return make_error("createStream not implemented for platform " + getName());
+ }
/// Launches a kernel on the given stream.
virtual Error launch(PlatformStreamHandle *S, BlockDimensions BlockSize,
- GridDimensions GridSize, const KernelBase &Kernel,
+ GridDimensions GridSize, PlatformKernelHandle *K,
const PackedKernelArgumentArrayBase &ArgumentArray) {
return make_error("launch not implemented for platform " + getName());
}
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Stream.h b/parallel-libs/streamexecutor/include/streamexecutor/Stream.h
index 0e6e898b473..2937c5842e8 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/Stream.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/Stream.h
@@ -86,15 +86,15 @@ public:
/// These arguments can be device memory types like GlobalDeviceMemory<T> and
/// SharedDeviceMemory<T>, or they can be primitive types such as int. The
/// allowable argument types are determined by the template parameters to the
- /// TypedKernel argument.
+ /// Kernel argument.
template <typename... ParameterTs>
Stream &thenLaunch(BlockDimensions BlockSize, GridDimensions GridSize,
- const TypedKernel<ParameterTs...> &Kernel,
+ const Kernel<ParameterTs...> &K,
const ParameterTs &... Arguments) {
auto ArgumentArray =
make_kernel_argument_pack<ParameterTs...>(Arguments...);
setError(PDevice->launch(ThePlatformStream.get(), BlockSize, GridSize,
- Kernel, ArgumentArray));
+ K.getPlatformHandle(), ArgumentArray));
return *this;
}
diff --git a/parallel-libs/streamexecutor/lib/Kernel.cpp b/parallel-libs/streamexecutor/lib/Kernel.cpp
index fa0992003a6..1f4218c4df3 100644
--- a/parallel-libs/streamexecutor/lib/Kernel.cpp
+++ b/parallel-libs/streamexecutor/lib/Kernel.cpp
@@ -20,26 +20,8 @@
namespace streamexecutor {
-KernelBase::KernelBase(Device *Dev, const std::string &Name,
- const std::string &DemangledName,
- std::unique_ptr<KernelInterface> Implementation)
- : TheDevice(Dev), Name(Name), DemangledName(DemangledName),
- Implementation(std::move(Implementation)) {}
-
-KernelBase::~KernelBase() = default;
-
-Expected<KernelBase> KernelBase::create(Device *Dev,
- const MultiKernelLoaderSpec &Spec) {
- auto MaybeImplementation = Dev->getKernelImplementation(Spec);
- if (!MaybeImplementation) {
- return MaybeImplementation.takeError();
- }
- std::string Name = Spec.getKernelName();
- std::string DemangledName =
- llvm::symbolize::LLVMSymbolizer::DemangleName(Name, nullptr);
- KernelBase Instance(Dev, Name, DemangledName,
- std::move(*MaybeImplementation));
- return std::move(Instance);
-}
+KernelBase::KernelBase(llvm::StringRef Name)
+ : Name(Name), DemangledName(llvm::symbolize::LLVMSymbolizer::DemangleName(
+ Name, nullptr)) {}
} // namespace streamexecutor
diff --git a/parallel-libs/streamexecutor/lib/unittests/CMakeLists.txt b/parallel-libs/streamexecutor/lib/unittests/CMakeLists.txt
index 3b414e342d9..e12b675f2c4 100644
--- a/parallel-libs/streamexecutor/lib/unittests/CMakeLists.txt
+++ b/parallel-libs/streamexecutor/lib/unittests/CMakeLists.txt
@@ -9,16 +9,6 @@ target_link_libraries(
add_test(DeviceTest device_test)
add_executable(
- kernel_test
- KernelTest.cpp)
-target_link_libraries(
- kernel_test
- streamexecutor
- ${GTEST_BOTH_LIBRARIES}
- ${CMAKE_THREAD_LIBS_INIT})
-add_test(KernelTest kernel_test)
-
-add_executable(
kernel_spec_test
KernelSpecTest.cpp)
target_link_libraries(
diff --git a/parallel-libs/streamexecutor/lib/unittests/KernelTest.cpp b/parallel-libs/streamexecutor/lib/unittests/KernelTest.cpp
deleted file mode 100644
index a19ebfb96bd..00000000000
--- a/parallel-libs/streamexecutor/lib/unittests/KernelTest.cpp
+++ /dev/null
@@ -1,93 +0,0 @@
-//===-- KernelTest.cpp - Tests for Kernel objects -------------------------===//
-//
-// The LLVM Compiler Infrastructure
-//
-// This file is distributed under the University of Illinois Open Source
-// License. See LICENSE.TXT for details.
-//
-//===----------------------------------------------------------------------===//
-///
-/// \file
-/// This file contains the unit tests for the code in Kernel.
-///
-//===----------------------------------------------------------------------===//
-
-#include <cassert>
-
-#include "streamexecutor/Device.h"
-#include "streamexecutor/Kernel.h"
-#include "streamexecutor/KernelSpec.h"
-#include "streamexecutor/PlatformInterfaces.h"
-
-#include "llvm/ADT/STLExtras.h"
-
-#include "gtest/gtest.h"
-
-namespace {
-
-namespace se = ::streamexecutor;
-
-// A Device that returns a dummy KernelInterface.
-//
-// During construction it creates a unique_ptr to a dummy KernelInterface and it
-// also stores a separate copy of the raw pointer that is stored by that
-// unique_ptr.
-//
-// The expectation is that the code being tested will call the
-// getKernelImplementation method and will thereby take ownership of the
-// unique_ptr, but the copy of the raw pointer will stay behind in this mock
-// object. The raw pointer copy can then be used to identify the unique_ptr in
-// its new location (by comparing the raw pointer with unique_ptr::get), to
-// verify that the unique_ptr ended up where it was supposed to be.
-class MockDevice : public se::Device {
-public:
- MockDevice()
- : se::Device(nullptr), Unique(llvm::make_unique<se::KernelInterface>()),
- Raw(Unique.get()) {}
-
- // Moves the unique pointer into the returned se::Expected instance.
- //
- // Asserts that it is not called again after the unique pointer has been moved
- // out.
- se::Expected<std::unique_ptr<se::KernelInterface>>
- getKernelImplementation(const se::MultiKernelLoaderSpec &) override {
- assert(Unique && "MockDevice getKernelImplementation should not be "
- "called more than once");
- return std::move(Unique);
- }
-
- // Gets the copy of the raw pointer from the original unique pointer.
- const se::KernelInterface *getRaw() const { return Raw; }
-
-private:
- std::unique_ptr<se::KernelInterface> Unique;
- const se::KernelInterface *Raw;
-};
-
-// Test fixture class for typed tests for KernelBase.getImplementation.
-//
-// The only purpose of this class is to provide a name that types can be bound
-// to in the gtest infrastructure.
-template <typename T> class GetImplementationTest : public ::testing::Test {};
-
-// Types used with the GetImplementationTest fixture class.
-typedef ::testing::Types<se::KernelBase, se::TypedKernel<>,
- se::TypedKernel<int>>
- GetImplementationTypes;
-
-TYPED_TEST_CASE(GetImplementationTest, GetImplementationTypes);
-
-// Tests that the kernel create functions properly fetch the implementation
-// pointers for the kernel objects they construct from the passed-in
-// Device objects.
-TYPED_TEST(GetImplementationTest, SetImplementationDuringCreate) {
- se::MultiKernelLoaderSpec Spec;
- MockDevice Dev;
-
- auto MaybeKernel = TypeParam::create(&Dev, Spec);
- EXPECT_TRUE(static_cast<bool>(MaybeKernel));
- se::KernelInterface *Implementation = MaybeKernel->getImplementation();
- EXPECT_EQ(Dev.getRaw(), Implementation);
-}
-
-} // namespace