Serenity Operating System
1/*
2 * Copyright (c) 2023, Matthew Olsson <mattco@serenityos.org>
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 */
6
7#include "CellsHandler.h"
8#include <clang/AST/DeclCXX.h>
9#include <clang/AST/Type.h>
10#include <clang/ASTMatchers/ASTMatchFinder.h>
11#include <clang/ASTMatchers/ASTMatchers.h>
12#include <clang/Basic/Diagnostic.h>
13#include <clang/Basic/Specifiers.h>
14#include <clang/Frontend/CompilerInstance.h>
15#include <filesystem>
16#include <llvm/Support/Casting.h>
17#include <llvm/Support/raw_ostream.h>
18#include <unordered_set>
19#include <vector>
20
21CollectCellsHandler::CollectCellsHandler()
22{
23 using namespace clang::ast_matchers;
24
25 m_finder.addMatcher(
26 traverse(
27 clang::TK_IgnoreUnlessSpelledInSource,
28 cxxRecordDecl(decl().bind("record-decl"))),
29 this);
30}
31
32bool CollectCellsHandler::handleBeginSource(clang::CompilerInstance& ci)
33{
34 auto const& source_manager = ci.getSourceManager();
35 ci.getFileManager().getNumUniqueRealFiles();
36 auto file_id = source_manager.getMainFileID();
37 auto const* file_entry = source_manager.getFileEntryForID(file_id);
38 if (!file_entry)
39 return false;
40
41 auto current_filepath = std::filesystem::canonical(file_entry->getName().str());
42 llvm::outs() << "Processing " << current_filepath.string() << "\n";
43
44 return true;
45}
46
47bool record_inherits_from_cell(clang::CXXRecordDecl const& record)
48{
49 if (!record.isCompleteDefinition())
50 return false;
51
52 bool inherits_from_cell = record.getQualifiedNameAsString() == "JS::Cell";
53 record.forallBases([&](clang::CXXRecordDecl const* base) -> bool {
54 if (base->getQualifiedNameAsString() == "JS::Cell") {
55 inherits_from_cell = true;
56 return false;
57 }
58 return true;
59 });
60 return inherits_from_cell;
61}
62
63std::vector<clang::QualType> get_all_qualified_types(clang::QualType const& type)
64{
65 std::vector<clang::QualType> qualified_types;
66
67 if (auto const* template_specialization = type->getAs<clang::TemplateSpecializationType>()) {
68 auto specialization_name = template_specialization->getTemplateName().getAsTemplateDecl()->getQualifiedNameAsString();
69 // Do not unwrap GCPtr/NonnullGCPtr
70 if (specialization_name == "JS::GCPtr" || specialization_name == "JS::NonnullGCPtr") {
71 qualified_types.push_back(type);
72 } else {
73 for (size_t i = 0; i < template_specialization->getNumArgs(); i++) {
74 auto const& template_arg = template_specialization->getArg(i);
75 if (template_arg.getKind() == clang::TemplateArgument::Type) {
76 auto template_qualified_types = get_all_qualified_types(template_arg.getAsType());
77 std::move(template_qualified_types.begin(), template_qualified_types.end(), std::back_inserter(qualified_types));
78 }
79 }
80 }
81 } else {
82 qualified_types.push_back(type);
83 }
84
85 return qualified_types;
86}
87
88struct FieldValidationResult {
89 bool is_valid { false };
90 bool is_wrapped_in_gcptr { false };
91};
92
93FieldValidationResult validate_field(clang::FieldDecl const* field_decl)
94{
95 auto type = field_decl->getType();
96 if (auto const* elaborated_type = llvm::dyn_cast<clang::ElaboratedType>(type.getTypePtr()))
97 type = elaborated_type->desugar();
98
99 FieldValidationResult result { .is_valid = true };
100
101 for (auto const& qualified_type : get_all_qualified_types(type)) {
102 if (auto const* pointer_decl = qualified_type->getAs<clang::PointerType>()) {
103 if (auto const* pointee = pointer_decl->getPointeeCXXRecordDecl()) {
104 if (record_inherits_from_cell(*pointee)) {
105 result.is_valid = false;
106 result.is_wrapped_in_gcptr = false;
107 return result;
108 }
109 }
110 } else if (auto const* reference_decl = qualified_type->getAs<clang::ReferenceType>()) {
111 if (auto const* pointee = reference_decl->getPointeeCXXRecordDecl()) {
112 if (record_inherits_from_cell(*pointee)) {
113 result.is_valid = false;
114 result.is_wrapped_in_gcptr = false;
115 return result;
116 }
117 }
118 } else if (auto const* specialization = qualified_type->getAs<clang::TemplateSpecializationType>()) {
119 auto template_type_name = specialization->getTemplateName().getAsTemplateDecl()->getName();
120 if (template_type_name != "GCPtr" && template_type_name != "NonnullGCPtr")
121 return result;
122
123 if (specialization->getNumArgs() != 1)
124 return result; // Not really valid, but will produce a compilation error anyway
125
126 auto const& type_arg = specialization->getArg(0);
127 auto const* record_type = type_arg.getAsType()->getAs<clang::RecordType>();
128 if (!record_type)
129 return result;
130
131 auto const* record_decl = record_type->getAsCXXRecordDecl();
132 if (!record_decl->hasDefinition())
133 return result;
134
135 result.is_wrapped_in_gcptr = true;
136 result.is_valid = record_inherits_from_cell(*record_decl);
137 }
138 }
139
140 return result;
141}
142
143void CollectCellsHandler::run(clang::ast_matchers::MatchFinder::MatchResult const& result)
144{
145 clang::CXXRecordDecl const* record = result.Nodes.getNodeAs<clang::CXXRecordDecl>("record-decl");
146 if (!record || !record->isCompleteDefinition() || (!record->isClass() && !record->isStruct()))
147 return;
148
149 auto& diag_engine = result.Context->getDiagnostics();
150
151 for (clang::FieldDecl const* field : record->fields()) {
152 auto const& type = field->getType();
153
154 auto validation_results = validate_field(field);
155 if (!validation_results.is_valid) {
156 if (validation_results.is_wrapped_in_gcptr) {
157 auto diag_id = diag_engine.getCustomDiagID(clang::DiagnosticsEngine::Warning, "Specialization type must inherit from JS::Cell");
158 diag_engine.Report(field->getLocation(), diag_id);
159 } else {
160 auto diag_id = diag_engine.getCustomDiagID(clang::DiagnosticsEngine::Warning, "%0 to JS::Cell type should be wrapped in %1");
161 auto builder = diag_engine.Report(field->getLocation(), diag_id);
162 if (type->isReferenceType()) {
163 builder << "reference"
164 << "JS::NonnullGCPtr";
165 } else {
166 builder << "pointer"
167 << "JS::GCPtr";
168 }
169 }
170 }
171 }
172}