WHLSLChecker.cpp   [plain text]


/*
 * Copyright (C) 2019 Apple Inc. All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
 * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGE.
 */

#include "config.h"
#include "WHLSLChecker.h"

#if ENABLE(WEBGPU)

#include "WHLSLArrayReferenceType.h"
#include "WHLSLArrayType.h"
#include "WHLSLAssignmentExpression.h"
#include "WHLSLCallExpression.h"
#include "WHLSLCommaExpression.h"
#include "WHLSLDereferenceExpression.h"
#include "WHLSLDoWhileLoop.h"
#include "WHLSLDotExpression.h"
#include "WHLSLEntryPointType.h"
#include "WHLSLForLoop.h"
#include "WHLSLGatherEntryPointItems.h"
#include "WHLSLIfStatement.h"
#include "WHLSLIndexExpression.h"
#include "WHLSLInferTypes.h"
#include "WHLSLLogicalExpression.h"
#include "WHLSLLogicalNotExpression.h"
#include "WHLSLMakeArrayReferenceExpression.h"
#include "WHLSLMakePointerExpression.h"
#include "WHLSLNameContext.h"
#include "WHLSLPointerType.h"
#include "WHLSLProgram.h"
#include "WHLSLReadModifyWriteExpression.h"
#include "WHLSLResolvableType.h"
#include "WHLSLResolveOverloadImpl.h"
#include "WHLSLResolvingType.h"
#include "WHLSLReturn.h"
#include "WHLSLSwitchStatement.h"
#include "WHLSLTernaryExpression.h"
#include "WHLSLVisitor.h"
#include "WHLSLWhileLoop.h"
#include <wtf/HashMap.h>
#include <wtf/HashSet.h>
#include <wtf/Ref.h>
#include <wtf/Vector.h>
#include <wtf/text/WTFString.h>

namespace WebCore {

namespace WHLSL {

class PODChecker : public Visitor {
public:
    PODChecker() = default;

    virtual ~PODChecker() = default;

    void visit(AST::EnumerationDefinition& enumerationDefinition) override
    {
        Visitor::visit(enumerationDefinition);
    }

    void visit(AST::NativeTypeDeclaration& nativeTypeDeclaration) override
    {
        if (!nativeTypeDeclaration.isNumber()
            && !nativeTypeDeclaration.isVector()
            && !nativeTypeDeclaration.isMatrix())
            setError(Error("Use of native type is not a POD in entrypoint semantic.", nativeTypeDeclaration.codeLocation()));
    }

    void visit(AST::StructureDefinition& structureDefinition) override
    {
        Visitor::visit(structureDefinition);
    }

    void visit(AST::TypeDefinition& typeDefinition) override
    {
        Visitor::visit(typeDefinition);
    }

    void visit(AST::ArrayType& arrayType) override
    {
        Visitor::visit(arrayType);
    }

    void visit(AST::PointerType& pointerType) override
    {
        setError(Error("Illegal use of pointer in entrypoint semantic.", pointerType.codeLocation()));
    }

    void visit(AST::ArrayReferenceType& arrayReferenceType) override
    {
        setError(Error("Illegal use of array reference in entrypoint semantic.", arrayReferenceType.codeLocation()));
    }

    void visit(AST::TypeReference& typeReference) override
    {
        checkErrorAndVisit(typeReference.resolvedType());
    }
};

class FunctionKey {
public:
    FunctionKey() = default;
    FunctionKey(WTF::HashTableDeletedValueType)
    {
        m_castReturnType = bitwise_cast<AST::NamedType*>(static_cast<uintptr_t>(1));
    }

    FunctionKey(String name, Vector<std::reference_wrapper<AST::UnnamedType>> types, AST::NamedType* castReturnType = nullptr)
        : m_name(WTFMove(name))
        , m_types(WTFMove(types))
        , m_castReturnType(castReturnType)
    { }

    bool isEmptyValue() const { return m_name.isNull(); }
    bool isHashTableDeletedValue() const { return m_castReturnType == bitwise_cast<AST::NamedType*>(static_cast<uintptr_t>(1)); }

    unsigned hash() const
    {
        unsigned hash = IntHash<size_t>::hash(m_types.size());
        hash ^= m_name.hash();
        for (size_t i = 0; i < m_types.size(); ++i)
            hash ^= m_types[i].get().hash() + i;

        if (m_castReturnType)
            hash ^= WTF::PtrHash<AST::Type*>::hash(&m_castReturnType->unifyNode());

        return hash;
    }

    bool operator==(const FunctionKey& other) const
    {
        if (m_types.size() != other.m_types.size())
            return false;

        if (m_name != other.m_name)
            return false;

        for (size_t i = 0; i < m_types.size(); ++i) {
            if (!matches(m_types[i].get(), other.m_types[i].get()))
                return false;
        }

        if (static_cast<bool>(m_castReturnType) != static_cast<bool>(other.m_castReturnType))
            return false;

        if (!m_castReturnType)
            return true;

        if (&m_castReturnType->unifyNode() == &other.m_castReturnType->unifyNode())
            return true;

        return false;
    }

    struct Hash {
        static unsigned hash(const FunctionKey& key)
        {
            return key.hash();
        }

        static bool equal(const FunctionKey& a, const FunctionKey& b)
        {
            return a == b;
        }

        static const bool safeToCompareToEmptyOrDeleted = false;
        static const bool emptyValueIsZero = false;
    };

