Skip to content

Commit 901ceae

Browse files
committed
[SPIR-V] Add descriptor heap stride specialization constant attribute
Adds [[vk::resource_heap_stride_constant_id(id)]] and [[vk::sampler_heap_stride_constant_id(id)]] attributes that emit the descriptor heap ArrayStride as a SPIR-V specialization constant using ArrayStrideIdEXT, allowing applications to override the stride at pipeline creation time via VkSpecializationInfo without recompiling the shader. The attribute initializer supplies the default stride value and must be a power of two in [8, 256]. The CLI flags -fvk-resource-heap-stride and -fvk-sampler-heap-stride take higher precedence and suppress these attributes with a warning when both are specified.
1 parent 5e7b928 commit 901ceae

14 files changed

Lines changed: 491 additions & 14 deletions

tools/clang/include/clang/Basic/Attr.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1611,6 +1611,24 @@ def VKConstantId : InheritableAttr {
16111611
let Documentation = [Undocumented];
16121612
}
16131613

1614+
def VKResourceHeapStrideConstantId : InheritableAttr {
1615+
let Spellings = [CXX11<"vk", "resource_heap_stride_constant_id">];
1616+
let Subjects =
1617+
SubjectList<[ScalarGlobalVar], ErrorDiag, "ExpectedScalarGlobalVar">;
1618+
let Args = [IntArgument<"SpecConstId">];
1619+
let LangOpts = [SPIRV];
1620+
let Documentation = [Undocumented];
1621+
}
1622+
1623+
def VKSamplerHeapStrideConstantId : InheritableAttr {
1624+
let Spellings = [CXX11<"vk", "sampler_heap_stride_constant_id">];
1625+
let Subjects =
1626+
SubjectList<[ScalarGlobalVar], ErrorDiag, "ExpectedScalarGlobalVar">;
1627+
let Args = [IntArgument<"SpecConstId">];
1628+
let LangOpts = [SPIRV];
1629+
let Documentation = [Undocumented];
1630+
}
1631+
16141632
def VKPostDepthCoverage : InheritableAttr {
16151633
let Spellings = [CXX11<"vk", "post_depth_coverage">];
16161634
let Subjects = SubjectList<[Function], ErrorDiag>;

tools/clang/include/clang/SPIRV/SpirvContext.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ struct RuntimeArrayTypeMapInfo {
8989
static inline RuntimeArrayType *getTombstoneKey() { return nullptr; }
9090
static unsigned getHashValue(const RuntimeArrayType *Val) {
9191
return llvm::hash_combine(Val->getElementType(),
92-
Val->getStride().hasValue());
92+
Val->getStride().hasValue(),
93+
Val->getStrideSpecConst());
9394
}
9495
static bool isEqual(const RuntimeArrayType *LHS,
9596
const RuntimeArrayType *RHS) {
@@ -284,7 +285,8 @@ class SpirvContext {
284285
llvm::Optional<uint32_t> arrayStride);
285286
const RuntimeArrayType *
286287
getRuntimeArrayType(const SpirvType *elemType,
287-
llvm::Optional<uint32_t> arrayStride);
288+
llvm::Optional<uint32_t> arrayStride,
289+
SpirvInstruction *strideSpecConst = nullptr);
288290
const NodePayloadArrayType *
289291
getNodePayloadArrayType(const SpirvType *elemType,
290292
const ParmVarDecl *nodeDecl);

tools/clang/include/clang/SPIRV/SpirvType.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ namespace clang {
2323
namespace spirv {
2424

2525
class HybridType;
26+
class SpirvInstruction;
2627

2728
enum class StructInterfaceType : uint32_t {
2829
InternalStorage = 0,
@@ -274,9 +275,10 @@ class ArrayType : public SpirvType {
274275
class RuntimeArrayType : public SpirvType {
275276
public:
276277
RuntimeArrayType(const SpirvType *elemType,
277-
llvm::Optional<uint32_t> arrayStride)
278-
: SpirvType(TK_RuntimeArray), elementType(elemType), stride(arrayStride) {
279-
}
278+
llvm::Optional<uint32_t> arrayStride,
279+
SpirvInstruction *strideSpecConst = nullptr)
280+
: SpirvType(TK_RuntimeArray), elementType(elemType), stride(arrayStride),
281+
strideSpecConst(strideSpecConst) {}
280282

281283
static bool classof(const SpirvType *t) {
282284
return t->getKind() == TK_RuntimeArray;
@@ -286,12 +288,16 @@ class RuntimeArrayType : public SpirvType {
286288

287289
const SpirvType *getElementType() const { return elementType; }
288290
llvm::Optional<uint32_t> getStride() const { return stride; }
291+
SpirvInstruction *getStrideSpecConst() const { return strideSpecConst; }
289292

290293
private:
291294
const SpirvType *elementType;
292295
// Two runtime arrays with different ArrayStride decorations, are in fact two
293296
// different types. If no layout information is needed, use llvm::None.
297+
// Ignored when strideSpecConst is non-null (the spec-const wins).
294298
llvm::Optional<uint32_t> stride;
299+
// When non-null, ArrayStrideIdEXT %id is emitted instead of ArrayStride N.
300+
SpirvInstruction *strideSpecConst;
295301
};
296302

297303
class NodePayloadArrayType : public SpirvType {

tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1889,6 +1889,22 @@ void DeclResultIdMapper::registerSpecConstant(const VarDecl *decl,
18891889
SpirvInstruction *specConstant) {
18901890
specConstant->setRValue();
18911891
registerVariableForDecl(decl, createDeclSpirvInfo(specConstant));
1892+
if (decl->hasAttr<VKResourceHeapStrideConstantIdAttr>()) {
1893+
resourceHeapStride = HeapStrideSpecConst{
1894+
specConstant,
1895+
static_cast<uint32_t>(
1896+
decl->getAttr<VKResourceHeapStrideConstantIdAttr>()
1897+
->getSpecConstId()),
1898+
decl};
1899+
} else if (decl->hasAttr<VKSamplerHeapStrideConstantIdAttr>()) {
1900+
samplerHeapStride = HeapStrideSpecConst{
1901+
specConstant,
1902+
static_cast<uint32_t>(decl->getAttr<VKSamplerHeapStrideConstantIdAttr>()
1903+
->getSpecConstId()),
1904+
decl};
1905+
} else if (const auto *attr = decl->getAttr<VKConstantIdAttr>()) {
1906+
userSpecConstIdMap[attr->getSpecConstId()] = decl;
1907+
}
18921908
}
18931909

18941910
void DeclResultIdMapper::createCounterVar(

tools/clang/lib/SPIRV/DeclResultIdMapper.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,29 @@ class DeclResultIdMapper {
296296
/// |type|.
297297
SpirvVariableLike *createResourceHeap(const VarDecl *var, QualType type);
298298

299+
/// Records the [[vk::*_heap_stride_constant_id]] spec constant for one heap.
300+
/// Present iff such an attribute was declared in the translation unit.
301+
struct HeapStrideSpecConst {
302+
SpirvInstruction *specConst; // the OpSpecConstant
303+
uint32_t specId; // its SpecId
304+
const VarDecl *decl; // declaring var (diagnostics)
305+
};
306+
307+
/// Returns the resource/sampler heap stride spec constant, or None if no
308+
/// [[vk::*_heap_stride_constant_id]] was declared.
309+
const llvm::Optional<HeapStrideSpecConst> &getResourceHeapStride() const {
310+
return resourceHeapStride;
311+
}
312+
const llvm::Optional<HeapStrideSpecConst> &getSamplerHeapStride() const {
313+
return samplerHeapStride;
314+
}
315+
/// Returns the user [[vk::constant_id]] VarDecl that owns \p specId, or
316+
/// nullptr if no user spec-const has claimed that ID.
317+
const VarDecl *getUserSpecConstForId(uint32_t specId) const {
318+
auto it = userSpecConstIdMap.find(specId);
319+
return it != userSpecConstIdMap.end() ? it->second : nullptr;
320+
}
321+
299322
/// \brief Creates an external-visible variable and returns its instruction.
300323
SpirvVariable *createExternVar(const VarDecl *var);
301324

@@ -1058,6 +1081,11 @@ class DeclResultIdMapper {
10581081

10591082
SpirvUntypedVariableKHR *ResourceHeapVar = nullptr;
10601083
SpirvUntypedVariableKHR *SamplerHeapVar = nullptr;
1084+
llvm::Optional<HeapStrideSpecConst> resourceHeapStride;
1085+
llvm::Optional<HeapStrideSpecConst> samplerHeapStride;
1086+
/// Maps SpecId -> VarDecl for user [[vk::constant_id]] declarations, used to
1087+
/// detect SpecId collisions with heap-stride attributes.
1088+
llvm::DenseMap<uint32_t, const VarDecl *> userSpecConstIdMap;
10611089

10621090
/// Mapping from {RW|Append|Consume}StructuredBuffers to their
10631091
/// counter variables' (instr-ptr, is-alias-or-not) pairs

tools/clang/lib/SPIRV/EmitVisitor.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2695,9 +2695,15 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
26952695
curTypeInst.push_back(elemTypeId);
26962696
finalizeTypeInstruction();
26972697

2698-
auto stride = raType->getStride();
2699-
if (stride.hasValue())
2700-
emitDecoration(id, spv::Decoration::ArrayStride, {stride.getValue()});
2698+
if (auto *sc = raType->getStrideSpecConst()) {
2699+
const uint32_t scId = getOrAssignResultId<SpirvInstruction>(sc);
2700+
emitDecoration(id, spv::Decoration::ArrayStrideIdEXT, {scId}, llvm::None,
2701+
/*usesIdParams=*/true);
2702+
} else {
2703+
auto stride = raType->getStride();
2704+
if (stride.hasValue())
2705+
emitDecoration(id, spv::Decoration::ArrayStride, {stride.getValue()});
2706+
}
27012707
}
27022708
// NodePayloadArray types
27032709
else if (const auto *npaType = dyn_cast<NodePayloadArrayType>(type)) {

tools/clang/lib/SPIRV/SpirvContext.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,14 +268,15 @@ SpirvContext::getArrayType(const SpirvType *elemType, uint32_t elemCount,
268268

269269
const RuntimeArrayType *
270270
SpirvContext::getRuntimeArrayType(const SpirvType *elemType,
271-
llvm::Optional<uint32_t> arrayStride) {
272-
RuntimeArrayType type(elemType, arrayStride);
271+
llvm::Optional<uint32_t> arrayStride,
272+
SpirvInstruction *strideSpecConst) {
273+
RuntimeArrayType type(elemType, arrayStride, strideSpecConst);
273274
auto found = runtimeArrayTypes.find(&type);
274275
if (found != runtimeArrayTypes.end())
275276
return *found;
276277

277278
auto inserted = runtimeArrayTypes.insert(
278-
new (this) RuntimeArrayType(elemType, arrayStride));
279+
new (this) RuntimeArrayType(elemType, arrayStride, strideSpecConst));
279280
return *(inserted.first);
280281
}
281282

0 commit comments

Comments
 (0)