diff --git a/ref_vk/sebastian.py b/ref_vk/sebastian.py index 858b880e..583beee5 100755 --- a/ref_vk/sebastian.py +++ b/ref_vk/sebastian.py @@ -2,6 +2,7 @@ import json import argparse import struct +import copy from spirv import spv parser = argparse.ArgumentParser(description='Build pipeline descriptor') @@ -151,6 +152,38 @@ def parseSpirv(raw_data): return ctx +class Binding: + STAGE_VERTEX_BIT = 0x00000001 + STAGE_TESSELLATION_CONTROL_BIT = 0x00000002 + STAGE_TESSELLATION_EVALUATION_BIT = 0x00000004 + STAGE_GEOMETRY_BIT = 0x00000008 + STAGE_FRAGMENT_BIT = 0x00000010 + STAGE_COMPUTE_BIT = 0x00000020 + STAGE_ALL_GRAPHICS = 0x0000001F + STAGE_ALL = 0x7FFFFFFF + STAGE_RAYGEN_BIT_KHR = 0x00000100 + STAGE_ANY_HIT_BIT_KHR = 0x00000200 + STAGE_CLOSEST_HIT_BIT_KHR = 0x00000400 + STAGE_MISS_BIT_KHR = 0x00000800 + STAGE_INTERSECTION_BIT_KHR = 0x00001000 + STAGE_CALLABLE_BIT_KHR = 0x00002000 + STAGE_TASK_BIT_NV = 0x00000040 + STAGE_MESH_BIT_NV = 0x00000080 + STAGE_SUBPASS_SHADING_BIT_HUAWEI = 0x00004000 + + def __init__(self, name, descriptor_set, index, stages): + self.name = name + self.index = index + self.descriptor_set = descriptor_set + self.stages = stages + #TODO: type, count, etc + + def serialize(self, out): + out.writeString(self.name) + out.writeU32(self.descriptor_set) + out.writeU32(self.index) + out.writeU32(self.stages) + class Shader: def __init__(self, name, file): self.name = name @@ -165,7 +198,23 @@ class Shader: ret += ('[%d:%d] (id=%d) %s\n' % (node.descriptor_set, node.binding, index, node.name)) return ret + def getBindings(self): + ret = [] + for node in self.spirv.nodes: + if node.binding == None or node.descriptor_set == None: + continue + ret.append(Binding(node.name, node.descriptor_set, node.binding, 0)) + return ret + class Shaders: + __suffixes = { + Binding.STAGE_COMPUTE_BIT: '.comp.spv', + Binding.STAGE_RAYGEN_BIT_KHR: '.rgen.spv', + Binding.STAGE_ANY_HIT_BIT_KHR: '.rahit.spv', + Binding.STAGE_CLOSEST_HIT_BIT_KHR: '.rchit.spv', + Binding.STAGE_MISS_BIT_KHR: '.rmiss.spv' + } + def __init__(self): self.__map = dict() self.__shaders = [] @@ -185,7 +234,8 @@ class Shaders: raise Exception('Cannot load shader ' + name) - def load(self, name): + def load(self, name, stage): + name = name + self.__suffixes[stage] if name in self.__map: return self.__shaders[self.__map[name]] @@ -209,27 +259,46 @@ class Shaders: shaders = Shaders() - PIPELINE_COMPUTE = 1 PIPELINE_RAYTRACING = 2 NO_SHADER = 0xffffffff -class PipelineRayTracing: - def __loadHit(hit): - ret = dict() - suffixes = {'closest': '.rchit.spv', 'any': '.rahit.spv'} - for k, v in hit.items(): - ret[k] = shaders.load(v + suffixes[k]) - return ret - - def __init__(self, name, desc): - self.type = PIPELINE_RAYTRACING +class Pipeline: + def __init__(self, name, type_id): self.name = name - self.rgen = shaders.load(desc['rgen'] + '.rgen.spv') - self.miss = [] if not 'miss' in desc else [shaders.load(s + '.rmiss.spv') for s in desc['miss']] - self.hit = [] if not 'hit' in desc else [PipelineRayTracing.__loadHit(hit) for hit in desc['hit']] + self.type = type_id + self.__bindings = {} + + def addShader(self, shader_name, stage): + shader = shaders.load(shader_name, stage) + for binding in shader.getBindings(): + addr = (binding.descriptor_set, binding.index) + if addr in self.__bindings: + self.__bindings[addr].stages |= stage + else: + self.__bindings[addr] = copy.deepcopy(binding) + + return shader def serialize(self, out): + print(self.__bindings) + out.writeU32(self.type) + out.writeString(self.name) + #out.writeArray(self.__bindings) + +class PipelineRayTracing(Pipeline): + __hit2stage = { + 'closest': Binding.STAGE_CLOSEST_HIT_BIT_KHR, + 'any': Binding.STAGE_ANY_HIT_BIT_KHR, + } + def __init__(self, name, desc): + super().__init__(name, PIPELINE_RAYTRACING) + self.rgen = self.addShader(desc['rgen'], Binding.STAGE_RAYGEN_BIT_KHR) + self.miss = [] if not 'miss' in desc else [self.addShader(s, Binding.STAGE_MISS_BIT_KHR) for s in desc['miss']] + self.hit = [] if not 'hit' in desc else [self.__loadHit(hit) for hit in desc['hit']] + + def serialize(self, out): + super().serialize(out) out.writeU32(shaders.getIndex(self.rgen)) out.writeArray([shaders.getIndex(s) for s in self.miss]) @@ -238,13 +307,19 @@ class PipelineRayTracing: out.writeU32(shaders.getIndex(hit['closest']) if 'closest' in hit else NO_SHADER) out.writeU32(shaders.getIndex(hit['any']) if 'any' in hit else NO_SHADER) -class PipelineCompute: + def __loadHit(self, hit): + ret = dict() + for k, v in hit.items(): + ret[k] = self.addShader(v, self.__hit2stage[k]) + return ret + +class PipelineCompute(Pipeline): def __init__(self, name, desc): - self.type = PIPELINE_COMPUTE - self.name = name - self.comp = shaders.load(desc['comp'] + '.comp.spv') + super().__init__(name, PIPELINE_COMPUTE) + self.comp = self.addShader(desc['comp'], Binding.STAGE_COMPUTE_BIT) def serialize(self, out): + super().serialize(out) out.writeU32(shaders.getIndex(self.comp)) def parsePipeline(pipelines, name, desc): @@ -276,8 +351,6 @@ def writeOutput(file, pipelines): out.writeU32(len(pipelines)) for name, pipeline in pipelines.items(): - out.writeU32(pipeline.type) - out.writeString(pipeline.name) pipeline.serialize(out) pipelines = loadPipelines()