    struct Traits : public WTF::SimpleClassHashTraits<FunctionKey> {
        static const bool hasIsEmptyValueFunction = true;
        static bool isEmptyValue(const FunctionKey& key) { return key.isEmptyValue(); }
    };

private:
    String m_name;
    Vector<std::reference_wrapper<AST::UnnamedType>> m_types;
    AST::NamedType* m_castReturnType;
};

static AST::NativeFunctionDeclaration resolveWithReferenceComparator(CodeLocation location, ResolvingType& firstArgument, ResolvingType& secondArgument, const Intrinsics& intrinsics)
{
    const bool isOperator = true;
    auto returnType = AST::TypeReference::wrap(location, intrinsics.boolType());
    auto argumentType = firstArgument.visit(WTF::makeVisitor([](Ref<AST::UnnamedType>& unnamedType) -> Ref<AST::UnnamedType> {
        return unnamedType.copyRef();
    }, [&](RefPtr<ResolvableTypeReference>&) -> Ref<AST::UnnamedType> {
        return secondArgument.visit(WTF::makeVisitor([](Ref<AST::UnnamedType>& unnamedType) -> Ref<AST::UnnamedType> {
            return unnamedType.copyRef();
        }, [&](RefPtr<ResolvableTypeReference>&) -> Ref<AST::UnnamedType> {
            // We encountered "null == null".
            // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198162 This can probably be generalized, using the "preferred type" infrastructure used by generic literals
            ASSERT_NOT_REACHED();
            return AST::TypeReference::wrap(location, intrinsics.intType());
        }));
    }));
    AST::VariableDeclarations parameters;
    parameters.append(makeUniqueRef<AST::VariableDeclaration>(location, AST::Qualifiers(), argumentType.copyRef(), String(), nullptr, nullptr));
    parameters.append(makeUniqueRef<AST::VariableDeclaration>(location, AST::Qualifiers(), WTFMove(argumentType), String(), nullptr, nullptr));
    return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(location, AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator==", String::ConstructFromLiteral), WTFMove(parameters), nullptr, isOperator, ParsingMode::StandardLibrary));
}

enum class Acceptability {
    Yes,
    No
};

static Optional<AST::NativeFunctionDeclaration> resolveByInstantiation(const String& name, CodeLocation location, const Vector<std::reference_wrapper<ResolvingType>>& types, const Intrinsics& intrinsics)
{
    if (name == "operator==" && types.size() == 2) {
        auto acceptability = [](ResolvingType& resolvingType) -> Acceptability {
            return resolvingType.visit(WTF::makeVisitor([](Ref<AST::UnnamedType>& unnamedType) -> Acceptability {
                auto& unifyNode = unnamedType->unifyNode();
                return is<AST::UnnamedType>(unifyNode) && is<AST::ReferenceType>(downcast<AST::UnnamedType>(unifyNode)) ? Acceptability::Yes : Acceptability::No;
            }, [](RefPtr<ResolvableTypeReference>&) -> Acceptability {
                return Acceptability::No;
            }));
        };
        auto leftAcceptability = acceptability(types[0].get());
        auto rightAcceptability = acceptability(types[1].get());
        bool success = false;
        if (leftAcceptability == Acceptability::Yes && rightAcceptability == Acceptability::Yes) {
            auto& unnamedType1 = *types[0].get().getUnnamedType();
            auto& unnamedType2 = *types[1].get().getUnnamedType();
            success = matches(unnamedType1, unnamedType2);
        } 
        if (success)
            return resolveWithReferenceComparator(location, types[0].get(), types[1].get(), intrinsics);
    }
    return WTF::nullopt;
}

static bool checkSemantics(Vector<EntryPointItem>& inputItems, Vector<EntryPointItem>& outputItems, const Optional<AST::EntryPointType>& entryPointType, const Intrinsics& intrinsics)
{
    {
        auto checkDuplicateSemantics = [&](const Vector<EntryPointItem>& items) -> bool {
            for (size_t i = 0; i < items.size(); ++i) {
                for (size_t j = i + 1; j < items.size(); ++j) {
                    if (items[i].semantic == items[j].semantic)
                        return false;
                }
            }
            return true;
        };
        if (!checkDuplicateSemantics(inputItems))
            return false;
        if (!checkDuplicateSemantics(outputItems))
            return false;
    }

    {
        auto checkSemanticTypes = [&](const Vector<EntryPointItem>& items) -> bool {
            for (auto& item : items) {
                auto acceptable = WTF::visit(WTF::makeVisitor([&](const AST::BaseSemantic& semantic) -> bool {
                    return semantic.isAcceptableType(*item.unnamedType, intrinsics);
                }), *item.semantic);
                if (!acceptable)
                    return false;
            }
            return true;
        };
        if (!checkSemanticTypes(inputItems))
            return false;
        if (!checkSemanticTypes(outputItems))
            return false;
    }

    {
        auto checkSemanticForShaderType = [&](const Vector<EntryPointItem>& items, AST::BaseSemantic::ShaderItemDirection direction) -> bool {
            for (auto& item : items) {
                auto acceptable = WTF::visit(WTF::makeVisitor([&](const AST::BaseSemantic& semantic) -> bool {
                    return semantic.isAcceptableForShaderItemDirection(direction, entryPointType);
                }), *item.semantic);
                if (!acceptable)
                    return false;
            }
            return true;
        };
        if (!checkSemanticForShaderType(inputItems, AST::BaseSemantic::ShaderItemDirection::Input))
            return false;
        if (!checkSemanticForShaderType(outputItems, AST::BaseSemantic::ShaderItemDirection::Output))
            return false;
    }

    {
        auto checkPODData = [&](const Vector<EntryPointItem>& items) -> bool {
            for (auto& item : items) {
                PODChecker podChecker;
                if (is<AST::PointerType>(item.unnamedType))
                    podChecker.checkErrorAndVisit(downcast<AST::PointerType>(*item.unnamedType).elementType());
                else if (is<AST::ArrayReferenceType>(item.unnamedType))
                    podChecker.checkErrorAndVisit(downcast<AST::ArrayReferenceType>(*item.unnamedType).elementType());
                else if (is<AST::ArrayType>(item.unnamedType))
                    podChecker.checkErrorAndVisit(downcast<AST::ArrayType>(*item.unnamedType).type());
                else
                    continue;
                if (podChecker.hasError())
                    return false;
            }
            return true;
        };
        if (!checkPODData(inputItems))
            return false;
        if (!checkPODData(outputItems))
            return false;
    }

    return true;
}

static bool checkOperatorOverload(const AST::FunctionDefinition& functionDefinition)
{
    enum class CheckKind {
        Index,
        Dot
    };

    if (!functionDefinition.isOperator())
        return true;
    if (functionDefinition.isCast())
        return true;
    if (functionDefinition.name() == "operator++" || functionDefinition.name() == "operator--") {
        return functionDefinition.parameters().size() == 1
            && matches(*functionDefinition.parameters()[0]->type(), functionDefinition.type());
    }
    if (functionDefinition.name() == "operator+" || functionDefinition.name() == "operator-")
        return functionDefinition.parameters().size() == 1 || functionDefinition.parameters().size() == 2;
    if (functionDefinition.name() == "operator*"
        || functionDefinition.name() == "operator/"
        || functionDefinition.name() == "operator%"
        || functionDefinition.name() == "operator&"
        || functionDefinition.name() == "operator|"
        || functionDefinition.name() == "operator^"
        || functionDefinition.name() == "operator<<"
        || functionDefinition.name() == "operator>>")
        return functionDefinition.parameters().size() == 2;
    if (functionDefinition.name() == "operator~")
        return functionDefinition.parameters().size() == 1;
    return false;
}

class Checker : public Visitor {
public:
    Checker(const Intrinsics& intrinsics, Program& program)
        : m_intrinsics(intrinsics)
        , m_program(program)
    {
        auto addFunction = [&] (AST::FunctionDeclaration& function) {
            AST::NamedType* castReturnType = nullptr;
            if (function.isCast() && is<AST::NamedType>(function.type().unifyNode()))
                castReturnType = &downcast<AST::NamedType>(function.type().unifyNode());

            Vector<std::reference_wrapper<AST::UnnamedType>> types;
            types.reserveInitialCapacity(function.parameters().size());

            for (auto& param : function.parameters())
                types.uncheckedAppend(normalizedTypeForFunctionKey(*param->type()));

            auto addResult = m_functions.add(FunctionKey { function.name(), WTFMove(types), castReturnType }, Vector<std::reference_wrapper<AST::FunctionDeclaration>, 1>());
            addResult.iterator->value.append(function);
        };

        for (auto& function : m_program.functionDefinitions())
            addFunction(function.get());
        for (auto& function : m_program.nativeFunctionDeclarations())
            addFunction(function.get());
    }

    virtual ~Checker() = default;

    void visit(Program&) override;

    Expected<void, Error> assignTypes();

private:
    bool checkShaderType(const AST::FunctionDefinition&);
    bool isBoolType(ResolvingType&);
    struct RecurseInfo {
        ResolvingType& resolvingType;
        const AST::TypeAnnotation typeAnnotation;
    };
    Optional<RecurseInfo> recurseAndGetInfo(AST::Expression&, bool requiresLeftValue = false);
    Optional<RecurseInfo> getInfo(AST::Expression&, bool requiresLeftValue = false);
    RefPtr<AST::UnnamedType> recurseAndWrapBaseType(AST::PropertyAccessExpression&);
    bool recurseAndRequireBoolType(AST::Expression&);
    void assignConcreteType(AST::Expression&, Ref<AST::UnnamedType>, AST::TypeAnnotation);
    void assignConcreteType(AST::Expression&, AST::NamedType&, AST::TypeAnnotation);
    void assignType(AST::Expression&, RefPtr<ResolvableTypeReference>, AST::TypeAnnotation);
    void forwardType(AST::Expression&, ResolvingType&, AST::TypeAnnotation);

    void visit(AST::FunctionDefinition&) override;
    void visit(AST::FunctionDeclaration&) override;
    void visit(AST::EnumerationDefinition&) override;
    void visit(AST::TypeReference&) override;
    void visit(AST::VariableDeclaration&) override;
    void visit(AST::AssignmentExpression&) override;
    void visit(AST::ReadModifyWriteExpression&) override;
    void visit(AST::DereferenceExpression&) override;
    void visit(AST::MakePointerExpression&) override;
    void visit(AST::MakeArrayReferenceExpression&) override;
    void visit(AST::DotExpression&) override;
    void visit(AST::IndexExpression&) override;
    void visit(AST::VariableReference&) override;
    void visit(AST::Return&) override;
    void visit(AST::PointerType&) override;
    void visit(AST::ArrayReferenceType&) override;
    void visit(AST::IntegerLiteral&) override;
    void visit(AST::UnsignedIntegerLiteral&) override;
    void visit(AST::FloatLiteral&) override;
    void visit(AST::BooleanLiteral&) override;
    void visit(AST::EnumerationMemberLiteral&) override;
    void visit(AST::LogicalNotExpression&) override;
    void visit(AST::LogicalExpression&) override;
    void visit(AST::IfStatement&) override;
    void visit(AST::WhileLoop&) override;
    void visit(AST::DoWhileLoop&) override;
    void visit(AST::ForLoop&) override;
    void visit(AST::SwitchStatement&) override;
    void visit(AST::CommaExpression&) override;
    void visit(AST::TernaryExpression&) override;
    void visit(AST::CallExpression&) override;

    AST::FunctionDeclaration* resolveFunction(Vector<std::reference_wrapper<ResolvingType>>& types, const String& name, CodeLocation, AST::NamedType* castReturnType = nullptr);

    AST::UnnamedType& wrappedFloatType()
    {
        if (!m_wrappedFloatType)
            m_wrappedFloatType = AST::TypeReference::wrap({ }, m_intrinsics.floatType());
        return *m_wrappedFloatType;
    }

    AST::UnnamedType& wrappedUintType()
    {
        if (!m_wrappedUintType)
            m_wrappedUintType = AST::TypeReference::wrap({ }, m_intrinsics.uintType());
        return *m_wrappedUintType;
    }

    AST::UnnamedType& normalizedTypeForFunctionKey(AST::UnnamedType& type)
    {
        auto* unifyNode = &type.unifyNode();
        if (unifyNode == &m_intrinsics.uintType() || unifyNode == &m_intrinsics.intType())
            return wrappedFloatType();

        return type;
    }

    RefPtr<AST::TypeReference> m_wrappedFloatType;
    RefPtr<AST::TypeReference> m_wrappedUintType;
    HashMap<AST::Expression*, std::unique_ptr<ResolvingType>> m_typeMap;
    HashSet<String> m_vertexEntryPoints[AST::nameSpaceCount];
    HashSet<String> m_fragmentEntryPoints[AST::nameSpaceCount];
    HashSet<String> m_computeEntryPoints[AST::nameSpaceCount];
    const Intrinsics& m_intrinsics;
    Program& m_program;
    AST::FunctionDefinition* m_currentFunction { nullptr };
    HashMap<FunctionKey, Vector<std::reference_wrapper<AST::FunctionDeclaration>, 1>, FunctionKey::Hash, FunctionKey::Traits> m_functions;
    AST::NameSpace m_currentNameSpace { AST::NameSpace::StandardLibrary };
    bool m_isVisitingParameters { false };
};

void Checker::visit(Program& program)
{
    // These visiting functions might add new global statements, so don't use foreach syntax.
    for (size_t i = 0; i < program.typeDefinitions().size(); ++i)
        checkErrorAndVisit(program.typeDefinitions()[i]);
    for (size_t i = 0; i < program.structureDefinitions().size(); ++i)
        checkErrorAndVisit(program.structureDefinitions()[i]);
    for (size_t i = 0; i < program.enumerationDefinitions().size(); ++i)
        checkErrorAndVisit(program.enumerationDefinitions()[i]);
    for (size_t i = 0; i < program.nativeTypeDeclarations().size(); ++i)
        checkErrorAndVisit(program.nativeTypeDeclarations()[i]);

    for (size_t i = 0; i < program.functionDefinitions().size(); ++i)
        checkErrorAndVisit(program.functionDefinitions()[i]);
    for (size_t i = 0; i < program.nativeFunctionDeclarations().size(); ++i)
        checkErrorAndVisit(program.nativeFunctionDeclarations()[i]);
}

Expected<void, Error> Checker::assignTypes()
{
    for (auto& keyValuePair : m_typeMap) {
        auto success = keyValuePair.value->visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& unnamedType) -> bool {
            keyValuePair.key->setType(unnamedType.copyRef());
            return true;
        }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> bool {
            if (!resolvableTypeReference->resolvableType().maybeResolvedType()) {
                if (!static_cast<bool>(commit(resolvableTypeReference->resolvableType())))
                    return false;
            }
            keyValuePair.key->setType(resolvableTypeReference->resolvableType().resolvedType());
            return true;
        }));
        if (!success)
            return makeUnexpected(Error("Could not resolve the type of a constant."));
    }

