WHLSLFunctionStageChecker.cpp [plain text]
#include "config.h"
#include "WHLSLFunctionStageChecker.h"
#if ENABLE(WEBGPU)
#include "WHLSLCallExpression.h"
#include "WHLSLEntryPointType.h"
#include "WHLSLIntrinsics.h"
#include "WHLSLProgram.h"
#include "WHLSLVisitor.h"
namespace WebCore {
namespace WHLSL {
class FunctionStageChecker : public Visitor {
public:
FunctionStageChecker(AST::EntryPointType entryPointType, const Intrinsics& intrinsics)
: m_entryPointType(entryPointType)
, m_intrinsics(intrinsics)
{
}
public:
void visit(AST::CallExpression& callExpression) override
{
if ((&callExpression.function() == m_intrinsics.ddx() || &callExpression.function() == m_intrinsics.ddy()) && m_entryPointType != AST::EntryPointType::Fragment) {
setError(Error("Cannot use ddx or ddy in a non-fragment shader.", callExpression.codeLocation()));
return;
}
if ((&callExpression.function() == m_intrinsics.allMemoryBarrier() || &callExpression.function() == m_intrinsics.deviceMemoryBarrier() || &callExpression.function() == m_intrinsics.groupMemoryBarrier())
&& m_entryPointType != AST::EntryPointType::Compute) {
setError(Error("Cannot use memory barrier in a non-compute shader.", callExpression.codeLocation()));
return;
}
Visitor::visit(callExpression.function());
}
AST::EntryPointType m_entryPointType;
const Intrinsics& m_intrinsics;
};
Expected<void, Error> checkFunctionStages(Program& program)
{
for (auto& functionDefinition : program.functionDefinitions()) {
if (!functionDefinition->entryPointType())
continue;
FunctionStageChecker functionStageChecker(*functionDefinition->entryPointType(), program.intrinsics());
functionStageChecker.Visitor::visit(functionDefinition);
if (functionStageChecker.hasError())
return makeUnexpected(functionStageChecker.result().error());
}
return { };
}
}
}
#endif