WHLSLPropertyResolver.cpp [plain text]
#include "config.h"
#include "WHLSLPropertyResolver.h"
#if ENABLE(WEBGPU)
#include "WHLSLAST.h"
#include "WHLSLProgram.h"
#include "WHLSLReplaceWith.h"
#include "WHLSLVisitor.h"
namespace WebCore {
namespace WHLSL {
class PropertyResolver : public Visitor {
void handleLeftHandSideBase(UniqueRef<AST::Expression> base, UniqueRef<AST::Expression>& slot, Vector<UniqueRef<AST::Expression>>& expressions)
{
if (!base->mayBeEffectful()) {
slot = WTFMove(base);
return;
}
auto leftAddressSpace = base->typeAnnotation().leftAddressSpace();
RELEASE_ASSERT(leftAddressSpace);
CodeLocation codeLocation = base->codeLocation();
Ref<AST::UnnamedType> baseType = base->resolvedType();
Ref<AST::PointerType> pointerType = AST::PointerType::create(codeLocation, *leftAddressSpace, baseType.copyRef());
UniqueRef<AST::VariableDeclaration> pointerVariable = makeUniqueRef<AST::VariableDeclaration>(codeLocation, AST::Qualifiers { }, pointerType.ptr(), String(), nullptr, nullptr);
auto makeVariableReference = [&] {
auto variableReference = makeUniqueRef<AST::VariableReference>(AST::VariableReference::wrap(pointerVariable));
variableReference->setType(pointerType.copyRef());
variableReference->setTypeAnnotation(AST::LeftValue { AST::AddressSpace::Thread });
return variableReference;
};
{
auto pointerOfBase = makeUniqueRef<AST::MakePointerExpression>(codeLocation, WTFMove(base), AST::AddressEscapeMode::DoesNotEscape);
pointerOfBase->setType(pointerType.copyRef());
pointerOfBase->setTypeAnnotation(AST::RightValue());
auto assignment = makeUniqueRef<AST::AssignmentExpression>(codeLocation, makeVariableReference(), WTFMove(pointerOfBase));
assignment->setType(pointerType.copyRef());
assignment->setTypeAnnotation(AST::RightValue());
expressions.append(WTFMove(assignment));
}
{
auto dereference = makeUniqueRef<AST::DereferenceExpression>(codeLocation, makeVariableReference());
dereference->setType(baseType.copyRef());
dereference->setTypeAnnotation(AST::LeftValue { *leftAddressSpace });
slot = WTFMove(dereference);
}
m_variables.append(WTFMove(pointerVariable));
}
void handlePropertyAccess(AST::PropertyAccessExpression& propertyAccess, Vector<UniqueRef<AST::Expression>>& expressions)
{
AST::PropertyAccessExpression* currentPtr = &propertyAccess;
Vector<std::reference_wrapper<AST::PropertyAccessExpression>> chain;
while (true) {
AST::PropertyAccessExpression& current = *currentPtr;
chain.append(current);
if (is<AST::IndexExpression>(current))
checkErrorAndVisit(downcast<AST::IndexExpression>(current).indexExpression());
if (!is<AST::PropertyAccessExpression>(current.base()))
break;
currentPtr = &downcast<AST::PropertyAccessExpression>(current.base());
}
AST::PropertyAccessExpression& current = *currentPtr;
checkErrorAndVisit(current.base());
CodeLocation baseCodeLocation = current.base().codeLocation();
if (current.base().typeAnnotation().isRightValue()) {
UniqueRef<AST::VariableDeclaration> copy = makeUniqueRef<AST::VariableDeclaration>(baseCodeLocation, AST::Qualifiers { }, ¤t.base().resolvedType(), String(), nullptr, nullptr);
Ref<AST::UnnamedType> baseType = current.base().resolvedType();
auto makeVariableReference = [&] {
auto variableReference = makeUniqueRef<AST::VariableReference>(AST::VariableReference::wrap(copy));
variableReference->setType(baseType.copyRef());
variableReference->setTypeAnnotation(AST::LeftValue { AST::AddressSpace::Thread });
return variableReference;
};
auto assignment = makeUniqueRef<AST::AssignmentExpression>(baseCodeLocation, makeVariableReference(), current.takeBase());
assignment->setType(baseType.copyRef());
assignment->setTypeAnnotation(AST::RightValue());
expressions.append(WTFMove(assignment));
current.baseReference() = makeVariableReference();
m_variables.append(WTFMove(copy));
} else
handleLeftHandSideBase(current.takeBase(), current.baseReference(), expressions);
for (size_t i = chain.size(); i--; ) {
auto& access = chain[i].get();
if (is<AST::IndexExpression>(access) && downcast<AST::IndexExpression>(access).indexExpression().mayBeEffectful()) {
auto& indexExpression = downcast<AST::IndexExpression>(access);
Ref<AST::UnnamedType> indexType = indexExpression.indexExpression().resolvedType();
UniqueRef<AST::VariableDeclaration> indexVariable = makeUniqueRef<AST::VariableDeclaration>(access.codeLocation(), AST::Qualifiers { }, indexType.ptr(), String(), nullptr, nullptr);
auto makeVariableReference = [&] {
auto variableReference = makeUniqueRef<AST::VariableReference>(AST::VariableReference::wrap(indexVariable));
variableReference->setType(indexType.copyRef());
variableReference->setTypeAnnotation(AST::LeftValue { AST::AddressSpace::Thread });
return variableReference;
};
{
auto assignment = makeUniqueRef<AST::AssignmentExpression>(baseCodeLocation, makeVariableReference(), indexExpression.takeIndex());
assignment->setType(indexType.copyRef());
assignment->setTypeAnnotation(AST::RightValue());
expressions.append(WTFMove(assignment));
}
indexExpression.indexReference() = makeVariableReference();
m_variables.append(WTFMove(indexVariable));
}
}
}
void handlePropertyAccess(AST::PropertyAccessExpression& propertyAccess)
{
Vector<UniqueRef<AST::Expression>> expressions;
handlePropertyAccess(propertyAccess, expressions);
Ref<AST::UnnamedType> accessType = propertyAccess.resolvedType();
AST::CommaExpression* comma;
CodeLocation codeLocation = propertyAccess.codeLocation();
if (is<AST::IndexExpression>(propertyAccess)) {
auto& indexExpression = downcast<AST::IndexExpression>(propertyAccess);
auto newIndexExpression = makeUniqueRef<AST::IndexExpression>(codeLocation, indexExpression.takeBase(), indexExpression.takeIndex());
newIndexExpression->setType(indexExpression.resolvedType());
newIndexExpression->setTypeAnnotation(AST::TypeAnnotation(indexExpression.typeAnnotation()));
expressions.append(WTFMove(newIndexExpression));
comma = AST::replaceWith<AST::CommaExpression>(indexExpression, codeLocation, WTFMove(expressions));
} else {
RELEASE_ASSERT(is<AST::DotExpression>(propertyAccess));
auto& dotExpression = downcast<AST::DotExpression>(propertyAccess);
auto newDotExpression = makeUniqueRef<AST::DotExpression>(codeLocation, dotExpression.takeBase(), String(dotExpression.fieldName()));
newDotExpression->setType(dotExpression.resolvedType());
newDotExpression->setTypeAnnotation(AST::TypeAnnotation(dotExpression.typeAnnotation()));
expressions.append(WTFMove(newDotExpression));
comma = AST::replaceWith<AST::CommaExpression>(dotExpression, codeLocation, WTFMove(expressions));
}
comma->setType(WTFMove(accessType));
comma->setTypeAnnotation(AST::RightValue());
}
public:
void visit(AST::DotExpression& dotExpression) override
{
handlePropertyAccess(dotExpression);
}
void visit(AST::IndexExpression& indexExpression) override
{
handlePropertyAccess(indexExpression);
}
void visit(AST::ReadModifyWriteExpression& readModifyWrite) override
{
checkErrorAndVisit(readModifyWrite.newValueExpression());
checkErrorAndVisit(readModifyWrite.resultExpression());
Vector<UniqueRef<AST::Expression>> expressions;
CodeLocation codeLocation = readModifyWrite.codeLocation();
Ref<AST::UnnamedType> type = readModifyWrite.resolvedType();
if (is<AST::PropertyAccessExpression>(readModifyWrite.leftValue()))
handlePropertyAccess(downcast<AST::PropertyAccessExpression>(readModifyWrite.leftValue()), expressions);
else
handleLeftHandSideBase(readModifyWrite.takeLeftValue(), readModifyWrite.leftValueReference(), expressions);
{
UniqueRef<AST::ReadModifyWriteExpression> newReadModifyWrite = makeUniqueRef<AST::ReadModifyWriteExpression>(
readModifyWrite.codeLocation(), readModifyWrite.takeLeftValue(), readModifyWrite.takeOldValue(), readModifyWrite.takeNewValue());
newReadModifyWrite->setNewValueExpression(readModifyWrite.takeNewValueExpression());
newReadModifyWrite->setResultExpression(readModifyWrite.takeResultExpression());
newReadModifyWrite->setType(type.copyRef());
newReadModifyWrite->setTypeAnnotation(AST::TypeAnnotation(readModifyWrite.typeAnnotation()));
expressions.append(WTFMove(newReadModifyWrite));
}
auto* comma = AST::replaceWith<AST::CommaExpression>(readModifyWrite, codeLocation, WTFMove(expressions));
comma->setType(WTFMove(type));
comma->setTypeAnnotation(AST::RightValue());
}
void visit(AST::FunctionDefinition& functionDefinition) override
{
RELEASE_ASSERT(m_variables.isEmpty());
checkErrorAndVisit(static_cast<AST::FunctionDeclaration&>(functionDefinition));
checkErrorAndVisit(functionDefinition.block());
if (!m_variables.isEmpty()) {
functionDefinition.block().statements().insert(0,
makeUniqueRef<AST::VariableDeclarationsStatement>(functionDefinition.block().codeLocation(), WTFMove(m_variables)));
}
}
private:
AST::VariableDeclarations m_variables;
};
void resolveProperties(Program& program)
{
PropertyResolver resolver;
for (auto& function : program.functionDefinitions())
resolver.visit(function);
}
}
}
#endif // ENABLE(WEBGPU)