    return { };
}

bool Checker::checkShaderType(const AST::FunctionDefinition& functionDefinition)
{
    auto index = static_cast<unsigned>(m_currentNameSpace);
    switch (*functionDefinition.entryPointType()) {
    case AST::EntryPointType::Vertex:
        return static_cast<bool>(m_vertexEntryPoints[index].add(functionDefinition.name()));
    case AST::EntryPointType::Fragment:
        return static_cast<bool>(m_fragmentEntryPoints[index].add(functionDefinition.name()));
    case AST::EntryPointType::Compute:
        return static_cast<bool>(m_computeEntryPoints[index].add(functionDefinition.name()));
    }
}

void Checker::visit(AST::FunctionDeclaration& functionDeclaration)
{
    m_isVisitingParameters = true;
    Visitor::visit(functionDeclaration);
    m_isVisitingParameters = false;
}

void Checker::visit(AST::FunctionDefinition& functionDefinition)
{
    m_currentNameSpace = functionDefinition.nameSpace();
    m_currentFunction = &functionDefinition;
    if (functionDefinition.entryPointType()) {
        if (!checkShaderType(functionDefinition)) {
            setError(Error("Duplicate entrypoint function.", functionDefinition.codeLocation()));
            return;
        }
        auto entryPointItems = gatherEntryPointItems(m_intrinsics, functionDefinition);
        if (!entryPointItems) {
            setError(entryPointItems.error());
            return;
        }
        if (!checkSemantics(entryPointItems->inputs, entryPointItems->outputs, functionDefinition.entryPointType(), m_intrinsics)) {
            setError(Error("Bad semantics for entrypoint.", functionDefinition.codeLocation()));
            return;
        }
    }
    if (!checkOperatorOverload(functionDefinition)) {
        setError(Error("Operator does not match expected signature.", functionDefinition.codeLocation()));
        return;
    }

    Visitor::visit(functionDefinition);
}

