Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ endif ()

if (APPLE)
set(METAL_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/third-party/metal-cpp
/usr/local/include/metal_irconverter
${CMAKE_CURRENT_SOURCE_DIR}/third-party/metal_irconverter_runtime)
find_library(METAL_IRCONVERTER_LIBRARY metalirconverter PATHS /usr/local/lib REQUIRED)
set(METAL_LIBRARIES ${METAL_IRCONVERTER_LIBRARY})
set(OFFLOADTEST_ENABLE_METAL On)
endif ()

Expand Down
28 changes: 27 additions & 1 deletion include/Support/Pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,33 @@ enum class DataFormat {
Bool,
};

enum class DescriptorKind { UAV, SRV, CBV, SAMPLER };
Comment thread
kcloudy0717 marked this conversation as resolved.

static DescriptorKind getDescriptorKind(ResourceKind RK) {
switch (RK) {
case ResourceKind::Buffer:
case ResourceKind::StructuredBuffer:
case ResourceKind::ByteAddressBuffer:
case ResourceKind::Texture2D:
return DescriptorKind::SRV;

case ResourceKind::RWStructuredBuffer:
case ResourceKind::RWBuffer:
case ResourceKind::RWByteAddressBuffer:
case ResourceKind::RWTexture2D:
return DescriptorKind::UAV;

case ResourceKind::ConstantBuffer:
return DescriptorKind::CBV;

case ResourceKind::Sampler:
return DescriptorKind::SAMPLER;
case ResourceKind::SampledTexture2D:
llvm_unreachable("Sampled textures aren't supported!");
}
llvm_unreachable("All cases handled");
}

enum class FilterMode { Nearest, Linear };

enum class AddressMode { Clamp, Repeat, Mirror, Border, MirrorOnce };
Expand Down Expand Up @@ -408,7 +435,6 @@ struct Shader {
Stages Stage;
std::string Entry;
std::unique_ptr<llvm::MemoryBuffer> Shader;
std::unique_ptr<llvm::MemoryBuffer> Reflection;
llvm::SmallVector<SpecializationConstant> SpecializationConstants;
};

Expand Down
7 changes: 5 additions & 2 deletions lib/API/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ if (OFFLOADTEST_ENABLE_D3D12)
endif()

if (APPLE)
list(APPEND api_sources MTL/MTLDevice.cpp)
list(APPEND api_sources MTL/MTLDevice.cpp
MTL/MTLDescriptorHeap.cpp
MTL/MTLTopLevelArgumentBuffer.cpp)
list(APPEND api_libraries "-framework Metal"
"-framework MetalKit"
"-framework AppKit"
"-framework Foundation"
"-framework QuartzCore")
"-framework QuartzCore"
${METAL_LIBRARIES})
list(APPEND api_headers PRIVATE ${METAL_INCLUDE_DIRS})
endif()

Expand Down
70 changes: 22 additions & 48 deletions lib/API/DX/Device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,33 +193,6 @@ static D3D12_RESOURCE_DIMENSION getDXDimension(ResourceKind RK) {
llvm_unreachable("All cases handled");
}

enum DXResourceKind { UAV, SRV, CBV, SAMPLER };

static DXResourceKind getDXKind(offloadtest::ResourceKind RK) {
switch (RK) {
case ResourceKind::Buffer:
case ResourceKind::StructuredBuffer:
case ResourceKind::ByteAddressBuffer:
case ResourceKind::Texture2D:
return SRV;

case ResourceKind::RWStructuredBuffer:
case ResourceKind::RWBuffer:
case ResourceKind::RWByteAddressBuffer:
case ResourceKind::RWTexture2D:
return UAV;

case ResourceKind::ConstantBuffer:
return CBV;

case ResourceKind::Sampler:
return SAMPLER;
case ResourceKind::SampledTexture2D:
llvm_unreachable("Sampled textures aren't supported in DirectX!");
}
llvm_unreachable("All cases handled");
}

