mirror of
https://github.com/w23/xash3d-fwgs
synced 2024-12-15 21:50:59 +01:00
seba: collect all bindings for pipelines
This commit is contained in:
parent
6f80bc0015
commit
556440df27
@ -2,6 +2,7 @@
|
||||
import json
|
||||
import argparse
|
||||
import struct
|
||||
import copy
|
||||
from spirv import spv
|
||||
|
||||
parser = argparse.ArgumentParser(description='Build pipeline descriptor')
|
||||
@ -151,6 +152,38 @@ def parseSpirv(raw_data):
|
||||
|
||||
return ctx
|
||||
|
||||
class Binding:
|
||||
STAGE_VERTEX_BIT = 0x00000001
|
||||
STAGE_TESSELLATION_CONTROL_BIT = 0x00000002
|
||||
STAGE_TESSELLATION_EVALUATION_BIT = 0x00000004
|
||||
STAGE_GEOMETRY_BIT = 0x00000008
|
||||
STAGE_FRAGMENT_BIT = 0x00000010
|
||||
STAGE_COMPUTE_BIT = 0x00000020
|
||||
STAGE_ALL_GRAPHICS = 0x0000001F
|
||||
STAGE_ALL = 0x7FFFFFFF
|
||||
STAGE_RAYGEN_BIT_KHR = 0x00000100
|
||||
STAGE_ANY_HIT_BIT_KHR = 0x00000200
|
||||
STAGE_CLOSEST_HIT_BIT_KHR = 0x00000400
|
||||
STAGE_MISS_BIT_KHR = 0x00000800
|
||||
STAGE_INTERSECTION_BIT_KHR = 0x00001000
|
||||
STAGE_CALLABLE_BIT_KHR = 0x00002000
|
||||
STAGE_TASK_BIT_NV = 0x00000040
|
||||
STAGE_MESH_BIT_NV = 0x00000080
|
||||
STAGE_SUBPASS_SHADING_BIT_HUAWEI = 0x00004000
|
||||
|
||||
def __init__(self, name, descriptor_set, index, stages):
|
||||
self.name = name
|
||||
self.index = index
|
||||
self.descriptor_set = descriptor_set
|
||||
self.stages = stages
|
||||
#TODO: type, count, etc
|
||||
|
||||
def serialize(self, out):
|
||||
out.writeString(self.name)
|
||||
out.writeU32(self.descriptor_set)
|
||||
out.writeU32(self.index)
|
||||
out.writeU32(self.stages)
|
||||
|
||||
class Shader:
|
||||
def __init__(self, name, file):
|
||||
self.name = name
|
||||
@ -165,7 +198,23 @@ class Shader:
|
||||
ret += ('[%d:%d] (id=%d) %s\n' % (node.descriptor_set, node.binding, index, node.name))
|
||||
return ret
|
||||
|
||||
def getBindings(self):
|
||||
ret = []
|
||||
for node in self.spirv.nodes:
|
||||
if node.binding == None or node.descriptor_set == None:
|
||||
continue
|
||||
ret.append(Binding(node.name, node.descriptor_set, node.binding, 0))
|
||||
return ret
|
||||
|
||||
class Shaders:
|
||||
__suffixes = {
|
||||
Binding.STAGE_COMPUTE_BIT: '.comp.spv',
|
||||
Binding.STAGE_RAYGEN_BIT_KHR: '.rgen.spv',
|
||||
Binding.STAGE_ANY_HIT_BIT_KHR: '.rahit.spv',
|
||||
Binding.STAGE_CLOSEST_HIT_BIT_KHR: '.rchit.spv',
|
||||
Binding.STAGE_MISS_BIT_KHR: '.rmiss.spv'
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.__map = dict()
|
||||
self.__shaders = []
|
||||
@ -185,7 +234,8 @@ class Shaders:
|
||||
|
||||
raise Exception('Cannot load shader ' + name)
|
||||
|
||||
def load(self, name):
|
||||
def load(self, name, stage):
|
||||
name = name + self.__suffixes[stage]
|
||||
if name in self.__map:
|
||||
return self.__shaders[self.__map[name]]
|
||||
|
||||
@ -209,27 +259,46 @@ class Shaders:
|
||||
|
||||
shaders = Shaders()
|
||||
|
||||
|
||||
PIPELINE_COMPUTE = 1
|
||||
PIPELINE_RAYTRACING = 2
|
||||
NO_SHADER = 0xffffffff
|
||||
|
||||
class PipelineRayTracing:
|
||||
def __loadHit(hit):
|
||||
ret = dict()
|
||||
suffixes = {'closest': '.rchit.spv', 'any': '.rahit.spv'}
|
||||
for k, v in hit.items():
|
||||
ret[k] = shaders.load(v + suffixes[k])
|
||||
return ret
|
||||
|
||||
def __init__(self, name, desc):
|
||||
self.type = PIPELINE_RAYTRACING
|
||||
class Pipeline:
|
||||
def __init__(self, name, type_id):
|
||||
self.name = name
|
||||
self.rgen = shaders.load(desc['rgen'] + '.rgen.spv')
|
||||
self.miss = [] if not 'miss' in desc else [shaders.load(s + '.rmiss.spv') for s in desc['miss']]
|
||||
self.hit = [] if not 'hit' in desc else [PipelineRayTracing.__loadHit(hit) for hit in desc['hit']]
|
||||
self.type = type_id
|
||||
self.__bindings = {}
|
||||
|
||||
def addShader(self, shader_name, stage):
|
||||
shader = shaders.load(shader_name, stage)
|
||||
for binding in shader.getBindings():
|
||||
addr = (binding.descriptor_set, binding.index)
|
||||
if addr in self.__bindings:
|
||||
self.__bindings[addr].stages |= stage
|
||||
else:
|
||||
self.__bindings[addr] = copy.deepcopy(binding)
|
||||
|
||||
return shader
|
||||
|
||||
def serialize(self, out):
|
||||
print(self.__bindings)
|
||||
out.writeU32(self.type)
|
||||
out.writeString(self.name)
|
||||
#out.writeArray(self.__bindings)
|
||||
|
||||
class PipelineRayTracing(Pipeline):
|
||||
__hit2stage = {
|
||||
'closest': Binding.STAGE_CLOSEST_HIT_BIT_KHR,
|
||||
'any': Binding.STAGE_ANY_HIT_BIT_KHR,
|
||||
}
|
||||
def __init__(self, name, desc):
|
||||
super().__init__(name, PIPELINE_RAYTRACING)
|
||||
self.rgen = self.addShader(desc['rgen'], Binding.STAGE_RAYGEN_BIT_KHR)
|
||||
self.miss = [] if not 'miss' in desc else [self.addShader(s, Binding.STAGE_MISS_BIT_KHR) for s in desc['miss']]
|
||||
self.hit = [] if not 'hit' in desc else [self.__loadHit(hit) for hit in desc['hit']]
|
||||
|
||||
def serialize(self, out):
|
||||
super().serialize(out)
|
||||
out.writeU32(shaders.getIndex(self.rgen))
|
||||
out.writeArray([shaders.getIndex(s) for s in self.miss])
|
||||
|
||||
@ -238,13 +307,19 @@ class PipelineRayTracing:
|
||||
out.writeU32(shaders.getIndex(hit['closest']) if 'closest' in hit else NO_SHADER)
|
||||
out.writeU32(shaders.getIndex(hit['any']) if 'any' in hit else NO_SHADER)
|
||||
|
||||
class PipelineCompute:
|
||||
def __loadHit(self, hit):
|
||||
ret = dict()
|
||||
for k, v in hit.items():
|
||||
ret[k] = self.addShader(v, self.__hit2stage[k])
|
||||
return ret
|
||||
|
||||
class PipelineCompute(Pipeline):
|
||||
def __init__(self, name, desc):
|
||||
self.type = PIPELINE_COMPUTE
|
||||
self.name = name
|
||||
self.comp = shaders.load(desc['comp'] + '.comp.spv')
|
||||
super().__init__(name, PIPELINE_COMPUTE)
|
||||
self.comp = self.addShader(desc['comp'], Binding.STAGE_COMPUTE_BIT)
|
||||
|
||||
def serialize(self, out):
|
||||
super().serialize(out)
|
||||
out.writeU32(shaders.getIndex(self.comp))
|
||||
|
||||
def parsePipeline(pipelines, name, desc):
|
||||
@ -276,8 +351,6 @@ def writeOutput(file, pipelines):
|
||||
|
||||
out.writeU32(len(pipelines))
|
||||
for name, pipeline in pipelines.items():
|
||||
out.writeU32(pipeline.type)
|
||||
out.writeString(pipeline.name)
|
||||
pipeline.serialize(out)
|
||||
|
||||
pipelines = loadPipelines()
|
||||
|
Loading…
Reference in New Issue
Block a user