static RefPtr<AST::UnnamedType> matchAndCommit(ResolvingType& left, ResolvingType& right)
{
    return left.visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& left) -> RefPtr<AST::UnnamedType> {
        return right.visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& right) -> RefPtr<AST::UnnamedType> {
            if (matches(left, right))
                return left.copyRef();
            return nullptr;
        }, [&](RefPtr<ResolvableTypeReference>& right) -> RefPtr<AST::UnnamedType> {
            return matchAndCommit(left, right->resolvableType());
        }));
    }, [&](RefPtr<ResolvableTypeReference>& left) -> RefPtr<AST::UnnamedType> {
        return right.visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& right) -> RefPtr<AST::UnnamedType> {
            return matchAndCommit(right, left->resolvableType());
        }, [&](RefPtr<ResolvableTypeReference>& right) -> RefPtr<AST::UnnamedType> {
            return matchAndCommit(left->resolvableType(), right->resolvableType());
        }));
    }));
}

static RefPtr<AST::UnnamedType> matchAndCommit(ResolvingType& resolvingType, AST::UnnamedType& unnamedType)
{
    return resolvingType.visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& resolvingType) -> RefPtr<AST::UnnamedType> {
        if (matches(unnamedType, resolvingType))
            return &unnamedType;
        return nullptr;
    }, [&](RefPtr<ResolvableTypeReference>& resolvingType) -> RefPtr<AST::UnnamedType> {
        return matchAndCommit(unnamedType, resolvingType->resolvableType());
    }));
}

static bool matchAndCommit(ResolvingType& resolvingType, AST::NamedType& namedType)
{
    return resolvingType.visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& resolvingType) {
        if (matches(resolvingType, namedType))
            return true;
        return false;
    }, [&](RefPtr<ResolvableTypeReference>& resolvingType) -> bool {
        return matchAndCommit(namedType, resolvingType->resolvableType());
    }));
}

static RefPtr<AST::UnnamedType> commit(ResolvingType& resolvingType)
{
    return resolvingType.visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& unnamedType) -> RefPtr<AST::UnnamedType> {
        return unnamedType.copyRef();
    }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> RefPtr<AST::UnnamedType> {
        if (!resolvableTypeReference->resolvableType().maybeResolvedType())
            return commit(resolvableTypeReference->resolvableType());
        return &resolvableTypeReference->resolvableType().resolvedType();
    }));
}

AST::FunctionDeclaration* Checker::resolveFunction(Vector<std::reference_wrapper<ResolvingType>>& types, const String& name, CodeLocation location, AST::NamedType* castReturnType)
{
    Vector<std::reference_wrapper<AST::UnnamedType>> unnamedTypes;
    unnamedTypes.reserveInitialCapacity(types.size());

    for (auto resolvingType : types) {
        AST::UnnamedType* type = resolvingType.get().visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& unnamedType) -> AST::UnnamedType* {
            return unnamedType.ptr();
        }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> AST::UnnamedType* {
            if (resolvableTypeReference->resolvableType().maybeResolvedType())
                return &resolvableTypeReference->resolvableType().resolvedType();

            if (resolvableTypeReference->resolvableType().isFloatLiteralType()
                || resolvableTypeReference->resolvableType().isIntegerLiteralType()
                || resolvableTypeReference->resolvableType().isUnsignedIntegerLiteralType())
                return &wrappedFloatType();

            return commit(resolvableTypeReference->resolvableType()).get();
        }));

        if (!type) {
            setError(Error("Could not resolve the type of a constant."));
            return nullptr;
        }

        unnamedTypes.uncheckedAppend(normalizedTypeForFunctionKey(*type));
    }

    {
        auto iter = m_functions.find(FunctionKey { name, WTFMove(unnamedTypes), castReturnType });
        if (iter != m_functions.end()) {
            if (AST::FunctionDeclaration* function = resolveFunctionOverload(iter->value, types, castReturnType, m_currentNameSpace))
                return function;
        }
    }

    if (auto newFunction = resolveByInstantiation(name, location, types, m_intrinsics)) {
        m_program.append(WTFMove(*newFunction));
        return &m_program.nativeFunctionDeclarations().last();
    }

    return nullptr;
}

void Checker::visit(AST::EnumerationDefinition& enumerationDefinition)
{
    bool isSigned;
    auto* baseType = ([&]() -> AST::NativeTypeDeclaration* {
        checkErrorAndVisit(enumerationDefinition.type());
        auto& baseType = enumerationDefinition.type().unifyNode();
        if (!is<AST::NamedType>(baseType))
            return nullptr;
        auto& namedType = downcast<AST::NamedType>(baseType);
        if (!is<AST::NativeTypeDeclaration>(namedType))
            return nullptr;
        auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType);
        if (!nativeTypeDeclaration.isInt())
            return nullptr;
        isSigned = nativeTypeDeclaration.isSigned();
        return &nativeTypeDeclaration;
    })();
    if (!baseType) {
        setError(Error("Invalid base type for enum.", enumerationDefinition.codeLocation()));
        return;
    }

    auto enumerationMembers = enumerationDefinition.enumerationMembers();

    for (auto& member : enumerationMembers) {
        int64_t value = member.get().value();
        if (isSigned) {
            if (static_cast<int64_t>(static_cast<int32_t>(value)) != value) {
                setError(Error("Invalid enumeration value.", member.get().codeLocation()));
                return;
            }
        } else {
            if (static_cast<int64_t>(static_cast<uint32_t>(value)) != value) {
                setError(Error("Invalid enumeration value.", member.get().codeLocation()));
                return;
            }
        }
    }

    for (size_t i = 0; i < enumerationMembers.size(); ++i) {
        auto value = enumerationMembers[i].get().value();
        for (size_t j = i + 1; j < enumerationMembers.size(); ++j) {
            auto otherValue = enumerationMembers[j].get().value();
            if (value == otherValue) {
                setError(Error("Cannot declare duplicate enumeration values.", enumerationMembers[j].get().codeLocation()));
                return;
            }
        }
    }

    bool foundZero = false;
    for (auto& member : enumerationMembers) {
        if (!member.get().value()) {
            foundZero = true;
            break;
        }
    }
    if (!foundZero) {
        setError(Error("enum definition must contain a zero value.", enumerationDefinition.codeLocation()));
        return;
    }
}

void Checker::visit(AST::TypeReference& typeReference)
{
    ASSERT(typeReference.maybeResolvedType());

    for (auto& typeArgument : typeReference.typeArguments())
        checkErrorAndVisit(typeArgument);
}

auto Checker::recurseAndGetInfo(AST::Expression& expression, bool requiresLeftValue) -> Optional<RecurseInfo>
{
    Visitor::visit(expression);
    if (hasError())
        return WTF::nullopt;
    return getInfo(expression, requiresLeftValue);
}

auto Checker::getInfo(AST::Expression& expression, bool requiresLeftValue) -> Optional<RecurseInfo>
{
    auto typeIterator = m_typeMap.find(&expression);
    ASSERT(typeIterator != m_typeMap.end());

    const auto& typeAnnotation = expression.typeAnnotation();
    if (requiresLeftValue && typeAnnotation.isRightValue()) {
        setError(Error("Unexpected rvalue.", expression.codeLocation()));
        return WTF::nullopt;
    }
    return {{ *typeIterator->value, typeAnnotation }};
}

