1030% faster comp shader

This commit is contained in:
Aada 2026-03-09 08:01:39 +02:00
parent 47ca1265e6
commit 9ab7191c95
2 changed files with 30 additions and 33 deletions

View file

@ -1,7 +1,7 @@
#[compute]
#version 450
// Invocations in the (x, y, z) dimension
layout(local_size_x = 1000, local_size_y = 1, local_size_z = 1) in;
layout(local_size_x = 25, local_size_y = 1, local_size_z = 1) in;
// A binding to the buffer we create in our script
layout(set = 0, binding = 0, std430) restrict buffer PointBuffer {

View file

@ -101,6 +101,9 @@ public class PlanetHelper
public Oct Octree = new Oct();
private Curve _remapCurve;
private RenderingDevice _rd;
private Rid _shader;
public PlanetHelper(MeshInstance3D meshInstance, TextureRect textureRect, Curve remapCurve)
{
_meshInstance = meshInstance;
@ -132,6 +135,11 @@ public class PlanetHelper
textureShaderMaterial.SetShaderParameter("mode", 1);
}
UpdateMesh();
_rd = RenderingServer.CreateLocalRenderingDevice();
var shaderFile = GD.Load<RDShaderFile>("res://shaders/compute/PlateExpansion.glsl");
var shaderBytecode = shaderFile.GetSpirV();
_shader = _rd.ShaderCreateFromSpirV(shaderBytecode);
}
public void InitializeGeneration()
@ -295,18 +303,6 @@ public class PlanetHelper
public void PlateGeneration()
{
var rd = RenderingServer.CreateLocalRenderingDevice();
var shaderFile = GD.Load<RDShaderFile>("res://shaders/compute/PlateExpansion.glsl");
GD.Print(shaderFile.BaseError);
var shaderBytecode = shaderFile.GetSpirV();
GD.Print(shaderBytecode.CompileErrorCompute);
GD.Print(shaderBytecode.CompileErrorFragment);
GD.Print(shaderBytecode.CompileErrorVertex);
GD.Print(shaderBytecode.CompileErrorTesselationEvaluation);
GD.Print(shaderBytecode.CompileErrorTesselationControl);
var shader = rd.ShaderCreateFromSpirV(shaderBytecode);
GD.Print(shader.IsValid);
// Prepare our data. We use floats in the shader, so we need 32 bit.
int[] points = Enumerable.Range(0, Mdt.GetVertexCount()).ToArray();
var sizePoints = points.Length * sizeof(int);
@ -330,7 +326,7 @@ public class PlanetHelper
UniformType = RenderingDevice.UniformType.StorageBuffer,
Binding = 0
};
var pointBuffer = rd.StorageBufferCreate((uint)pointBytes.Length, pointBytes);
var pointBuffer = _rd.StorageBufferCreate((uint)pointBytes.Length, pointBytes);
pointUniform.AddId(pointBuffer);
var neighborUniform = new RDUniform
@ -338,7 +334,7 @@ public class PlanetHelper
UniformType = RenderingDevice.UniformType.StorageBuffer,
Binding = 1
};
var neighborBuffer = rd.StorageBufferCreate((uint)neighborBytes.Length, neighborBytes);
var neighborBuffer = _rd.StorageBufferCreate((uint)neighborBytes.Length, neighborBytes);
neighborUniform.AddId(neighborBuffer);
var plateUniform = new RDUniform
@ -346,23 +342,23 @@ public class PlanetHelper
UniformType = RenderingDevice.UniformType.StorageBuffer,
Binding = 2
};
var plateBuffer = rd.StorageBufferCreate((uint)plateBytes.Length, plateBytes);
var plateBuffer = _rd.StorageBufferCreate((uint)plateBytes.Length, plateBytes);
plateUniform.AddId(plateBuffer);
var uniformSet = rd.UniformSetCreate([pointUniform, neighborUniform, plateUniform], shader, 0);
var uniformSet = _rd.UniformSetCreate([pointUniform, neighborUniform, plateUniform], _shader, 0);
// Create a compute pipeline
var pipeline = rd.ComputePipelineCreate(shader);
var computeList = rd.ComputeListBegin();
rd.ComputeListBindComputePipeline(computeList, pipeline);
rd.ComputeListBindUniformSet(computeList, uniformSet, 0);
uint xgroups = (uint)(Mathf.Ceil((double)points.Length / 1000.0));
rd.ComputeListDispatch(computeList, xGroups: xgroups, yGroups: 50, zGroups: 1);
rd.ComputeListEnd();
rd.Submit();
rd.Sync();
var pipeline = _rd.ComputePipelineCreate(_shader);
var computeList = _rd.ComputeListBegin();
_rd.ComputeListBindComputePipeline(computeList, pipeline);
_rd.ComputeListBindUniformSet(computeList, uniformSet, 0);
uint xgroups = (uint)(Mathf.Ceil((double)points.Length / 25.0));
_rd.ComputeListDispatch(computeList, xGroups: xgroups, yGroups: 50, zGroups: 1);
_rd.ComputeListEnd();
_rd.Submit();
_rd.Sync();
// Read back the data from the buffers
var outputBytes = rd.BufferGetData(plateBuffer);
var outputBytes = _rd.BufferGetData(plateBuffer);
var output = new int[plateids.Length];
Buffer.BlockCopy(outputBytes, 0, output, 0, plateBytes.Length);
@ -376,11 +372,12 @@ public class PlanetHelper
}
index++;
}
rd.FreeRid(pointBuffer);
rd.FreeRid(plateBuffer);
rd.FreeRid(neighborBuffer);
rd.FreeRid(pipeline);
rd.FreeRid(uniformSet);
_rd.FreeRid(pipeline);
_rd.FreeRid(uniformSet);
_rd.FreeRid(pointBuffer);
_rd.FreeRid(plateBuffer);
_rd.FreeRid(neighborBuffer);
AssignOceanPlates(Plates);
CompleteStage();