static llvm::Expected<D3D12_RESOURCE_DESC>
getResourceDescription(const Resource &R) {
const D3D12_RESOURCE_DIMENSION Dimension = getDXDimension(R.Kind);
Expand All @@ -239,7 +212,8 @@ getResourceDescription(const Resource &R) {

if (R.isTexture())
Layout =
R.IsReserved && (getDXKind(R.Kind) == SRV || getDXKind(R.Kind) == UAV)
R.IsReserved && (getDescriptorKind(R.Kind) == DescriptorKind::SRV ||
getDescriptorKind(R.Kind) == DescriptorKind::UAV)
? D3D12_TEXTURE_LAYOUT_64KB_UNDEFINED_SWIZZLE
: D3D12_TEXTURE_LAYOUT_UNKNOWN;
else
Expand Down Expand Up @@ -786,17 +760,17 @@ class DXDevice : public offloadtest::Device {
uint32_t DescriptorIdx = 0;
const uint32_t StartRangeIdx = RangeIdx;
for (const auto &Binding : Set.ResourceBindings) {
switch (getDXKind(Binding.Kind)) {
case SRV:
switch (getDescriptorKind(Binding.Kind)) {
case DescriptorKind::SRV:
Ranges.get()[RangeIdx].RangeType = D3D12_DESCRIPTOR_RANGE_TYPE_SRV;
break;
case UAV:
case DescriptorKind::UAV:
Ranges.get()[RangeIdx].RangeType = D3D12_DESCRIPTOR_RANGE_TYPE_UAV;
break;
case CBV:
case DescriptorKind::CBV:
Ranges.get()[RangeIdx].RangeType = D3D12_DESCRIPTOR_RANGE_TYPE_CBV;
break;
case SAMPLER:
case DescriptorKind::SAMPLER:
llvm_unreachable("Not implemented yet."); // Requires a separate heap
}
Ranges.get()[RangeIdx].NumDescriptors = Binding.DescriptorCount;
Expand Down Expand Up @@ -1592,29 +1566,29 @@ class DXDevice : public offloadtest::Device {
[&IS,
this](Resource &R,
llvm::SmallVectorImpl<ResourcePair> &Resources) -> llvm::Error {
switch (getDXKind(R.Kind)) {
case SRV: {
switch (getDescriptorKind(R.Kind)) {
case DescriptorKind::SRV: {
auto ExRes = createSRV(R, IS);
if (!ExRes)
return ExRes.takeError();
Resources.push_back(std::make_pair(&R, *ExRes));
break;
}
case UAV: {
case DescriptorKind::UAV: {
auto ExRes = createUAV(R, IS);
if (!ExRes)
return ExRes.takeError();
Resources.push_back(std::make_pair(&R, *ExRes));
break;
}
case CBV: {
case DescriptorKind::CBV: {
auto ExRes = createCBV(R, IS);
if (!ExRes)
return ExRes.takeError();
Resources.push_back(std::make_pair(&R, *ExRes));
break;
}
case SAMPLER:
case DescriptorKind::SAMPLER:
return llvm::createStringError(
std::errc::not_supported,
"Samplers are not yet implemented for DirectX.");
Expand All @@ -1634,17 +1608,17 @@ class DXDevice : public offloadtest::Device {
uint32_t HeapIndex = 0;
for (auto &T : IS.DescTables) {
for (auto &R : T.Resources) {
switch (getDXKind(R.first->Kind)) {
case SRV:
switch (getDescriptorKind(R.first->Kind)) {
case DescriptorKind::SRV:
HeapIndex = bindSRV(*(R.first), IS, HeapIndex, R.second);
break;
case UAV:
case DescriptorKind::UAV:
HeapIndex = bindUAV(*(R.first), IS, HeapIndex, R.second);
break;
case CBV:
case DescriptorKind::CBV:
HeapIndex = bindCBV(*(R.first), IS, HeapIndex, R.second);
break;
case SAMPLER:
case DescriptorKind::SAMPLER:
llvm_unreachable("Not implemented yet.");
}
}
Expand Down Expand Up @@ -1748,23 +1722,23 @@ class DXDevice : public offloadtest::Device {
return llvm::createStringError(
std::errc::value_too_large,
"Root descriptor cannot refer to resource arrays.");
switch (getDXKind(RootDescIt->first->Kind)) {
case SRV:
switch (getDescriptorKind(RootDescIt->first->Kind)) {
case DescriptorKind::SRV:
IS.CB->CmdList->SetComputeRootShaderResourceView(
RootParamIndex++,
RootDescIt->second.back().Buffer->GetGPUVirtualAddress());
break;
case UAV:
case DescriptorKind::UAV:
IS.CB->CmdList->SetComputeRootUnorderedAccessView(
RootParamIndex++,
RootDescIt->second.back().Buffer->GetGPUVirtualAddress());
break;
case CBV:
case DescriptorKind::CBV:
IS.CB->CmdList->SetComputeRootConstantBufferView(
RootParamIndex++,
RootDescIt->second.back().Buffer->GetGPUVirtualAddress());
break;
case SAMPLER:
case DescriptorKind::SAMPLER:
llvm_unreachable("Not implemented yet.");
}
++RootDescIt;
Expand Down
80 changes: 80 additions & 0 deletions lib/API/MTL/MTLDescriptorHeap.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
//===- MTL/MTLDescriptorHeap.cpp - Metal Descriptor Heap ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "MTLDescriptorHeap.h"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing file header - we do want the license statement in every file.

#include "MetalIRConverter.h"

using namespace offloadtest;

static NS::UInteger getDescriptorHeapBindPoint(MTLDescriptorHeapType Type) {
switch (Type) {
case MTLDescriptorHeapType::CBV_SRV_UAV:
return kIRDescriptorHeapBindPoint;
case MTLDescriptorHeapType::Sampler:
return kIRSamplerHeapBindPoint;
}
llvm_unreachable("All cases handled.");
}

MTLGPUDescriptorHandle &
MTLGPUDescriptorHandle::addOffset(int32_t OffsetInDescriptors) {
Ptr = MTL::GPUAddress(int64_t(Ptr) + int64_t(OffsetInDescriptors) *
sizeof(IRDescriptorTableEntry));
return *this;
}

llvm::Expected<std::unique_ptr<MTLDescriptorHeap>>
MTLDescriptorHeap::create(MTL::Device *Device,
const MTLDescriptorHeapDesc &Desc) {
if (!Device)
return llvm::createStringError(std::errc::invalid_argument,
"Invalid MTL::Device pointer.");

if (Desc.NumDescriptors == 0)
return llvm::createStringError(std::errc::invalid_argument,
"Invalid descriptor heap description.");

MTL::Buffer *Buf =
Device->newBuffer(Desc.NumDescriptors * sizeof(IRDescriptorTableEntry),
MTL::ResourceStorageModeShared);
if (!Buf)
return llvm::createStringError(std::errc::not_enough_memory,
"Failed to create MTLDescriptorHeap.");
return std::make_unique<MTLDescriptorHeap>(Desc, Buf);
}

MTLDescriptorHeap::~MTLDescriptorHeap() {
if (Buffer)
Buffer->release();
}

MTLGPUDescriptorHandle
MTLDescriptorHeap::getGPUDescriptorHandleForHeapStart() const {
return MTLGPUDescriptorHandle{Buffer->gpuAddress()};
}

IRDescriptorTableEntry *
MTLDescriptorHeap::getEntryHandle(uint32_t Index) const {
assert(Index < Desc.NumDescriptors && "Descriptor index out of bounds.");
return static_cast<IRDescriptorTableEntry *>(Buffer->contents()) + Index;
}

void MTLDescriptorHeap::bind(MTL::RenderCommandEncoder *Encoder) {
Encoder->useResource(Buffer, MTL::ResourceUsageRead);
// Dynamic resource indexing
const NS::UInteger BindPoint = getDescriptorHeapBindPoint(Desc.Type);
Encoder->setVertexBuffer(Buffer, 0, BindPoint);
Encoder->setFragmentBuffer(Buffer, 0, BindPoint);
}

void MTLDescriptorHeap::bind(MTL::ComputeCommandEncoder *Encoder) {
Encoder->useResource(Buffer, MTL::ResourceUsageRead);
// Dynamic resource indexing
const NS::UInteger BindPoint = getDescriptorHeapBindPoint(Desc.Type);
Encoder->setBuffer(Buffer, 0, BindPoint);
}
66 changes: 66 additions & 0 deletions lib/API/MTL/MTLDescriptorHeap.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//===- MTLDescriptorHeap.h - Metal Descriptor Heap ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef OFFLOADTEST_API_MTL_MTLDESCRIPTORHEAP_H
#define OFFLOADTEST_API_MTL_MTLDESCRIPTORHEAP_H

#include "llvm/Support/Error.h"
#include <memory>

// Forward declarations
namespace MTL {
class Device;
class Buffer;
class RenderCommandEncoder;
class ComputeCommandEncoder;
} // namespace MTL
struct IRDescriptorTableEntry;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit awkward that metal_irconverter_runtime.h just puts all of its definitions in the global namespace. Maybe a comment saying this forward declaration is for an object from there would be helpful.


namespace offloadtest {
struct MTLGPUDescriptorHandle {
MTLGPUDescriptorHandle &addOffset(int32_t OffsetInDescriptors);

uint64_t Ptr;
};

enum class MTLDescriptorHeapType {
CBV_SRV_UAV,
Sampler,
};

struct MTLDescriptorHeapDesc {
MTLDescriptorHeapType Type;
uint32_t NumDescriptors;
};

// MTLDescriptorHeap mimics the D3D12 descriptor heap concept, except
// MTLDescriptorHeap is always shader visible and meant to be used
// by the argument buffer for shader resource binding with the explicit root
// signature layout.
class MTLDescriptorHeap {
MTLDescriptorHeapDesc Desc;
MTL::Buffer *Buffer;

public:
static llvm::Expected<std::unique_ptr<MTLDescriptorHeap>>
create(MTL::Device *Device, const MTLDescriptorHeapDesc &Desc);

MTLDescriptorHeap(const MTLDescriptorHeapDesc &Desc, MTL::Buffer *Buffer)
: Desc(Desc), Buffer(Buffer) {}
~MTLDescriptorHeap();

MTLGPUDescriptorHandle getGPUDescriptorHandleForHeapStart() const;

IRDescriptorTableEntry *getEntryHandle(uint32_t Index) const;

void bind(MTL::RenderCommandEncoder *Encoder);
void bind(MTL::ComputeCommandEncoder *Encoder);
};
} // namespace offloadtest

#endif // OFFLOADTEST_API_MTL_MTLDESCRIPTORHEAP_H
Loading
Loading