void Checker::visit(AST::VariableDeclaration& variableDeclaration)
{
    // ReadModifyWriteExpressions are the only place where anonymous variables exist,
    // and that doesn't recurse on the anonymous variables, so we can assume the variable has a type.
    checkErrorAndVisit(*variableDeclaration.type());
    if (matches(*variableDeclaration.type(), m_intrinsics.voidType())) {
        setError(Error("Variables can't have void type.", variableDeclaration.codeLocation()));
        return;
    }
    if (variableDeclaration.initializer()) {
        auto& lhsType = *variableDeclaration.type();
        auto initializerInfo = recurseAndGetInfo(*variableDeclaration.initializer());
        if (!initializerInfo)
            return;
        if (!matchAndCommit(initializerInfo->resolvingType, lhsType)) {
            setError(Error("Declared variable type does not match its initializer's type.", variableDeclaration.codeLocation()));
            return;
        }
    } else if (!m_isVisitingParameters && is<AST::ReferenceType>(variableDeclaration.type()->unifyNode())) {
        if (is<AST::PointerType>(variableDeclaration.type()->unifyNode()))
            setError(Error("Must assign to a pointer variable declaration in its initializer.", variableDeclaration.codeLocation()));
        else {
            ASSERT(is<AST::ArrayReferenceType>(variableDeclaration.type()->unifyNode()));
            setError(Error("Must assign to an array reference variable declaration in its initializer.", variableDeclaration.codeLocation()));
        }
        return;
    }
}

void Checker::assignConcreteType(AST::Expression& expression, Ref<AST::UnnamedType> unnamedType, AST::TypeAnnotation typeAnnotation = AST::RightValue())
{
    auto addResult = m_typeMap.add(&expression, makeUnique<ResolvingType>(WTFMove(unnamedType)));
    ASSERT_UNUSED(addResult, addResult.isNewEntry);
    expression.setTypeAnnotation(WTFMove(typeAnnotation));
}

void Checker::assignConcreteType(AST::Expression& expression, AST::NamedType& type, AST::TypeAnnotation annotation)
{
    auto unnamedType = AST::TypeReference::wrap(type.codeLocation(), type);
    Visitor::visit(unnamedType);
    assignConcreteType(expression, WTFMove(unnamedType), annotation);
}

void Checker::assignType(AST::Expression& expression, RefPtr<ResolvableTypeReference> resolvableTypeReference, AST::TypeAnnotation typeAnnotation = AST::RightValue())
{
    auto addResult = m_typeMap.add(&expression, makeUnique<ResolvingType>(WTFMove(resolvableTypeReference)));
    ASSERT_UNUSED(addResult, addResult.isNewEntry);
    expression.setTypeAnnotation(WTFMove(typeAnnotation));
}

void Checker::forwardType(AST::Expression& expression, ResolvingType& resolvingType, AST::TypeAnnotation typeAnnotation = AST::RightValue())
{
    resolvingType.visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& result) {
        auto addResult = m_typeMap.add(&expression, makeUnique<ResolvingType>(result.copyRef()));
        ASSERT_UNUSED(addResult, addResult.isNewEntry);
    }, [&](RefPtr<ResolvableTypeReference>& result) {
        auto addResult = m_typeMap.add(&expression, makeUnique<ResolvingType>(result.copyRef()));
        ASSERT_UNUSED(addResult, addResult.isNewEntry);
    }));
    expression.setTypeAnnotation(WTFMove(typeAnnotation));
}

void Checker::visit(AST::AssignmentExpression& assignmentExpression)
{
    auto leftInfo = recurseAndGetInfo(assignmentExpression.left(), true);
    if (!leftInfo)
        return;

    auto rightInfo = recurseAndGetInfo(assignmentExpression.right());
    if (!rightInfo)
        return;

    auto resultType = matchAndCommit(leftInfo->resolvingType, rightInfo->resolvingType);
    if (!resultType) {
        setError(Error("Left hand side of assignment does not match the type of the right hand side.", assignmentExpression.codeLocation()));
        return;
    }

    assignConcreteType(assignmentExpression, *resultType);
}

void Checker::visit(AST::ReadModifyWriteExpression& readModifyWriteExpression)
{
    auto leftValueInfo = recurseAndGetInfo(readModifyWriteExpression.leftValue(), true);
    if (!leftValueInfo)
        return;

    readModifyWriteExpression.oldValue().setType(*leftValueInfo->resolvingType.getUnnamedType());

    auto newValueInfo = recurseAndGetInfo(readModifyWriteExpression.newValueExpression());
    if (!newValueInfo)
        return;

    if (RefPtr<AST::UnnamedType> matchedType = matchAndCommit(leftValueInfo->resolvingType, newValueInfo->resolvingType))
        readModifyWriteExpression.newValue().setType(*matchedType);
    else {
        setError(Error("Base of the read-modify-write expression does not match the type of the new value.", readModifyWriteExpression.codeLocation()));
        return;
    }

    auto resultInfo = recurseAndGetInfo(readModifyWriteExpression.resultExpression());
    if (!resultInfo)
        return;

    forwardType(readModifyWriteExpression, resultInfo->resolvingType, AST::RightValue());
}

static AST::UnnamedType* getUnnamedType(ResolvingType& resolvingType)
{
    return resolvingType.visit(WTF::makeVisitor([](Ref<AST::UnnamedType>& type) -> AST::UnnamedType* {
        return type.ptr();
    }, [](RefPtr<ResolvableTypeReference>& type) -> AST::UnnamedType* {
        // FIXME: If the type isn't committed, should we just commit() it now?
        return type->resolvableType().maybeResolvedType();
    }));
}

void Checker::visit(AST::DereferenceExpression& dereferenceExpression)
{
    auto pointerInfo = recurseAndGetInfo(dereferenceExpression.pointer());
    if (!pointerInfo)
        return;

    auto* unnamedType = getUnnamedType(pointerInfo->resolvingType);

    auto* pointerType = ([&](AST::UnnamedType* unnamedType) -> AST::PointerType* {
        if (!unnamedType)
            return nullptr;
        auto& unifyNode = unnamedType->unifyNode();
        if (!is<AST::UnnamedType>(unifyNode))
            return nullptr;
        auto& unnamedUnifyType = downcast<AST::UnnamedType>(unifyNode);
        if (!is<AST::PointerType>(unnamedUnifyType))
            return nullptr;
        return &downcast<AST::PointerType>(unnamedUnifyType);
    })(unnamedType);
    if (!pointerType) {
        setError(Error("Cannot dereference a non-pointer type.", dereferenceExpression.codeLocation()));
        return;
    }

    assignConcreteType(dereferenceExpression, pointerType->elementType(), AST::LeftValue { pointerType->addressSpace() });
}

void Checker::visit(AST::MakePointerExpression& makePointerExpression)
{
    auto leftValueInfo = recurseAndGetInfo(makePointerExpression.leftValue(), true);
    if (!leftValueInfo)
        return;

    auto leftAddressSpace = leftValueInfo->typeAnnotation.leftAddressSpace();
    if (!leftAddressSpace) {
        setError(Error("Cannot take the address of a non lvalue.", makePointerExpression.codeLocation()));
        return;
    }

    auto* leftValueType = getUnnamedType(leftValueInfo->resolvingType);
    if (!leftValueType) {
        setError(Error("Cannot take the address of a value without a type.", makePointerExpression.codeLocation()));
        return;
    }

    assignConcreteType(makePointerExpression, AST::PointerType::create(makePointerExpression.codeLocation(), *leftAddressSpace, *leftValueType));
}

void Checker::visit(AST::MakeArrayReferenceExpression& makeArrayReferenceExpression)
{
    auto leftValueInfo = recurseAndGetInfo(makeArrayReferenceExpression.leftValue());
    if (!leftValueInfo)
        return;

    auto* leftValueType = getUnnamedType(leftValueInfo->resolvingType);
    if (!leftValueType) {
        setError(Error("Cannot make an array reference of a value without a type.", makeArrayReferenceExpression.codeLocation()));
        return;
    }

    auto& unifyNode = leftValueType->unifyNode();
    if (is<AST::UnnamedType>(unifyNode)) {
        auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
        if (is<AST::PointerType>(unnamedType)) {
            auto& pointerType = downcast<AST::PointerType>(unnamedType);
            assignConcreteType(makeArrayReferenceExpression, AST::ArrayReferenceType::create(makeArrayReferenceExpression.codeLocation(), pointerType.addressSpace(), pointerType.elementType()));
            return;
        }

        auto leftAddressSpace = leftValueInfo->typeAnnotation.leftAddressSpace();
        if (!leftAddressSpace) {
            setError(Error("Cannot make an array reference from a non-left-value.", makeArrayReferenceExpression.codeLocation()));
            return;
        }

        if (is<AST::ArrayType>(unnamedType)) {
            auto& arrayType = downcast<AST::ArrayType>(unnamedType);
            // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198163 Save the number of elements.
            assignConcreteType(makeArrayReferenceExpression, AST::ArrayReferenceType::create(makeArrayReferenceExpression.codeLocation(), *leftAddressSpace, arrayType.type()));
            return;
        }
    }

    auto leftAddressSpace = leftValueInfo->typeAnnotation.leftAddressSpace();
    if (!leftAddressSpace) {
        setError(Error("Cannot make an array reference from a non-left-value.", makeArrayReferenceExpression.codeLocation()));
        return;
    }

    assignConcreteType(makeArrayReferenceExpression, AST::ArrayReferenceType::create(makeArrayReferenceExpression.codeLocation(), *leftAddressSpace, *leftValueType));
}

void Checker::visit(AST::DotExpression& dotExpression)
{
    auto baseInfo = recurseAndGetInfo(dotExpression.base());
    if (!baseInfo)
        return;

    auto baseUnnamedType = commit(baseInfo->resolvingType);
    if (!baseUnnamedType) {
        setError(Error("Cannot resolve the type of the base of a dot expression.", dotExpression.codeLocation()));
        return;
    }

    auto& type = baseUnnamedType->unifyNode();
    if (is<AST::StructureDefinition>(type)) {
        auto& structure = downcast<AST::StructureDefinition>(type);    
        if (AST::StructureElement* element = structure.find(dotExpression.fieldName()))
            assignConcreteType(dotExpression, element->type(), baseInfo->typeAnnotation);
        else {
            setError(Error(makeString("Field name: '", dotExpression.fieldName(), "' does not exist on structure: ", structure.name()), dotExpression.codeLocation()));
            return;
        }
    } else if (dotExpression.fieldName() == "length") {
        if (is<AST::ArrayReferenceType>(type)
            || is<AST::ArrayType>(type)
            || (is<AST::NativeTypeDeclaration>(type) && downcast<AST::NativeTypeDeclaration>(type).isVector())) {
            assignConcreteType(dotExpression, wrappedUintType(), AST::RightValue());
        } else {
            setError(Error(".length field is only available on arrays, array references, or vectors.", dotExpression.codeLocation()));
            return;
        }
    } else if (is<AST::NativeTypeDeclaration>(type) && downcast<AST::NativeTypeDeclaration>(type).isVector()) {
        if (!m_program.isValidVectorProperty(dotExpression.fieldName())) {
            setError(Error(makeString("'.", dotExpression.fieldName(), "' is not a valid property on a vector."), dotExpression.codeLocation()));
            return;
        }

        auto typeAnnotation = baseInfo->typeAnnotation.isRightValue() ? AST::TypeAnnotation { AST::RightValue() } : AST::TypeAnnotation { AST::AbstractLeftValue() };

        size_t fieldLength = dotExpression.fieldName().length();
        auto& innerType = downcast<AST::NativeTypeDeclaration>(type).vectorTypeArgument();
        if (fieldLength == 1)
            assignConcreteType(dotExpression, innerType, typeAnnotation);
        else {
            if (matches(innerType, m_intrinsics.boolType()))
                assignConcreteType(dotExpression, m_intrinsics.boolVectorTypeForSize(fieldLength), typeAnnotation);
            else if (matches(innerType, m_intrinsics.intType()))
                assignConcreteType(dotExpression, m_intrinsics.intVectorTypeForSize(fieldLength), typeAnnotation);
            else if (matches(innerType, m_intrinsics.uintType()))
                assignConcreteType(dotExpression, m_intrinsics.uintVectorTypeForSize(fieldLength), typeAnnotation);
            else if (matches(innerType, m_intrinsics.floatType()))
                assignConcreteType(dotExpression, m_intrinsics.floatVectorTypeForSize(fieldLength), typeAnnotation);
            else
                RELEASE_ASSERT_NOT_REACHED();
        }
    } else
        setError(Error("Base value of dot expression must be a structure, array, or vector.", dotExpression.codeLocation()));
}

void Checker::visit(AST::IndexExpression& indexExpression)
{
    {
        auto indexInfo = recurseAndGetInfo(indexExpression.indexExpression());
        if (!indexInfo)
            return;

        if (!matchAndCommit(indexInfo->resolvingType, m_intrinsics.uintType())) {
            setError(Error("Index in an index expression must be a uint.", indexExpression.codeLocation()));
            return;
        }
    }

    auto baseInfo = recurseAndGetInfo(indexExpression.base());
    if (!baseInfo)
        return;

    auto baseUnnamedType = commit(baseInfo->resolvingType);
    if (!baseUnnamedType) {
        setError(Error("Cannot resolve the type of the base of an index expression.", indexExpression.codeLocation()));
        return;
    }

    auto& type = baseUnnamedType->unifyNode();
    if (is<AST::ArrayReferenceType>(type)) {
        auto& arrayReferenceType = downcast<AST::ArrayReferenceType>(type);
        assignConcreteType(indexExpression, arrayReferenceType.elementType(), AST::LeftValue { arrayReferenceType.addressSpace() });
    } else if (is<AST::ArrayType>(type))
        assignConcreteType(indexExpression, downcast<AST::ArrayType>(type).type(), baseInfo->typeAnnotation);
    else if (is<AST::NativeTypeDeclaration>(type)) {
        auto& nativeType = downcast<AST::NativeTypeDeclaration>(type);
        auto typeAnnotation = baseInfo->typeAnnotation.isRightValue() ? AST::TypeAnnotation { AST::RightValue() } : AST::TypeAnnotation { AST::AbstractLeftValue() };
        if (nativeType.isVector())
            assignConcreteType(indexExpression, nativeType.vectorTypeArgument(), typeAnnotation);
        else if (nativeType.isMatrix()) {
            auto& innerType = nativeType.matrixTypeArgument();
            unsigned numRows = nativeType.numberOfMatrixRows();
            if (matches(innerType, m_intrinsics.boolType()))
                assignConcreteType(indexExpression, m_intrinsics.boolVectorTypeForSize(numRows), typeAnnotation);
            else if (matches(innerType, m_intrinsics.floatType()))
                assignConcreteType(indexExpression, m_intrinsics.floatVectorTypeForSize(numRows), typeAnnotation);
            else
                RELEASE_ASSERT_NOT_REACHED();
        } else {
            setError(Error("Index expression on unknown type.", indexExpression.codeLocation()));
            return;
        }
    } else {
        setError(Error("Index expression on an unknown base type. Base type must be an array, array reference, vector, or matrix.", indexExpression.codeLocation()));
        return;
    }
}

void Checker::visit(AST::VariableReference& variableReference)
{
    ASSERT(variableReference.variable());
    ASSERT(variableReference.variable()->type());
    
    assignConcreteType(variableReference, *variableReference.variable()->type(), AST::LeftValue { AST::AddressSpace::Thread });
}

void Checker::visit(AST::Return& returnStatement)
{
    if (returnStatement.value()) {
        auto valueInfo = recurseAndGetInfo(*returnStatement.value());
        if (!valueInfo)
            return;
        if (!matchAndCommit(valueInfo->resolvingType, m_currentFunction->type()))
            setError(Error("Type of the return value must match the return type of the function.", returnStatement.codeLocation()));
        return;
    }

    if (!matches(m_currentFunction->type(), m_intrinsics.voidType()))
        setError(Error("Cannot return a value from a void function.", returnStatement.codeLocation()));
}

void Checker::visit(AST::PointerType&)
{
    // Following pointer types can cause infinite loops because of data structures
    // like linked lists.
    // FIXME: Make sure this function should be empty
}

void Checker::visit(AST::ArrayReferenceType&)
{
    // Following array reference types can cause infinite loops because of data
    // structures like linked lists.
    // FIXME: Make sure this function should be empty
}

void Checker::visit(AST::IntegerLiteral& integerLiteral)
{
    assignType(integerLiteral, adoptRef(*new ResolvableTypeReference(integerLiteral.type())));
}

void Checker::visit(AST::UnsignedIntegerLiteral& unsignedIntegerLiteral)
{
    assignType(unsignedIntegerLiteral, adoptRef(*new ResolvableTypeReference(unsignedIntegerLiteral.type())));
}

void Checker::visit(AST::FloatLiteral& floatLiteral)
{
    assignType(floatLiteral, adoptRef(*new ResolvableTypeReference(floatLiteral.type())));
}

void Checker::visit(AST::BooleanLiteral& booleanLiteral)
{
    assignConcreteType(booleanLiteral, AST::TypeReference::wrap(booleanLiteral.codeLocation(), m_intrinsics.boolType()));
}

void Checker::visit(AST::EnumerationMemberLiteral& enumerationMemberLiteral)
{
    ASSERT(enumerationMemberLiteral.enumerationDefinition());
    auto& enumerationDefinition = *enumerationMemberLiteral.enumerationDefinition();
    assignConcreteType(enumerationMemberLiteral, AST::TypeReference::wrap(enumerationMemberLiteral.codeLocation(), enumerationDefinition));
}

bool Checker::isBoolType(ResolvingType& resolvingType)
{
    return resolvingType.visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& left) -> bool {
        return matches(left, m_intrinsics.boolType());
    }, [&](RefPtr<ResolvableTypeReference>& left) -> bool {
        return static_cast<bool>(matchAndCommit(m_intrinsics.boolType(), left->resolvableType()));
    }));
}

bool Checker::recurseAndRequireBoolType(AST::Expression& expression)
{
    auto expressionInfo = recurseAndGetInfo(expression);
    if (!expressionInfo)
        return false;
    if (!isBoolType(expressionInfo->resolvingType)) {
        setError(Error("Expected bool type from expression.", expression.codeLocation()));
        return false;
    }
    return true;
}

void Checker::visit(AST::LogicalNotExpression& logicalNotExpression)
{
    if (!recurseAndRequireBoolType(logicalNotExpression.operand()))
        return;
    assignConcreteType(logicalNotExpression, AST::TypeReference::wrap(logicalNotExpression.codeLocation(), m_intrinsics.boolType()));
}

void Checker::visit(AST::LogicalExpression& logicalExpression)
{
    if (!recurseAndRequireBoolType(logicalExpression.left()))
        return;
    if (!recurseAndRequireBoolType(logicalExpression.right()))
        return;
    assignConcreteType(logicalExpression, AST::TypeReference::wrap(logicalExpression.codeLocation(), m_intrinsics.boolType()));
}

void Checker::visit(AST::IfStatement& ifStatement)
{
    if (!recurseAndRequireBoolType(ifStatement.conditional()))
        return;
    checkErrorAndVisit(ifStatement.body());
    if (ifStatement.elseBody())
        checkErrorAndVisit(*ifStatement.elseBody());
}

void Checker::visit(AST::WhileLoop& whileLoop)
{
    if (!recurseAndRequireBoolType(whileLoop.conditional()))
        return;
    checkErrorAndVisit(whileLoop.body());
}

void Checker::visit(AST::DoWhileLoop& doWhileLoop)
{
    checkErrorAndVisit(doWhileLoop.body());
    recurseAndRequireBoolType(doWhileLoop.conditional());
}

void Checker::visit(AST::ForLoop& forLoop)
{
    checkErrorAndVisit(forLoop.initialization());
    if (hasError())
        return;
    if (forLoop.condition()) {
        if (!recurseAndRequireBoolType(*forLoop.condition()))
            return;
    }
    if (forLoop.increment())
        checkErrorAndVisit(*forLoop.increment());
    checkErrorAndVisit(forLoop.body());
}

void Checker::visit(AST::SwitchStatement& switchStatement)
{
    auto* valueType = ([&]() -> AST::NamedType* {
        auto valueInfo = recurseAndGetInfo(switchStatement.value());
        if (!valueInfo)
            return nullptr;
        auto* valueType = getUnnamedType(valueInfo->resolvingType);
        if (!valueType)
            return nullptr;
        auto& valueUnifyNode = valueType->unifyNode();
        if (!is<AST::NamedType>(valueUnifyNode))
            return nullptr;
        auto& valueNamedUnifyNode = downcast<AST::NamedType>(valueUnifyNode);
        if (!(is<AST::NativeTypeDeclaration>(valueNamedUnifyNode) && downcast<AST::NativeTypeDeclaration>(valueNamedUnifyNode).isInt())
            && !is<AST::EnumerationDefinition>(valueNamedUnifyNode))
            return nullptr;
        return &valueNamedUnifyNode;
    })();
    if (!valueType) {
        setError(Error("Invalid type for the expression condition of the switch statement.", switchStatement.codeLocation()));
        return;
    }

    bool hasDefault = false;
    for (auto& switchCase : switchStatement.switchCases()) {
        checkErrorAndVisit(switchCase.block());
        if (!switchCase.value()) {
            hasDefault = true;
            continue;
        }
        auto success = switchCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> bool {
            return static_cast<bool>(matchAndCommit(*valueType, integerLiteral.type()));
        }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> bool {
            return static_cast<bool>(matchAndCommit(*valueType, unsignedIntegerLiteral.type()));
        }, [&](AST::FloatLiteral& floatLiteral) -> bool {
            return static_cast<bool>(matchAndCommit(*valueType, floatLiteral.type()));
        }, [&](AST::BooleanLiteral&) -> bool {
            return matches(*valueType, m_intrinsics.boolType());
        }, [&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) -> bool {
            ASSERT(enumerationMemberLiteral.enumerationDefinition());
            return matches(*valueType, *enumerationMemberLiteral.enumerationDefinition());
        }));
        if (!success) {
            setError(Error("Invalid type for switch case.", switchCase.codeLocation()));
            return;
        }
    }

    for (size_t i = 0; i < switchStatement.switchCases().size(); ++i) {
        auto& firstCase = switchStatement.switchCases()[i];
        for (size_t j = i + 1; j < switchStatement.switchCases().size(); ++j) {
            auto& secondCase = switchStatement.switchCases()[j];
            
            if (static_cast<bool>(firstCase.value()) != static_cast<bool>(secondCase.value()))
                continue;

            if (!static_cast<bool>(firstCase.value())) {
                setError(Error("Cannot define multiple default cases in switch statement.", secondCase.codeLocation()));
                return;
            }

            auto success = firstCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& firstIntegerLiteral) -> bool {
                return secondCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& secondIntegerLiteral) -> bool {
                    return firstIntegerLiteral.value() != secondIntegerLiteral.value();
                }, [&](AST::UnsignedIntegerLiteral& secondUnsignedIntegerLiteral) -> bool {
                    return static_cast<int64_t>(firstIntegerLiteral.value()) != static_cast<int64_t>(secondUnsignedIntegerLiteral.value());
                }, [](auto&) -> bool {
                    return true;
                }));
            }, [&](AST::UnsignedIntegerLiteral& firstUnsignedIntegerLiteral) -> bool {
                return secondCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& secondIntegerLiteral) -> bool {
                    return static_cast<int64_t>(firstUnsignedIntegerLiteral.value()) != static_cast<int64_t>(secondIntegerLiteral.value());
                }, [&](AST::UnsignedIntegerLiteral& secondUnsignedIntegerLiteral) -> bool {
                    return firstUnsignedIntegerLiteral.value() != secondUnsignedIntegerLiteral.value();
                }, [](auto&) -> bool {
                    return true;
                }));
            }, [&](AST::EnumerationMemberLiteral& firstEnumerationMemberLiteral) -> bool {
                return secondCase.value()->visit(WTF::makeVisitor([&](AST::EnumerationMemberLiteral& secondEnumerationMemberLiteral) -> bool {
                    ASSERT(firstEnumerationMemberLiteral.enumerationMember());
                    ASSERT(secondEnumerationMemberLiteral.enumerationMember());
                    return firstEnumerationMemberLiteral.enumerationMember() != secondEnumerationMemberLiteral.enumerationMember();
                }, [](auto&) -> bool {
                    return true;
                }));
            }, [](auto&) -> bool {
                return true;
            }));
            if (!success) {
                setError(Error("Cannot define duplicate case statements in a switch.", secondCase.codeLocation()));
                return;
            }
        }
    }

    if (!hasDefault) {
        if (is<AST::NativeTypeDeclaration>(*valueType)) {
            HashSet<int64_t> values;
            bool zeroValueExists;
            for (auto& switchCase : switchStatement.switchCases()) {
                auto value = switchCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> int64_t {
                    return integerLiteral.valueForSelectedType();
                }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> int64_t {
                    return unsignedIntegerLiteral.valueForSelectedType();
                }, [](auto&) -> int64_t {
                    ASSERT_NOT_REACHED();
                    return 0;
                }));
                if (!value)
                    zeroValueExists = true;
                else
                    values.add(value);
            }
            bool success = true;
            downcast<AST::NativeTypeDeclaration>(*valueType).iterateAllValues([&](int64_t value) -> bool {
                if (!value) {
                    if (!zeroValueExists) {
                        success = false;
                        return true;
                    }
                    return false;
                }
                if (!values.contains(value)) {
                    success = false;
                    return true;
                }
                return false;
            });
            if (!success) {
                setError(Error("Switch cases must be exhaustive or you must define a default case.", switchStatement.codeLocation()));
                return;
            }
        } else {
            HashSet<AST::EnumerationMember*> values;
            for (auto& switchCase : switchStatement.switchCases()) {
                switchCase.value()->visit(WTF::makeVisitor([&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) {
                    ASSERT(enumerationMemberLiteral.enumerationMember());
                    values.add(enumerationMemberLiteral.enumerationMember());
                }, [](auto&) {
                    ASSERT_NOT_REACHED();
                }));
            }
            for (auto& enumerationMember : downcast<AST::EnumerationDefinition>(*valueType).enumerationMembers()) {
                if (!values.contains(&enumerationMember.get())) {
                    setError(Error("Switch cases over an enum must be exhaustive or you must define a default case.", switchStatement.codeLocation()));
                    return;
                }
            }
        }
    }
}

void Checker::visit(AST::CommaExpression& commaExpression)
{
    ASSERT(commaExpression.list().size() > 0);
    Visitor::visit(commaExpression);
    if (hasError())
        return;
    auto lastInfo = getInfo(commaExpression.list().last());
    forwardType(commaExpression, lastInfo->resolvingType);
}

void Checker::visit(AST::TernaryExpression& ternaryExpression)
{
    auto predicateInfo = recurseAndRequireBoolType(ternaryExpression.predicate());
    if (!predicateInfo)
        return;

    auto bodyInfo = recurseAndGetInfo(ternaryExpression.bodyExpression());
    auto elseInfo = recurseAndGetInfo(ternaryExpression.elseExpression());
    
    auto resultType = matchAndCommit(bodyInfo->resolvingType, elseInfo->resolvingType);
    if (!resultType) {
        setError(Error("lhs and rhs of a ternary expression must match.", ternaryExpression.codeLocation()));
        return;
    }

    assignConcreteType(ternaryExpression, *resultType);
}

void Checker::visit(AST::CallExpression& callExpression)
{
    Vector<std::reference_wrapper<ResolvingType>> types;
    types.reserveInitialCapacity(callExpression.arguments().size());
    for (auto& argument : callExpression.arguments()) {
        auto argumentInfo = recurseAndGetInfo(argument);
        if (!argumentInfo)
            return;
        types.uncheckedAppend(argumentInfo->resolvingType);
    }
    // Don't recurse on the castReturnType, because it's guaranteed to be a NamedType, which will get visited later.
    // We don't want to recurse to the same node twice.

    auto* function = resolveFunction(types, callExpression.name(), callExpression.codeLocation());
    if (hasError())
        return;

    if (!function) {
        NameContext& nameContext = m_program.nameContext();
        auto castTypes = nameContext.getTypes(callExpression.name(), m_currentNameSpace);
        if (castTypes.size() == 1) {
            AST::NamedType& castType = castTypes[0].get();
            function = resolveFunction(types, "operator cast"_str, callExpression.codeLocation(), &castType);
            if (hasError())
                return;
            if (function)
                callExpression.setCastData(castType);
        }
    }

    if (!function) {
        // FIXME: Add better error messages for why we can't resolve to one of the overrides.
        // https://bugs.webkit.org/show_bug.cgi?id=200133
        setError(Error("Cannot resolve function call to a concrete callee. Make sure you are using compatible types.", callExpression.codeLocation()));
        return;
    }

    for (size_t i = 0; i < function->parameters().size(); ++i) {
        if (!matchAndCommit(types[i].get(), *function->parameters()[i]->type())) {
            setError(Error(makeString("Invalid type for parameter number ", i + 1, " in function call."), callExpression.codeLocation()));
            return;
        }
    }

    callExpression.setFunction(*function);

    assignConcreteType(callExpression, function->type());
}

Expected<void, Error> check(Program& program)
{
    Checker checker(program.intrinsics(), program);
    checker.checkErrorAndVisit(program);
    if (checker.hasError())
        return checker.result();
    return checker.assignTypes();
}

} // namespace WHLSL

} // namespace WebCore

#endif // ENABLE(WEBGPU)