Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
1From fc7a8e35819bda632bdcf1cf75fd9abe4d4e077a Mon Sep 17 00:00:00 2001 2From: Christian Sigg <chsigg@users.noreply.github.com> 3Date: Thu, 16 Feb 2023 15:40:53 +0100 4Subject: [PATCH] Rebase Triton to LLVM-15. (#1070) 5 6This PR rebases Triton from LLVM-14 to LLVM-15. Most changes are 7mechanical, except for the analysis framework changes. 8--- 9 CMakeLists.txt | 6 +- 10 bin/CMakeLists.txt | 2 +- 11 bin/FileCheck/FileCheck.cpp | 3 + 12 bin/triton-opt.cpp | 6 +- 13 bin/triton-translate.cpp | 7 +- 14 include/triton/Analysis/Alias.h | 21 +- 15 include/triton/Analysis/Allocation.h | 2 + 16 include/triton/Analysis/AxisInfo.h | 56 ++- 17 include/triton/Analysis/Utility.h | 6 +- 18 include/triton/Conversion/Passes.td | 4 +- 19 include/triton/Dialect/Triton/IR/Dialect.h | 7 +- 20 .../triton/Dialect/Triton/IR/TritonDialect.td | 8 +- 21 include/triton/Dialect/Triton/IR/TritonOps.td | 12 +- 22 .../triton/Dialect/Triton/IR/TritonTypes.td | 2 + 23 .../Dialect/Triton/Transforms/Passes.td | 3 +- 24 include/triton/Dialect/TritonGPU/IR/Dialect.h | 4 +- 25 .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 7 + 26 .../Dialect/TritonGPU/IR/TritonGPUDialect.td | 2 +- 27 .../Dialect/TritonGPU/IR/TritonGPUOps.td | 13 +- 28 lib/Analysis/Alias.cpp | 14 +- 29 lib/Analysis/Allocation.cpp | 30 +- 30 lib/Analysis/AxisInfo.cpp | 79 ++-- 31 lib/Analysis/CMakeLists.txt | 2 +- 32 lib/Analysis/Membar.cpp | 2 +- 33 lib/Analysis/Utility.cpp | 54 +++ 34 .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 3 - 35 lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h | 10 +- 36 .../TritonGPUToLLVM/DotOpToLLVM.cpp | 5 - 37 .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 2 - 38 .../TritonGPUToLLVM/LoadStoreOpToLLVM.cpp | 5 +- 39 .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 2 - 40 .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 7 +- 41 .../TritonGPUToLLVM/TritonGPUToLLVMBase.h | 26 +- 42 .../TritonGPUToLLVM/TritonGPUToLLVMPass.cpp | 52 +-- 43 lib/Conversion/TritonGPUToLLVM/Utility.h | 5 +- 44 .../TritonToTritonGPUPass.cpp | 69 ++-- 45 lib/Dialect/Triton/IR/CMakeLists.txt | 10 +- 46 lib/Dialect/Triton/IR/Ops.cpp | 34 +- 47 lib/Dialect/Triton/Transforms/Combine.cpp | 6 +- 48 lib/Dialect/Triton/Transforms/Combine.td | 2 +- 49 lib/Dialect/TritonGPU/IR/Dialect.cpp | 27 +- 50 lib/Dialect/TritonGPU/Transforms/Coalesce.cpp | 20 +- 51 lib/Dialect/TritonGPU/Transforms/Combine.cpp | 2 +- 52 lib/Dialect/TritonGPU/Transforms/Combine.td | 1 + 53 .../Transforms/DecomposeConversions.cpp | 2 +- 54 lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 10 +- 55 .../Transforms/ReorderInstructions.cpp | 2 +- 56 .../Transforms/TritonGPUConversion.cpp | 12 +- 57 .../Transforms/UpdateMmaForVolta.cpp | 6 +- 58 lib/Dialect/TritonGPU/Transforms/Utility.cpp | 2 +- 59 lib/Target/LLVMIR/CMakeLists.txt | 3 +- 60 lib/Target/PTX/PTXTranslation.cpp | 3 + 61 python/setup.py | 15 +- 62 python/src/triton.cc | 85 +++-- 63 python/test/unit/language/test_core.py | 2 +- 64 python/triton/compiler.py | 4 +- 65 test/Analysis/test-alias.mlir | 24 +- 66 test/Analysis/test-alignment.mlir | 344 +++++++++--------- 67 test/Analysis/test-allocation.mlir | 32 +- 68 test/Analysis/test-membar.mlir | 38 +- 69 test/Conversion/triton_ops.mlir | 10 +- 70 test/Conversion/triton_to_tritongpu.mlir | 6 +- 71 test/Conversion/tritongpu_to_llvm.mlir | 94 ++--- 72 test/Target/tritongpu_to_llvmir.mlir | 4 +- 73 test/Target/tritongpu_to_ptx.mlir | 2 +- 74 test/Triton/combine.mlir | 40 +- 75 test/Triton/vecadd.mlir | 4 +- 76 test/TritonGPU/coalesce.mlir | 2 +- 77 test/TritonGPU/combine.mlir | 38 +- 78 test/TritonGPU/loop-pipeline.mlir | 22 +- 79 test/TritonGPU/matmul.mlir | 4 +- 80 test/TritonGPU/prefetch.mlir | 4 +- 81 test/TritonGPU/update-mma-for-volta.mlir | 4 +- 82 test/lib/Analysis/TestAlias.cpp | 29 +- 83 test/lib/Analysis/TestAllocation.cpp | 5 +- 84 test/lib/Analysis/TestAxisInfo.cpp | 51 +-- 85 test/lib/Analysis/TestMembar.cpp | 7 +- 86 78 files changed, 808 insertions(+), 742 deletions(-) 87 88diff --git a/CMakeLists.txt b/CMakeLists.txt 89index d0d361fc7c..b281a28400 100644 90--- a/CMakeLists.txt 91+++ b/CMakeLists.txt 92@@ -1,4 +1,7 @@ 93 cmake_minimum_required(VERSION 3.6) 94+ 95+cmake_policy(SET CMP0116 OLD) 96+ 97 include(ExternalProject) 98 99 set(CMAKE_CXX_STANDARD 17) 100@@ -155,7 +158,6 @@ if(TRITON_BUILD_PYTHON_MODULE) 101 endif() 102 endif() 103 104- 105 # # Triton 106 # file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc) 107 # if (WIN32 AND TRITON_BUILD_PYTHON_MODULE) 108@@ -212,7 +214,7 @@ if(TRITON_BUILD_PYTHON_MODULE) 109 # optimizations 110 MLIRPass 111 MLIRTransforms 112- MLIRLLVMIR 113+ MLIRLLVMDialect 114 MLIRSupport 115 MLIRTargetLLVMIRExport 116 MLIRExecutionEngine 117diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt 118index 906f635f8b..695b3479fd 100644 119--- a/bin/CMakeLists.txt 120+++ b/bin/CMakeLists.txt 121@@ -48,7 +48,7 @@ llvm_update_compile_flags(triton-translate) 122 # MLIR core 123 MLIROptLib 124 MLIRIR 125- MLIRLLVMIR 126+ MLIRLLVMDialect 127 MLIRPass 128 MLIRSupport 129 MLIRTransforms 130diff --git a/bin/FileCheck/FileCheck.cpp b/bin/FileCheck/FileCheck.cpp 131index 819efc3541..9ac6f1b277 100644 132--- a/bin/FileCheck/FileCheck.cpp 133+++ b/bin/FileCheck/FileCheck.cpp 134@@ -19,6 +19,7 @@ 135 #include "llvm/Support/CommandLine.h" 136 #include "llvm/Support/InitLLVM.h" 137 #include "llvm/Support/Process.h" 138+#include "llvm/Support/SourceMgr.h" 139 #include "llvm/Support/WithColor.h" 140 #include "llvm/Support/raw_ostream.h" 141 #include <cmath> 142@@ -360,6 +361,8 @@ static std::string GetCheckTypeAbbreviation(Check::FileCheckType Ty) { 143 return "bad-not"; 144 case Check::CheckBadCount: 145 return "bad-count"; 146+ case Check::CheckMisspelled: 147+ return "misspelled"; 148 case Check::CheckNone: 149 llvm_unreachable("invalid FileCheckType"); 150 } 151diff --git a/bin/triton-opt.cpp b/bin/triton-opt.cpp 152index 9f3b53b7ae..f96232e1b0 100644 153--- a/bin/triton-opt.cpp 154+++ b/bin/triton-opt.cpp 155@@ -8,7 +8,7 @@ 156 157 #include "mlir/IR/Dialect.h" 158 #include "mlir/InitAllPasses.h" 159-#include "mlir/Support/MlirOptMain.h" 160+#include "mlir/Tools/mlir-opt/MlirOptMain.h" 161 162 namespace mlir { 163 namespace test { 164@@ -33,8 +33,8 @@ int main(int argc, char **argv) { 165 // TODO: register Triton & TritonGPU passes 166 mlir::DialectRegistry registry; 167 registry.insert<mlir::triton::TritonDialect, 168- mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect, 169- mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect, 170+ mlir::triton::gpu::TritonGPUDialect, mlir::func::FuncDialect, 171+ mlir::math::MathDialect, mlir::arith::ArithmeticDialect, 172 mlir::scf::SCFDialect, mlir::gpu::GPUDialect>(); 173 174 return mlir::asMainReturnCode(mlir::MlirOptMain( 175diff --git a/bin/triton-translate.cpp b/bin/triton-translate.cpp 176index 05ba15e453..56b5d65857 100644 177--- a/bin/triton-translate.cpp 178+++ b/bin/triton-translate.cpp 179@@ -3,7 +3,7 @@ 180 #include "mlir/IR/AsmState.h" 181 #include "mlir/IR/BuiltinOps.h" 182 #include "mlir/IR/Dialect.h" 183-#include "mlir/Parser.h" 184+#include "mlir/Parser/Parser.h" 185 #include "mlir/Pass/Pass.h" 186 #include "mlir/Pass/PassManager.h" 187 #include "mlir/Support/FileUtilities.h" 188@@ -38,7 +38,7 @@ OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename, 189 mlir::DialectRegistry registry; 190 registry.insert<TritonDialect, triton::gpu::TritonGPUDialect, 191 mlir::math::MathDialect, arith::ArithmeticDialect, 192- StandardOpsDialect, scf::SCFDialect>(); 193+ scf::SCFDialect>(); 194 195 context.appendDialectRegistry(registry); 196 197@@ -50,7 +50,8 @@ OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename, 198 context.loadAllAvailableDialects(); 199 context.allowUnregisteredDialects(); 200 201- OwningOpRef<ModuleOp> module(parseSourceFile(sourceMgr, &context)); 202+ OwningOpRef<ModuleOp> module = 203+ parseSourceFile<ModuleOp>(sourceMgr, &context); 204 if (!module) { 205 llvm::errs() << "Parse MLIR file failed."; 206 return nullptr; 207diff --git a/include/triton/Analysis/Alias.h b/include/triton/Analysis/Alias.h 208index fa6b906fc9..631df518bc 100644 209--- a/include/triton/Analysis/Alias.h 210+++ b/include/triton/Analysis/Alias.h 211@@ -2,7 +2,7 @@ 212 #define TRITON_ANALYSIS_ALIAS_H 213 214 #include "mlir/Analysis/AliasAnalysis.h" 215-#include "mlir/Analysis/DataFlowAnalysis.h" 216+#include "mlir/Analysis/DataFlow/SparseAnalysis.h" 217 #include "llvm/ADT/DenseSet.h" 218 219 namespace mlir { 220@@ -21,7 +21,7 @@ class AliasInfo { 221 } 222 223 /// The pessimistic value state of a value without alias 224- static AliasInfo getPessimisticValueState(MLIRContext *context) { 225+ static AliasInfo getPessimisticValueState(MLIRContext *context = nullptr) { 226 return AliasInfo(); 227 } 228 static AliasInfo getPessimisticValueState(Value value) { return AliasInfo(); } 229@@ -29,6 +29,10 @@ class AliasInfo { 230 /// The union of both arguments 231 static AliasInfo join(const AliasInfo &lhs, const AliasInfo &rhs); 232 233+ void print(raw_ostream &os) const { 234+ llvm::interleaveComma(allocs, os, [&](Value alloc) { alloc.print(os); }); 235+ } 236+ 237 private: 238 /// The set of allocated values that are aliased by this lattice. 239 /// For now, we only consider aliased value produced by the following 240@@ -58,9 +62,13 @@ class AliasInfo { 241 //===----------------------------------------------------------------------===// 242 // Shared Memory Alias Analysis 243 //===----------------------------------------------------------------------===// 244-class SharedMemoryAliasAnalysis : public ForwardDataFlowAnalysis<AliasInfo> { 245+class SharedMemoryAliasAnalysis 246+ : public dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AliasInfo>> { 247 public: 248- using ForwardDataFlowAnalysis<AliasInfo>::ForwardDataFlowAnalysis; 249+ using dataflow::SparseDataFlowAnalysis< 250+ dataflow::Lattice<AliasInfo>>::SparseDataFlowAnalysis; 251+ using dataflow::SparseDataFlowAnalysis< 252+ dataflow::Lattice<AliasInfo>>::getLatticeElement; 253 254 /// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use. 255 /// Given two values, returns their aliasing behavior. 256@@ -70,9 +78,10 @@ class SharedMemoryAliasAnalysis : public ForwardDataFlowAnalysis<AliasInfo> { 257 ModRefResult getModRef(Operation *op, Value location); 258 259 /// Computes if the alloc set of the results are changed. 260- ChangeResult 261+ void 262 visitOperation(Operation *op, 263- ArrayRef<LatticeElement<AliasInfo> *> operands) override; 264+ ArrayRef<const dataflow::Lattice<AliasInfo> *> operands, 265+ ArrayRef<dataflow::Lattice<AliasInfo> *> results) override; 266 }; 267 268 } // namespace mlir 269diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h 270index b7c136d602..89b77034cc 100644 271--- a/include/triton/Analysis/Allocation.h 272+++ b/include/triton/Analysis/Allocation.h 273@@ -188,6 +188,8 @@ class Allocation { 274 friend class triton::AllocationAnalysis; 275 }; 276 277+template <typename T> Interval(T, T) -> Interval<T>; 278+ 279 } // namespace mlir 280 281 #endif // TRITON_ANALYSIS_ALLOCATION_H 282diff --git a/include/triton/Analysis/AxisInfo.h b/include/triton/Analysis/AxisInfo.h 283index fdfbd8fbb3..7083b9c43b 100644 284--- a/include/triton/Analysis/AxisInfo.h 285+++ b/include/triton/Analysis/AxisInfo.h 286@@ -1,9 +1,10 @@ 287 #ifndef TRITON_ANALYSIS_AXISINFO_H 288 #define TRITON_ANALYSIS_AXISINFO_H 289 290-#include "mlir/Analysis/DataFlowAnalysis.h" 291+#include "mlir/Analysis/DataFlow/SparseAnalysis.h" 292 #include "llvm/Support/raw_ostream.h" 293 294+#include "mlir/Support/LLVM.h" 295 #include "triton/Analysis/Utility.h" 296 #include "triton/Dialect/Triton/IR/Dialect.h" 297 #include "triton/Dialect/TritonGPU/IR/Dialect.h" 298@@ -62,7 +63,7 @@ class AxisInfo { 299 } 300 301 /// The pessimistic value state of the contiguity is unknown. 302- static AxisInfo getPessimisticValueState(MLIRContext *context) { 303+ static AxisInfo getPessimisticValueState(MLIRContext *context = nullptr) { 304 return AxisInfo(); 305 } 306 static AxisInfo getPessimisticValueState(Value value); 307@@ -70,6 +71,22 @@ class AxisInfo { 308 /// The gcd of both arguments for each dimension 309 static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs); 310 311+ void print(raw_ostream &os) const { 312+ auto print = [&](StringRef name, DimVectorT vec) { 313+ os << name << " = ["; 314+ llvm::interleaveComma(vec, os); 315+ os << "]"; 316+ }; 317+ print("contiguity", contiguity); 318+ print(", divisibility", divisibility); 319+ print(", constancy", constancy); 320+ os << ", constant_value = "; 321+ if (constantValue) 322+ os << *constantValue; 323+ else 324+ os << "<none>"; 325+ } 326+ 327 private: 328 /// The _contiguity_ information maps the `d`-th 329 /// dimension to the length of the shortest 330@@ -147,7 +164,8 @@ class AxisInfoVisitor { 331 } 332 333 virtual AxisInfo 334- getAxisInfo(Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) = 0; 335+ getAxisInfo(Operation *op, 336+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) = 0; 337 338 virtual bool match(Operation *op) = 0; 339 }; 340@@ -157,15 +175,16 @@ template <typename OpTy> class AxisInfoVisitorImpl : public AxisInfoVisitor { 341 public: 342 using AxisInfoVisitor::AxisInfoVisitor; 343 344- AxisInfo getAxisInfo(Operation *op, 345- ArrayRef<LatticeElement<AxisInfo> *> operands) final { 346+ AxisInfo 347+ getAxisInfo(Operation *op, 348+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) final { 349 return getAxisInfo(cast<OpTy>(op), operands); 350 } 351 352 bool match(Operation *op) final { return isa<OpTy>(op); } 353 354- virtual AxisInfo getAxisInfo(OpTy op, 355- ArrayRef<LatticeElement<AxisInfo> *> operands) { 356+ virtual AxisInfo 357+ getAxisInfo(OpTy op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) { 358 llvm_unreachable("Unimplemented getAxisInfo"); 359 } 360 }; 361@@ -176,8 +195,9 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> { 362 public: 363 using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl; 364 365- AxisInfo getAxisInfo(OpTy op, 366- ArrayRef<LatticeElement<AxisInfo> *> operands) override { 367+ AxisInfo 368+ getAxisInfo(OpTy op, 369+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override { 370 auto lhsInfo = operands[0]->getValue(); 371 auto rhsInfo = operands[1]->getValue(); 372 auto rank = lhsInfo.getRank(); 373@@ -230,7 +250,8 @@ class AxisInfoVisitorList { 374 (visitors.emplace_back(std::make_unique<Ts>()), ...); 375 } 376 377- AxisInfo apply(Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) { 378+ AxisInfo apply(Operation *op, 379+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) { 380 for (auto &visitor : visitors) 381 if (visitor->match(op)) 382 return visitor->getAxisInfo(op, operands); 383@@ -241,16 +262,19 @@ class AxisInfoVisitorList { 384 std::vector<std::unique_ptr<AxisInfoVisitor>> visitors; 385 }; 386 387-class AxisInfoAnalysis : public ForwardDataFlowAnalysis<AxisInfo> { 388+class AxisInfoAnalysis 389+ : public dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AxisInfo>> { 390 private: 391 AxisInfoVisitorList visitors; 392 393 public: 394- AxisInfoAnalysis(MLIRContext *context); 395+ AxisInfoAnalysis(DataFlowSolver &solver); 396+ using dataflow::SparseDataFlowAnalysis< 397+ dataflow::Lattice<AxisInfo>>::getLatticeElement; 398 399- ChangeResult 400- visitOperation(Operation *op, 401- ArrayRef<LatticeElement<AxisInfo> *> operands) override; 402+ void visitOperation(Operation *op, 403+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands, 404+ ArrayRef<dataflow::Lattice<AxisInfo> *> results) override; 405 406 unsigned getPtrContiguity(Value ptr); 407 408@@ -261,4 +285,4 @@ class AxisInfoAnalysis : public ForwardDataFlowAnalysis<AxisInfo> { 409 410 } // namespace mlir 411 412-#endif 413\ No newline at end of file 414+#endif 415diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h 416index c5ac137dc1..ee7fadb59d 100644 417--- a/include/triton/Analysis/Utility.h 418+++ b/include/triton/Analysis/Utility.h 419@@ -1,6 +1,7 @@ 420 #ifndef TRITON_ANALYSIS_UTILITY_H 421 #define TRITON_ANALYSIS_UTILITY_H 422 423+#include "mlir/Analysis/DataFlowFramework.h" 424 #include "mlir/Analysis/SliceAnalysis.h" 425 #include "triton/Dialect/TritonGPU/IR/Dialect.h" 426 #include <algorithm> 427@@ -12,7 +13,7 @@ namespace mlir { 428 class ReduceOpHelper { 429 public: 430 explicit ReduceOpHelper(triton::ReduceOp op) : op(op) { 431- srcTy = op.operand().getType().cast<RankedTensorType>(); 432+ srcTy = op.getOperand().getType().cast<RankedTensorType>(); 433 } 434 435 ArrayRef<int64_t> getSrcShape() { return srcTy.getShape(); } 436@@ -103,6 +104,9 @@ SetVector<Operation *> 437 multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter = nullptr, 438 TransitiveFilter forwardFilter = nullptr); 439 440+// Create a basic DataFlowSolver with constant and dead code analysis included. 441+std::unique_ptr<DataFlowSolver> createDataFlowSolver(); 442+ 443 } // namespace mlir 444 445 #endif // TRITON_ANALYSIS_UTILITY_H 446diff --git a/include/triton/Conversion/Passes.td b/include/triton/Conversion/Passes.td 447index 70bb20b78e..be00eb2dac 100644 448--- a/include/triton/Conversion/Passes.td 449+++ b/include/triton/Conversion/Passes.td 450@@ -12,7 +12,6 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO 451 452 let dependentDialects = ["mlir::arith::ArithmeticDialect", 453 "mlir::math::MathDialect", 454- "mlir::StandardOpsDialect", 455 // TODO: Does this pass depend on SCF? 456 "mlir::scf::SCFDialect", 457 "mlir::triton::TritonDialect", 458@@ -41,8 +40,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp" 459 "mlir::tensor::TensorDialect", 460 "mlir::triton::TritonDialect", 461 "mlir::triton::gpu::TritonGPUDialect", 462- "mlir::NVVM::NVVMDialect", 463- "mlir::StandardOpsDialect"]; 464+ "mlir::NVVM::NVVMDialect"]; 465 466 let options = [ 467 Option<"computeCapability", "compute-capability", 468diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h 469index e8012a51df..15869e262e 100644 470--- a/include/triton/Dialect/Triton/IR/Dialect.h 471+++ b/include/triton/Dialect/Triton/IR/Dialect.h 472@@ -1,14 +1,15 @@ 473 #ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_ 474 #define TRITON_DIALECT_TRITON_IR_DIALECT_H_ 475 476+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 477+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" 478+#include "mlir/Dialect/Func/IR/FuncOps.h" 479 #include "mlir/Dialect/Math/IR/Math.h" 480-#include "mlir/Dialect/SCF/SCF.h" 481-#include "mlir/Dialect/StandardOps/IR/Ops.h" 482+#include "mlir/Dialect/SCF/IR/SCF.h" 483 #include "mlir/Dialect/Tensor/IR/Tensor.h" 484 #include "mlir/IR/BuiltinOps.h" 485 #include "mlir/IR/Dialect.h" 486 #include "mlir/Interfaces/ControlFlowInterfaces.h" 487- 488 #include "triton/Dialect/Triton/IR/Dialect.h.inc" 489 #include "triton/Dialect/Triton/IR/OpsEnums.h.inc" 490 #include "triton/Dialect/Triton/IR/Traits.h" 491diff --git a/include/triton/Dialect/Triton/IR/TritonDialect.td b/include/triton/Dialect/Triton/IR/TritonDialect.td 492index 07b069e14f..d98ce73884 100644 493--- a/include/triton/Dialect/Triton/IR/TritonDialect.td 494+++ b/include/triton/Dialect/Triton/IR/TritonDialect.td 495@@ -25,12 +25,9 @@ def Triton_Dialect : Dialect { 496 let dependentDialects = [ 497 "arith::ArithmeticDialect", 498 "math::MathDialect", 499- "StandardOpsDialect", 500 "scf::SCFDialect", 501- 502- // Since LLVM 15 503- // "cf::ControlFlowDialect", 504- // "func::FuncDialect" 505+ "cf::ControlFlowDialect", 506+ "func::FuncDialect" 507 ]; 508 509 let extraClassDeclaration = [{ 510@@ -38,6 +35,7 @@ def Triton_Dialect : Dialect { 511 }]; 512 513 let hasConstantMaterializer = 1; 514+ let useDefaultTypePrinterParser = 1; 515 } 516 517 include "triton/Dialect/Triton/IR/TritonTypes.td" 518diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td 519index 779e0b648c..0a69211179 100644 520--- a/include/triton/Dialect/Triton/IR/TritonOps.td 521+++ b/include/triton/Dialect/Triton/IR/TritonOps.td 522@@ -141,11 +141,7 @@ def TT_LoadOp : TT_Op<"load", 523 "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, 524 ]; 525 526- // let assemblyFormat = "operands attr-dict `:` type($result)"; 527- let parser = [{ return mlir::triton::parseLoadOp(parser, result); }]; 528- 529- let printer = [{ return mlir::triton::printLoadOp(p, *this); }]; 530- 531+ let hasCustomAssemblyFormat = 1; 532 let hasCanonicalizer = 1; 533 } 534 535@@ -170,11 +166,7 @@ def TT_StoreOp : TT_Op<"store", 536 "triton::EvictionPolicy":$evict)>, 537 ]; 538 539- // let assemblyFormat = "operands attr-dict `:` type($value)"; 540- let parser = [{ return mlir::triton::parseStoreOp(parser, result); }]; 541- 542- let printer = [{ return mlir::triton::printStoreOp(p, *this); }]; 543- 544+ let hasCustomAssemblyFormat = 1; 545 let hasCanonicalizer = 1; 546 } 547 548diff --git a/include/triton/Dialect/Triton/IR/TritonTypes.td b/include/triton/Dialect/Triton/IR/TritonTypes.td 549index 66d2a7b9a9..2fe2fd077d 100644 550--- a/include/triton/Dialect/Triton/IR/TritonTypes.td 551+++ b/include/triton/Dialect/Triton/IR/TritonTypes.td 552@@ -1,6 +1,7 @@ 553 #ifndef TRITON_TYPES 554 #define TRITON_TYPES 555 556+include "mlir/IR/AttrTypeBase.td" 557 include "triton/Dialect/Triton/IR/TritonDialect.td" 558 559 // 560@@ -58,6 +59,7 @@ def TT_Ptr : TritonTypeDef<"Pointer", "ptr"> { 561 }]> 562 ]; 563 564+ let hasCustomAssemblyFormat = 1; 565 let skipDefaultBuilders = 1; 566 } 567 def TT_PtrTensor : TensorOf<[TT_Ptr]>; 568diff --git a/include/triton/Dialect/Triton/Transforms/Passes.td b/include/triton/Dialect/Triton/Transforms/Passes.td 569index 8f77aed774..a25cdc5680 100644 570--- a/include/triton/Dialect/Triton/Transforms/Passes.td 571+++ b/include/triton/Dialect/Triton/Transforms/Passes.td 572@@ -16,8 +16,7 @@ def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp"> 573 574 let constructor = "mlir::triton::createCombineOpsPass()"; 575 576- let dependentDialects = ["mlir::arith::ArithmeticDialect", 577- /*SelectOp*/"mlir::StandardOpsDialect"]; 578+ let dependentDialects = ["mlir::arith::ArithmeticDialect"]; 579 } 580 581 #endif 582diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h 583index b4c8daec7b..dfc5f53ab1 100644 584--- a/include/triton/Dialect/TritonGPU/IR/Dialect.h 585+++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h 586@@ -1,19 +1,17 @@ 587 #ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ 588 #define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ 589 590-#include "mlir/Dialect/GPU/GPUDialect.h" 591+#include "mlir/Dialect/GPU/IR/GPUDialect.h" 592 #include "mlir/Dialect/Tensor/IR/Tensor.h" 593 #include "mlir/IR/BuiltinOps.h" 594 #include "mlir/IR/Dialect.h" 595 596 // TritonGPU depends on Triton 597 #include "triton/Dialect/Triton/IR/Dialect.h" 598- 599 #include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" 600 #include "triton/Dialect/TritonGPU/IR/Traits.h" 601 602 #define GET_ATTRDEF_CLASSES 603-#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" 604 #include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc" 605 606 #define GET_OP_CLASSES 607diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td 608index 0242c3cc17..af2aeb03a8 100644 609--- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td 610+++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td 611@@ -1,6 +1,7 @@ 612 #ifndef TRITONGPU_ATTRDEFS 613 #define TRITONGPU_ATTRDEFS 614 615+include "mlir/IR/AttrTypeBase.td" 616 include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" 617 include "triton/Dialect/Triton/IR/TritonInterfaces.td" 618 619@@ -136,6 +137,7 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] / 620 ]; 621 622 let extraClassDeclaration = extraBaseClassDeclaration; 623+ let hasCustomAssemblyFormat = 1; 624 } 625 626 //===----------------------------------------------------------------------===// 627@@ -273,6 +275,7 @@ for 628 // ArrayRefParameter<"unsigned">:$sizePerCTA 629 ); 630 631+ let hasCustomAssemblyFormat = 1; 632 } 633 634 //===----------------------------------------------------------------------===// 635@@ -422,6 +425,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: 636 static constexpr int numBitsToHoldMmaV1ID{5}; 637 }]; 638 639+ let hasCustomAssemblyFormat = 1; 640 } 641 642 def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> { 643@@ -456,6 +460,8 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> { 644 template<class T> 645 SmallVector<T> paddedShape(ArrayRef<T> shape) const; 646 }]; 647+ 648+ let hasCustomAssemblyFormat = 1; 649 } 650 651 def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding"> { 652@@ -492,6 +498,7 @@ section 9.7.13.4.1 for more details. 653 654 ]; 655 656+ let hasCustomAssemblyFormat = 1; 657 let extraClassDeclaration = extraBaseClassDeclaration; 658 } 659 660diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td 661index 87ec1d36c6..6489a721b4 100644 662--- a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td 663+++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td 664@@ -30,7 +30,7 @@ def TritonGPU_Dialect : Dialect { 665 } 666 }]; 667 668- 669+ let useDefaultAttributePrinterParser = 1; 670 } 671 672 #endif 673diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td 674index 510f8d0183..7aba11dc75 100644 675--- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td 676+++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td 677@@ -59,7 +59,7 @@ def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> { 678 // This is needed because these ops don't 679 // handle encodings 680 // e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111 681-def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise, 682+def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise, 683 SameOperandsAndResultShape, 684 SameOperandsAndResultEncoding]> { 685 let summary = "integer comparison operation"; 686@@ -73,7 +73,7 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise, 687 let results = (outs TT_BoolLike:$result); 688 } 689 690-def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise, 691+def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise, 692 SameOperandsAndResultShape, 693 SameOperandsAndResultEncoding]> { 694 let summary = "floating-point comparison operation"; 695@@ -88,8 +88,8 @@ def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise, 696 } 697 698 // TODO: migrate to arith::SelectOp on LLVM16 699-def TTG_SelectOp : TTG_Op<"select", [NoSideEffect, Elementwise, 700- SameOperandsAndResultShape, 701+def TTG_SelectOp : TTG_Op<"select", [NoSideEffect, Elementwise, 702+ SameOperandsAndResultShape, 703 SameOperandsAndResultEncoding]> { 704 let summary = "select operation"; 705 706@@ -188,10 +188,7 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", 707 } 708 }]; 709 710- // The custom parser could be replaced with oilist in LLVM-16 711- let parser = [{ return parseInsertSliceAsyncOp(parser, result); }]; 712- 713- let printer = [{ return printInsertSliceAsyncOp(p, *this); }]; 714+ let hasCustomAssemblyFormat = 1; 715 } 716 717 def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [MemoryEffects<[MemAlloc]>, // Allocate shared memory 718diff --git a/lib/Analysis/Alias.cpp b/lib/Analysis/Alias.cpp 719index a39e4de9aa..208fdd4afc 100644 720--- a/lib/Analysis/Alias.cpp 721+++ b/lib/Analysis/Alias.cpp 722@@ -18,8 +18,9 @@ AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) { 723 return ret; 724 } 725 726-ChangeResult SharedMemoryAliasAnalysis::visitOperation( 727- Operation *op, ArrayRef<LatticeElement<AliasInfo> *> operands) { 728+void SharedMemoryAliasAnalysis::visitOperation( 729+ Operation *op, ArrayRef<const dataflow::Lattice<AliasInfo> *> operands, 730+ ArrayRef<dataflow::Lattice<AliasInfo> *> results) { 731 AliasInfo aliasInfo; 732 bool pessimistic = true; 733 if (maybeSharedAllocationOp(op)) { 734@@ -44,14 +45,11 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation( 735 } 736 737 if (pessimistic) { 738- return markAllPessimisticFixpoint(op->getResults()); 739+ return markAllPessimisticFixpoint(results); 740 } 741 // Join all lattice elements 742- ChangeResult result = ChangeResult::NoChange; 743- for (Value value : op->getResults()) { 744- result |= getLatticeElement(value).join(aliasInfo); 745- } 746- return result; 747+ for (auto *result : results) 748+ propagateIfChanged(result, result->join(aliasInfo)); 749 } 750 751 AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) { 752diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp 753index 712c08c475..b4de8dcd9d 100644 754--- a/lib/Analysis/Allocation.cpp 755+++ b/lib/Analysis/Allocation.cpp 756@@ -1,4 +1,5 @@ 757 #include "triton/Analysis/Allocation.h" 758+#include "mlir/Analysis/DataFlowFramework.h" 759 #include "mlir/Analysis/Liveness.h" 760 #include "mlir/Analysis/SliceAnalysis.h" 761 #include "mlir/Dialect/Tensor/IR/Tensor.h" 762@@ -33,10 +34,8 @@ constexpr int kPtrBitWidth = 64; 763 764 static std::pair<SmallVector<unsigned>, SmallVector<unsigned>> 765 getCvtOrder(const Attribute &srcLayout, const Attribute &dstLayout) { 766- auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>(); 767 auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>(); 768 auto srcDotLayout = srcLayout.dyn_cast<DotOperandEncodingAttr>(); 769- auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>(); 770 auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>(); 771 auto dstDotLayout = dstLayout.dyn_cast<DotOperandEncodingAttr>(); 772 assert(!(srcMmaLayout && dstMmaLayout) && 773@@ -224,14 +223,12 @@ class AllocationAnalysis { 774 } 775 776 void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) { 777- LatticeElement<AliasInfo> *latticeElement = 778- analysis.lookupLatticeElement(value); 779- if (latticeElement) { 780- auto &info = latticeElement->getValue(); 781- if (!info.getAllocs().empty()) { 782- for (auto alloc : info.getAllocs()) { 783- allocation->addAlias(value, alloc); 784- } 785+ dataflow::Lattice<AliasInfo> *latticeElement = 786+ analysis.getLatticeElement(value); 787+ if (latticeElement && !latticeElement->isUninitialized()) { 788+ AliasInfo &info = latticeElement->getValue(); 789+ for (auto alloc : info.getAllocs()) { 790+ allocation->addAlias(value, alloc); 791 } 792 } 793 } 794@@ -244,14 +241,19 @@ class AllocationAnalysis { 795 getScratchValueSize(op); 796 }); 797 // Get the alias values 798- SharedMemoryAliasAnalysis aliasAnalysis(operation->getContext()); 799- aliasAnalysis.run(operation); 800+ std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver(); 801+ SharedMemoryAliasAnalysis *aliasAnalysis = 802+ solver->load<SharedMemoryAliasAnalysis>(); 803+ if (failed(solver->initializeAndRun(operation))) { 804+ // TODO: return error instead of bailing out.. 805+ llvm_unreachable("failed to run SharedMemoryAliasAnalysis"); 806+ } 807 operation->walk<WalkOrder::PreOrder>([&](Operation *op) { 808 for (auto operand : op->getOperands()) { 809- getValueAlias(operand, aliasAnalysis); 810+ getValueAlias(operand, *aliasAnalysis); 811 } 812 for (auto value : op->getResults()) { 813- getValueAlias(value, aliasAnalysis); 814+ getValueAlias(value, *aliasAnalysis); 815 } 816 }); 817 } 818diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp 819index 0b7142b04d..4af46c3fbb 100644 820--- a/lib/Analysis/AxisInfo.cpp 821+++ b/lib/Analysis/AxisInfo.cpp 822@@ -1,4 +1,4 @@ 823-#include "mlir/Analysis/DataFlowAnalysis.h" 824+#include "mlir/Analysis/DataFlowFramework.h" 825 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 826 #include "llvm/Support/raw_ostream.h" 827 828@@ -52,7 +52,7 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) { 829 BlockArgument blockArg = value.dyn_cast<BlockArgument>(); 830 if (blockArg && blockArg.getOwner()->isEntryBlock()) { 831 Operation *op = blockArg.getOwner()->getParentOp(); 832- if (FuncOp fun = dyn_cast<FuncOp>(op)) { 833+ if (func::FuncOp fun = dyn_cast<func::FuncOp>(op)) { 834 Attribute attr = 835 fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility"); 836 if (attr) 837@@ -136,8 +136,9 @@ class CastOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> { 838 public: 839 using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl; 840 841- AxisInfo getAxisInfo(OpTy op, 842- ArrayRef<LatticeElement<AxisInfo> *> operands) override { 843+ AxisInfo 844+ getAxisInfo(OpTy op, 845+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override { 846 return operands[0]->getValue(); 847 } 848 }; 849@@ -147,8 +148,9 @@ class MakeRangeOpAxisInfoVisitor final 850 public: 851 using AxisInfoVisitorImpl<triton::MakeRangeOp>::AxisInfoVisitorImpl; 852 853- AxisInfo getAxisInfo(triton::MakeRangeOp op, 854- ArrayRef<LatticeElement<AxisInfo> *> operands) override { 855+ AxisInfo 856+ getAxisInfo(triton::MakeRangeOp op, 857+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override { 858 auto start = op.start(); 859 auto end = op.end(); 860 return AxisInfo(/*contiguity=*/{end - start}, 861@@ -162,8 +164,9 @@ class ConstantOpAxisInfoVisitor final 862 public: 863 using AxisInfoVisitorImpl<arith::ConstantOp>::AxisInfoVisitorImpl; 864 865- AxisInfo getAxisInfo(arith::ConstantOp op, 866- ArrayRef<LatticeElement<AxisInfo> *> operands) override { 867+ AxisInfo 868+ getAxisInfo(arith::ConstantOp op, 869+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override { 870 auto intAttr = op.getValue().dyn_cast<IntegerAttr>(); 871 auto boolAttr = op.getValue().dyn_cast<BoolAttr>(); 872 if (intAttr || boolAttr) { 873@@ -416,8 +419,9 @@ class SplatOpAxisInfoVisitor final 874 public: 875 using AxisInfoVisitorImpl<triton::SplatOp>::AxisInfoVisitorImpl; 876 877- AxisInfo getAxisInfo(triton::SplatOp op, 878- ArrayRef<LatticeElement<AxisInfo> *> operands) override { 879+ AxisInfo 880+ getAxisInfo(triton::SplatOp op, 881+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override { 882 Type _retTy = *op->result_type_begin(); 883 TensorType retTy = _retTy.cast<TensorType>(); 884 AxisInfo opInfo = operands[0]->getValue(); 885@@ -439,8 +443,9 @@ class ExpandDimsOpAxisInfoVisitor final 886 public: 887 using AxisInfoVisitorImpl<triton::ExpandDimsOp>::AxisInfoVisitorImpl; 888 889- AxisInfo getAxisInfo(triton::ExpandDimsOp op, 890- ArrayRef<LatticeElement<AxisInfo> *> operands) override { 891+ AxisInfo 892+ getAxisInfo(triton::ExpandDimsOp op, 893+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override { 894 AxisInfo opInfo = operands[0]->getValue(); 895 AxisInfo::DimVectorT contiguity = opInfo.getContiguity(); 896 AxisInfo::DimVectorT divisibility = opInfo.getDivisibility(); 897@@ -458,8 +463,9 @@ class BroadcastOpAxisInfoVisitor final 898 public: 899 using AxisInfoVisitorImpl<triton::BroadcastOp>::AxisInfoVisitorImpl; 900 901- AxisInfo getAxisInfo(triton::BroadcastOp op, 902- ArrayRef<LatticeElement<AxisInfo> *> operands) override { 903+ AxisInfo 904+ getAxisInfo(triton::BroadcastOp op, 905+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override { 906 Type _retTy = *op->result_type_begin(); 907 Type _opTy = *op->operand_type_begin(); 908 TensorType retTy = _retTy.cast<TensorType>(); 909@@ -486,8 +492,9 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> { 910 public: 911 using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl; 912 913- AxisInfo getAxisInfo(OpTy op, 914- ArrayRef<LatticeElement<AxisInfo> *> operands) override { 915+ AxisInfo 916+ getAxisInfo(OpTy op, 917+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override { 918 auto resTy = op.getResult().getType().template dyn_cast<RankedTensorType>(); 919 if (!resTy) 920 return AxisInfo(); 921@@ -596,8 +603,9 @@ class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> { 922 public: 923 using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl; 924 925- AxisInfo getAxisInfo(OpTy op, 926- ArrayRef<LatticeElement<AxisInfo> *> operands) override { 927+ AxisInfo 928+ getAxisInfo(OpTy op, 929+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override { 930 auto resTy = op.getResult().getType().template dyn_cast<RankedTensorType>(); 931 if (!resTy) 932 return AxisInfo(); 933@@ -757,8 +765,9 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> { 934 public: 935 using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl; 936 937- AxisInfo getAxisInfo(OpTy op, 938- ArrayRef<LatticeElement<AxisInfo> *> operands) override { 939+ AxisInfo 940+ getAxisInfo(OpTy op, 941+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override { 942 auto lhsInfo = operands[0]->getValue(); 943 auto rhsInfo = operands[1]->getValue(); 944 std::optional<int64_t> constantValue; 945@@ -786,8 +795,8 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> { 946 // AxisInfoAnalysis 947 //===----------------------------------------------------------------------===// 948 949-AxisInfoAnalysis::AxisInfoAnalysis(MLIRContext *context) 950- : ForwardDataFlowAnalysis<AxisInfo>(context) { 951+AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver) 952+ : dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AxisInfo>>(solver) { 953 // UnrealizedConversionCast: 954 // This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is 955 // in the process of a PartialConversion, where UnrealizedConversionCast 956@@ -819,7 +828,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(MLIRContext *context) 957 visitors.append<LogicalOpAxisInfoVisitor<arith::AndIOp>, 958 LogicalOpAxisInfoVisitor<arith::OrIOp>, 959 LogicalOpAxisInfoVisitor<arith::XOrIOp>>(); 960- visitors.append<SelectOpAxisInfoVisitor<mlir::SelectOp>, 961+ visitors.append<SelectOpAxisInfoVisitor<mlir::arith::SelectOp>, 962 SelectOpAxisInfoVisitor<triton::gpu::SelectOp>>(); 963 visitors.append<ShLIOpAxisInfoVisitor, ShROpAxisInfoVisitor<arith::ShRUIOp>, 964 ShROpAxisInfoVisitor<arith::ShRSIOp>>(); 965@@ -829,11 +838,12 @@ AxisInfoAnalysis::AxisInfoAnalysis(MLIRContext *context) 966 MaxMinOpAxisInfoVisitor<arith::MinUIOp>>(); 967 } 968 969-ChangeResult AxisInfoAnalysis::visitOperation( 970- Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) { 971+void AxisInfoAnalysis::visitOperation( 972+ Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands, 973+ ArrayRef<dataflow::Lattice<AxisInfo> *> results) { 974 AxisInfo curr = visitors.apply(op, operands); 975 if (curr.getRank() == 0) { 976- return markAllPessimisticFixpoint(op->getResults()); 977+ return markAllPessimisticFixpoint(results); 978 } 979 // override with hint 980 auto newContiguity = curr.getContiguity(); 981@@ -854,11 +864,8 @@ ChangeResult AxisInfoAnalysis::visitOperation( 982 curr = mlir::AxisInfo(newContiguity, newDivisibility, newConstancy, 983 curr.getConstantValue()); 984 // join all lattice elements 985- ChangeResult result = ChangeResult::NoChange; 986- for (Value value : op->getResults()) { 987- result |= getLatticeElement(value).join(curr); 988- } 989- return result; 990+ for (auto *result : results) 991+ propagateIfChanged(result, result->join(curr)); 992 } 993 994 unsigned AxisInfoAnalysis::getPtrContiguity(Value ptr) { 995@@ -884,7 +891,10 @@ unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) { 996 auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>(); 997 if (!tensorTy) 998 return 1; 999- auto axisInfo = lookupLatticeElement(ptr)->getValue(); 1000+ dataflow::Lattice<AxisInfo> *latticeElement = getLatticeElement(ptr); 1001+ if (!latticeElement || latticeElement->isUninitialized()) 1002+ return 1; 1003+ auto axisInfo = latticeElement->getValue(); 1004 auto layout = tensorTy.getEncoding(); 1005 auto order = triton::gpu::getOrder(layout); 1006 auto maxMultipleBytes = axisInfo.getDivisibility(order[0]); 1007@@ -900,8 +910,11 @@ unsigned AxisInfoAnalysis::getMaskAlignment(Value mask) { 1008 auto tensorTy = mask.getType().dyn_cast<RankedTensorType>(); 1009 if (!tensorTy) 1010 return 1; 1011+ dataflow::Lattice<AxisInfo> *latticeElement = getLatticeElement(mask); 1012+ if (!latticeElement || latticeElement->isUninitialized()) 1013+ return 1; 1014+ auto maskAxis = latticeElement->getValue(); 1015 auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding()); 1016- auto maskAxis = lookupLatticeElement(mask)->getValue(); 1017 auto alignment = std::max<unsigned>(maskAxis.getConstancy(maskOrder[0]), 1); 1018 return alignment; 1019 } 1020diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt 1021index afbc692510..1f761f845c 100644 1022--- a/lib/Analysis/CMakeLists.txt 1023+++ b/lib/Analysis/CMakeLists.txt 1024@@ -8,7 +8,7 @@ add_mlir_library(TritonAnalysis 1025 DEPENDS 1026 TritonTableGen 1027 TritonGPUAttrDefsIncGen 1028- 1029+ 1030 LINK_LIBS PUBLIC 1031 MLIRAnalysis 1032 ) 1033diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp 1034index acc885e827..910274b2ac 100644 1035--- a/lib/Analysis/Membar.cpp 1036+++ b/lib/Analysis/Membar.cpp 1037@@ -2,7 +2,7 @@ 1038 #include "triton/Analysis/Alias.h" 1039 #include "triton/Dialect/TritonGPU/IR/Dialect.h" 1040 1041-#include "mlir/Dialect/GPU/GPUDialect.h" 1042+#include "mlir/Dialect/GPU/IR/GPUDialect.h" 1043 #include "mlir/Dialect/Tensor/IR/Tensor.h" 1044 1045 namespace mlir { 1046diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp 1047index d9e917e731..6ea52df272 100644 1048--- a/lib/Analysis/Utility.cpp 1049+++ b/lib/Analysis/Utility.cpp 1050@@ -1,5 +1,8 @@ 1051 #include "triton/Analysis/Utility.h" 1052+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" 1053+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 1054 #include "mlir/IR/Dialect.h" 1055+#include "mlir/IR/Matchers.h" 1056 #include "triton/Dialect/Triton/IR/Dialect.h" 1057 #include "triton/Dialect/TritonGPU/IR/Dialect.h" 1058 #include <deque> 1059@@ -325,4 +328,55 @@ SetVector<Operation *> multiRootGetSlice(Operation *op, 1060 return multiRootTopologicalSort(slice); 1061 } 1062 1063+namespace { 1064+// Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis 1065+// interacts with constant propagation, but SparseConstantPropagation 1066+// doesn't seem to be sufficient. 1067+struct ConstantAnalysis : public DataFlowAnalysis { 1068+ using DataFlowAnalysis::DataFlowAnalysis; 1069+ 1070+ LogicalResult initialize(Operation *top) override { 1071+ WalkResult result = top->walk([&](Operation *op) { 1072+ if (failed(visit(op))) 1073+ return WalkResult::interrupt(); 1074+ return WalkResult::advance(); 1075+ }); 1076+ return success(!result.wasInterrupted()); 1077+ } 1078+ 1079+ LogicalResult visit(ProgramPoint point) override { 1080+ Operation *op = point.get<Operation *>(); 1081+ Attribute value; 1082+ if (matchPattern(op, m_Constant(&value))) { 1083+ auto *constant = getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>( 1084+ op->getResult(0)); 1085+ propagateIfChanged(constant, constant->join(dataflow::ConstantValue( 1086+ value, op->getDialect()))); 1087+ return success(); 1088+ } 1089+ setAllToUnknownConstants(op->getResults()); 1090+ for (Region &region : op->getRegions()) 1091+ setAllToUnknownConstants(region.getArguments()); 1092+ return success(); 1093+ } 1094+ 1095+ /// Set all given values as not constants. 1096+ void setAllToUnknownConstants(ValueRange values) { 1097+ dataflow::ConstantValue unknownConstant(nullptr, nullptr); 1098+ for (Value value : values) { 1099+ auto *constant = 1100+ getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(value); 1101+ propagateIfChanged(constant, constant->join(unknownConstant)); 1102+ } 1103+ } 1104+}; 1105+} // namespace 1106+ 1107+std::unique_ptr<DataFlowSolver> createDataFlowSolver() { 1108+ auto solver = std::make_unique<DataFlowSolver>(); 1109+ solver->load<dataflow::DeadCodeAnalysis>(); 1110+ solver->load<ConstantAnalysis>(); 1111+ return solver; 1112+} 1113+ 1114 } // namespace mlir 1115diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp 1116index 6a46265bd7..e352eb3698 100644 1117--- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp 1118+++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp 1119@@ -159,9 +159,6 @@ struct ConvertLayoutOpConversion 1120 Value smemBase) const { 1121 auto accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep); 1122 auto layout = type.getEncoding(); 1123- auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>(); 1124- auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>(); 1125- auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>(); 1126 auto rank = type.getRank(); 1127 auto sizePerThread = getSizePerThread(layout); 1128 auto accumSizePerThread = product<unsigned>(sizePerThread); 1129diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h b/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h 1130index 4b89965aa9..1d9e00519b 100644 1131--- a/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h 1132+++ b/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h 1133@@ -7,10 +7,8 @@ 1134 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" 1135 #include "mlir/Conversion/LLVMCommon/Pattern.h" 1136 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" 1137-#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" 1138-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 1139 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 1140-#include "mlir/Dialect/GPU/GPUDialect.h" 1141+#include "mlir/Dialect/GPU/IR/GPUDialect.h" 1142 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 1143 #include "mlir/Dialect/Tensor/IR/Tensor.h" 1144 #include "mlir/IR/Matchers.h" 1145@@ -422,9 +420,9 @@ struct MMA16816ConversionHelper { 1146 MMA16816ConversionHelper(Type dotOperand, MmaEncodingAttr mmaLayout, 1147 Value thread, ConversionPatternRewriter &rewriter, 1148 TypeConverter *typeConverter, Location loc) 1149- : mmaLayout(mmaLayout), thread(thread), helper(mmaLayout), 1150- rewriter(rewriter), typeConverter(typeConverter), loc(loc), 1151- ctx(mmaLayout.getContext()), wpt(mmaLayout.getWarpsPerCTA()) { 1152+ : mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()), thread(thread), 1153+ helper(mmaLayout), rewriter(rewriter), typeConverter(typeConverter), 1154+ loc(loc), ctx(mmaLayout.getContext()) { 1155 helper.deduceMmaType(dotOperand); 1156 1157 Value _32 = i32_val(32); 1158diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp 1159index 0f8070ca9f..e4bd47c411 100644 1160--- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp 1161+++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp 1162@@ -115,8 +115,6 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> { 1163 auto DTensorTy = D.getType().cast<RankedTensorType>(); 1164 auto AShape = ATensorTy.getShape(); 1165 auto BShape = BTensorTy.getShape(); 1166- auto DShape = DTensorTy.getShape(); 1167- auto wpt = mmaLayout.getWarpsPerCTA(); 1168 1169 bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue(); 1170 bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue(); 1171@@ -221,7 +219,6 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> { 1172 ConversionPatternRewriter &rewriter) const { 1173 auto *ctx = rewriter.getContext(); 1174 auto loc = op.getLoc(); 1175- auto threadId = getThreadId(rewriter, loc); 1176 1177 auto A = op.a(); 1178 auto B = op.b(); 1179@@ -230,12 +227,10 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> { 1180 1181 auto aTensorTy = A.getType().cast<RankedTensorType>(); 1182 auto bTensorTy = B.getType().cast<RankedTensorType>(); 1183- auto cTensorTy = C.getType().cast<RankedTensorType>(); 1184 auto dTensorTy = D.getType().cast<RankedTensorType>(); 1185 1186 auto aShape = aTensorTy.getShape(); 1187 auto bShape = bTensorTy.getShape(); 1188- auto cShape = cTensorTy.getShape(); 1189 1190 BlockedEncodingAttr dLayout = 1191 dTensorTy.getEncoding().cast<BlockedEncodingAttr>(); 1192diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp 1193index deb71b9597..0b9e67674b 100644 1194--- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp 1195+++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp 1196@@ -61,7 +61,6 @@ struct FpToFpOpConversion 1197 convertFp16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter, 1198 const Value &v0, const Value &v1, const Value &v2, 1199 const Value &v3) { 1200- auto ctx = rewriter.getContext(); 1201 auto fp16x2VecTy = vec_ty(f16_ty, 2); 1202 Value fp16x2Vec0 = undef(fp16x2VecTy); 1203 Value fp16x2Vec1 = undef(fp16x2VecTy); 1204@@ -153,7 +152,6 @@ struct FpToFpOpConversion 1205 convertBf16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter, 1206 const Value &v0, const Value &v1, const Value &v2, 1207 const Value &v3) { 1208- auto ctx = rewriter.getContext(); 1209 auto bf16x2VecTy = vec_ty(i16_ty, 2); 1210 Value bf16x2Vec0 = undef(bf16x2VecTy); 1211 Value bf16x2Vec1 = undef(bf16x2VecTy); 1212diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp 1213index 9a8b4702bc..bae675f0cb 100644 1214--- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp 1215+++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp 1216@@ -109,7 +109,8 @@ struct LoadOpConversion 1217 DenseElementsAttr constAttr; 1218 int64_t splatVal = 0; 1219 if (other && valueElemTy.isa<IntegerType>() && 1220- matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat()) { 1221+ matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat() && 1222+ constAttr.getElementType().isa<IntegerType>()) { 1223 otherIsSplatConstInt = true; 1224 splatVal = constAttr.getSplatValue<APInt>().getSExtValue(); 1225 } 1226@@ -333,7 +334,6 @@ struct StoreOpConversion 1227 elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem); 1228 elem = bitcast(elem, valueElemTy); 1229 1230- Type u32Ty = typeConverter->convertType(type::u32Ty(ctx)); 1231 llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx)); 1232 } 1233 llWord = bitcast(llWord, valArgTy); 1234@@ -387,7 +387,6 @@ struct AtomicCASOpConversion 1235 ConversionPatternRewriter &rewriter) const override { 1236 auto loc = op.getLoc(); 1237 MLIRContext *ctx = rewriter.getContext(); 1238- Value ptr = op.ptr(); 1239 1240 Value llPtr = adaptor.ptr(); 1241 Value llCmp = adaptor.cmp(); 1242diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp 1243index 69abd889be..1c973dc196 100644 1244--- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp 1245+++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp 1246@@ -286,7 +286,6 @@ struct ReduceOpConversion 1247 auto srcTy = op.operand().getType().cast<RankedTensorType>(); 1248 auto srcLayout = srcTy.getEncoding(); 1249 auto srcShape = srcTy.getShape(); 1250- auto srcRank = srcTy.getRank(); 1251 auto order = getOrder(srcLayout); 1252 1253 auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); 1254@@ -351,7 +350,6 @@ struct ReduceOpConversion 1255 1256 Value zero = i32_val(0); 1257 Value laneZero = icmp_eq(laneIdAxis, zero); 1258- Value warpZero = icmp_eq(warpIdAxis, zero); 1259 1260 for (auto it : accs) { 1261 const SmallVector<unsigned> &key = it.first; 1262diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp 1263index 5b77150b1a..78cfa076bd 100644 1264--- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp 1265+++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp 1266@@ -11,11 +11,11 @@ using ::mlir::LLVM::getStructFromElements; 1267 using ::mlir::triton::gpu::getElemsPerThread; 1268 using ::mlir::triton::gpu::SharedEncodingAttr; 1269 1270-struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> { 1271- using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern; 1272+struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> { 1273+ using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern; 1274 1275 LogicalResult 1276- matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 1277+ matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 1278 ConversionPatternRewriter &rewriter) const override { 1279 unsigned numArguments = op.getNumOperands(); 1280 1281@@ -476,7 +476,6 @@ struct ExtractSliceOpConversion 1282 1283 auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); 1284 auto elemPtrTy = ptr_ty(llvmElemTy, 3); 1285- auto resTy = op.getType().dyn_cast<RankedTensorType>(); 1286 smemObj = SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset), 1287 strideVals, offsetVals); 1288 auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); 1289diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h 1290index bb10d5b24a..00e399f848 100644 1291--- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h 1292+++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h 1293@@ -4,6 +4,7 @@ 1294 // TODO: refactor so that it doesn't fail if Allocation.h 1295 // is included after utility.h (due to conflict in `store` macro 1296 // and <atomic> 1297+#include "mlir/Dialect/Func/IR/FuncOps.h" 1298 #include "triton/Analysis/Allocation.h" 1299 1300 // 1301@@ -39,15 +40,15 @@ void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info, 1302 // TODO(Superjomn): remove the code when MLIR v15.0 is included. 1303 // All the rights are reserved by the LLVM community. 1304 1305-struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> { 1306+struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> { 1307 private: 1308 /// Only retain those attributes that are not constructed by 1309 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument 1310 /// attributes. 1311- static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs, 1312- bool filterArgAttrs, 1313+ static void filterFuncAttributes(func::FuncOp op, bool filterArgAttrs, 1314 SmallVectorImpl<NamedAttribute> &result) { 1315- for (const auto &attr : attrs) { 1316+ 1317+ for (const auto &attr : op->getAttrs()) { 1318 if (attr.getName() == SymbolTable::getSymbolAttrName() || 1319 attr.getName() == FunctionOpInterface::getTypeAttrName() || 1320 attr.getName() == "std.varargs" || 1321@@ -65,27 +66,27 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> { 1322 } 1323 1324 protected: 1325- using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern; 1326+ using ConvertOpToLLVMPattern<func::FuncOp>::ConvertOpToLLVMPattern; 1327 1328 // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided 1329 // to this legalization pattern. 1330 LLVM::LLVMFuncOp 1331- convertFuncOpToLLVMFuncOp(FuncOp funcOp, 1332+ convertFuncOpToLLVMFuncOp(func::FuncOp funcOp, 1333 ConversionPatternRewriter &rewriter) const { 1334 // Convert the original function arguments. They are converted using the 1335 // LLVMTypeConverter provided to this legalization pattern. 1336 auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("func.varargs"); 1337 TypeConverter::SignatureConversion result(funcOp.getNumArguments()); 1338 auto llvmType = getTypeConverter()->convertFunctionSignature( 1339- funcOp.getType(), varargsAttr && varargsAttr.getValue(), result); 1340+ funcOp.getFunctionType(), varargsAttr && varargsAttr.getValue(), 1341+ result); 1342 if (!llvmType) 1343 return nullptr; 1344 1345 // Propagate argument/result attributes to all converted arguments/result 1346 // obtained after converting a given original argument/result. 1347 SmallVector<NamedAttribute, 4> attributes; 1348- filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true, 1349- attributes); 1350+ filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, attributes); 1351 if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) { 1352 assert(!resAttrDicts.empty() && "expected array to be non-empty"); 1353 auto newResAttrDicts = 1354@@ -131,7 +132,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> { 1355 } 1356 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>( 1357 funcOp.getLoc(), funcOp.getName(), llvmType, linkage, 1358- /*dsoLocal*/ false, attributes); 1359+ /*dsoLocal*/ false, LLVM::CConv::C, attributes); 1360 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 1361 newFuncOp.end()); 1362 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter, 1363@@ -191,8 +192,8 @@ class ConvertTritonGPUOpToLLVMPatternBase { 1364 const Allocation *allocation, 1365 Value smem, 1366 IndexCacheInfo indexCacheInfo) 1367- : converter(&typeConverter), indexCacheInfo(indexCacheInfo), 1368- allocation(allocation), smem(smem) {} 1369+ : converter(&typeConverter), allocation(allocation), smem(smem), 1370+ indexCacheInfo(indexCacheInfo) {} 1371 1372 LLVMTypeConverter *getTypeConverter() const { return converter; } 1373 1374@@ -861,7 +862,6 @@ class ConvertTritonGPUOpToLLVMPatternBase { 1375 ArrayRef<int64_t> shape) const { 1376 auto parent = sliceLayout.getParent(); 1377 unsigned dim = sliceLayout.getDim(); 1378- size_t rank = shape.size(); 1379 auto parentIndices = 1380 emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape)); 1381 unsigned numIndices = parentIndices.size(); 1382diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp 1383index ff1af09835..6f66af4e34 100644 1384--- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp 1385+++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp 1386@@ -1,10 +1,11 @@ 1387 #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h" 1388 1389+#include "mlir/Analysis/DataFlowFramework.h" 1390 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 1391+#include "mlir/Conversion/ControlFlowToLLVM//ControlFlowToLLVM.h" 1392 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" 1393 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" 1394-#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" 1395-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 1396+#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" 1397 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 1398 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 1399 #include "mlir/Pass/Pass.h" 1400@@ -40,7 +41,6 @@ class TritonLLVMConversionTarget : public ConversionTarget { 1401 addIllegalDialect<triton::TritonDialect>(); 1402 addIllegalDialect<triton::gpu::TritonGPUDialect>(); 1403 addIllegalDialect<mlir::gpu::GPUDialect>(); 1404- addIllegalDialect<mlir::StandardOpsDialect>(); 1405 addLegalOp<mlir::UnrealizedConversionCastOp>(); 1406 } 1407 }; 1408@@ -51,7 +51,7 @@ class TritonLLVMFunctionConversionTarget : public ConversionTarget { 1409 : ConversionTarget(ctx) { 1410 addLegalDialect<LLVM::LLVMDialect>(); 1411 addLegalDialect<NVVM::NVVMDialect>(); 1412- addIllegalOp<mlir::FuncOp>(); 1413+ addIllegalOp<mlir::func::FuncOp>(); 1414 addLegalOp<mlir::UnrealizedConversionCastOp>(); 1415 } 1416 }; 1417@@ -69,7 +69,7 @@ struct FuncOpConversion : public FuncOpConversionBase { 1418 : FuncOpConversionBase(converter, benefit), numWarps(numWarps) {} 1419 1420 LogicalResult 1421- matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, 1422+ matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, 1423 ConversionPatternRewriter &rewriter) const override { 1424 auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); 1425 if (!newFuncOp) 1426@@ -133,7 +133,8 @@ class ConvertTritonGPUToLLVM 1427 decomposeBlockedToDotOperand(mod); 1428 1429 // Step 2 1430- decomposeInsertSliceAsyncOp(mod); 1431+ if (failed(decomposeInsertSliceAsyncOp(mod))) 1432+ return signalPassFailure(); 1433 1434 // Step 3 1435 Allocation allocation(mod); 1436@@ -142,7 +143,7 @@ class ConvertTritonGPUToLLVM 1437 1438 // Step 4 1439 RewritePatternSet scf_patterns(context); 1440- mlir::populateLoopToStdConversionPatterns(scf_patterns); 1441+ mlir::populateSCFToControlFlowConversionPatterns(scf_patterns); 1442 mlir::ConversionTarget scf_target(*context); 1443 scf_target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, 1444 scf::WhileOp, scf::ExecuteRegionOp>(); 1445@@ -159,8 +160,10 @@ class ConvertTritonGPUToLLVM 1446 return signalPassFailure(); 1447 1448 // Step 6 - get axis and shared memory info 1449- AxisInfoAnalysis axisInfoAnalysis(mod.getContext()); 1450- axisInfoAnalysis.run(mod); 1451+ std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver(); 1452+ AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>(); 1453+ if (failed(solver->initializeAndRun(mod))) 1454+ return signalPassFailure(); 1455 initSharedMemory(allocation.getSharedMemorySize(), typeConverter); 1456 mod->setAttr("triton_gpu.shared", 1457 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32), 1458@@ -178,38 +181,39 @@ class ConvertTritonGPUToLLVM 1459 1460 // Normal conversions 1461 populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps, 1462- axisInfoAnalysis, &allocation, smem, 1463+ *axisInfoAnalysis, &allocation, smem, 1464 indexCacheInfo, /*benefit=*/10); 1465 // ConvertLayoutOp 1466 populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps, 1467- axisInfoAnalysis, &allocation, smem, 1468+ *axisInfoAnalysis, &allocation, smem, 1469 indexCacheInfo, /*benefit=*/10); 1470 // DotOp 1471 populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps, 1472- axisInfoAnalysis, &allocation, smem, 1473+ *axisInfoAnalysis, &allocation, smem, 1474 /*benefit=*/10); 1475 // ElementwiseOp 1476 populateElementwiseOpToLLVMPatterns(typeConverter, patterns, numWarps, 1477- axisInfoAnalysis, &allocation, smem, 1478+ *axisInfoAnalysis, &allocation, smem, 1479 /*benefit=*/10); 1480 // LoadStoreOp 1481 populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps, 1482- axisInfoAnalysis, &allocation, smem, 1483+ *axisInfoAnalysis, &allocation, smem, 1484 indexCacheInfo, /*benefit=*/10); 1485 // ReduceOp 1486 populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps, 1487- axisInfoAnalysis, &allocation, smem, 1488+ *axisInfoAnalysis, &allocation, smem, 1489 indexCacheInfo, /*benefit=*/10); 1490 // ViewOp 1491 populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps, 1492- axisInfoAnalysis, &allocation, smem, 1493+ *axisInfoAnalysis, &allocation, smem, 1494 /*benefit=*/10); 1495 1496 // Add arith/math's patterns to help convert scalar expression to LLVM. 1497 mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, 1498 patterns); 1499 mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns); 1500- mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns); 1501+ mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, 1502+ patterns); 1503 mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns); 1504 1505 if (failed(applyPartialConversion(mod, target, std::move(patterns)))) 1506@@ -306,9 +310,11 @@ class ConvertTritonGPUToLLVM 1507 }); 1508 } 1509 1510- void decomposeInsertSliceAsyncOp(ModuleOp mod) const { 1511- AxisInfoAnalysis axisInfoAnalysis(mod.getContext()); 1512- axisInfoAnalysis.run(mod); 1513+ LogicalResult decomposeInsertSliceAsyncOp(ModuleOp mod) const { 1514+ std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver(); 1515+ AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>(); 1516+ if (failed(solver->initializeAndRun(mod))) 1517+ return failure(); 1518 // TODO(Keren): This is a hacky knob that may cause performance regression 1519 // when decomposition has been performed. We should remove this knob once we 1520 // have thorough analysis on async wait. Currently, we decompose 1521@@ -342,7 +348,7 @@ class ConvertTritonGPUToLLVM 1522 auto resSharedLayout = 1523 dstTy.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>(); 1524 auto resElemTy = dstTy.getElementType(); 1525- unsigned inVec = axisInfoAnalysis.getPtrContiguity(src); 1526+ unsigned inVec = axisInfoAnalysis->getPtrContiguity(src); 1527 unsigned outVec = resSharedLayout.getVec(); 1528 unsigned minVec = std::min(outVec, inVec); 1529 auto maxBitWidth = 1530@@ -400,11 +406,11 @@ class ConvertTritonGPUToLLVM 1531 } else if (decomposed) { 1532 // Wait for all previous async ops 1533 OpBuilder builder(asyncWaitOp); 1534- auto newAsyncWaitOp = 1535- builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0); 1536+ builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0); 1537 asyncWaitOp.erase(); 1538 } 1539 }); 1540+ return success(); 1541 } 1542 }; 1543 1544diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h 1545index d35dac28c5..11976908cf 100644 1546--- a/lib/Conversion/TritonGPUToLLVM/Utility.h 1547+++ b/lib/Conversion/TritonGPUToLLVM/Utility.h 1548@@ -220,10 +220,7 @@ struct SharedMemoryObject { 1549 ConversionPatternRewriter &rewriter) 1550 : base(base) { 1551 strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter); 1552- 1553- for (auto idx : order) { 1554- offsets.emplace_back(i32_val(0)); 1555- } 1556+ offsets.append(order.size(), i32_val(0)); 1557 } 1558 1559 SmallVector<Value> getElems() const { 1560diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp 1561index fe42202c34..5f230f787f 100644 1562--- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp 1563+++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp 1564@@ -1,10 +1,10 @@ 1565 #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" 1566 1567 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 1568-#include "mlir/Dialect/GPU/GPUDialect.h" 1569+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 1570+#include "mlir/Dialect/GPU/IR/GPUDialect.h" 1571 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 1572 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 1573-#include "mlir/Dialect/StandardOps/IR/Ops.h" 1574 #include "mlir/Pass/Pass.h" 1575 #include "mlir/Transforms/DialectConversion.h" 1576 #include "triton/Dialect/Triton/IR/Dialect.h" 1577@@ -59,10 +59,13 @@ class ArithConstantPattern : public OpConversionPattern<arith::ConstantOp> { 1578 Type retType = getTypeConverter()->convertType(op.getType()); 1579 auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>(); 1580 assert(value); 1581- rewriter.replaceOpWithNewOp<arith::ConstantOp>( 1582- op, retType, 1583- value.reshape(retType) // This is a hack. We just want to add encoding 1584- ); 1585+ if (value.getElementType().isInteger(1) && value.isSplat()) 1586+ // Workaround until https://reviews.llvm.org/D133743 is included. 1587+ value = DenseElementsAttr::get(retType, value.getSplatValue<bool>()); 1588+ else 1589+ // This is a hack. We just want to add encoding 1590+ value = value.reshape(retType); 1591+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, retType, value); 1592 return success(); 1593 } 1594 }; 1595@@ -127,12 +130,12 @@ void populateArithmeticPatternsAndLegality( 1596 } 1597 1598 // this shouldn't exist if mlir's SelectOp checked encodings properly 1599-class StdSelectPattern : public OpConversionPattern<SelectOp> { 1600+class StdSelectPattern : public OpConversionPattern<arith::SelectOp> { 1601 public: 1602- using OpConversionPattern<SelectOp>::OpConversionPattern; 1603+ using OpConversionPattern<arith::SelectOp>::OpConversionPattern; 1604 1605 LogicalResult 1606- matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor, 1607+ matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, 1608 ConversionPatternRewriter &rewriter) const override { 1609 Type retType = this->getTypeConverter()->convertType(op.getType()); 1610 rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>( 1611@@ -148,8 +151,8 @@ void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter, 1612 MLIRContext *context = patterns.getContext(); 1613 // Rewrite rule 1614 patterns.add<StdSelectPattern>(typeConverter, context); 1615- target.addLegalOp<ReturnOp>(); // this is ok because all functions are inlined 1616- // by the frontend 1617+ target.addLegalOp<func::ReturnOp>(); // this is ok because all functions are 1618+ // inlined by the frontend 1619 } 1620 1621 void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter, 1622@@ -455,18 +458,19 @@ struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> { 1623 void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, 1624 RewritePatternSet &patterns) { 1625 MLIRContext *context = patterns.getContext(); 1626- patterns.add< // TODO: view should have custom pattern that views the layout 1627- TritonGenericPattern<triton::ViewOp>, 1628- TritonGenericPattern<triton::BitcastOp>, 1629- TritonGenericPattern<triton::FpToFpOp>, 1630- TritonGenericPattern<triton::IntToPtrOp>, 1631- TritonGenericPattern<triton::PtrToIntOp>, 1632- TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern, 1633- TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern, 1634- TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern, 1635- TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, 1636- TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern, 1637- TritonAtomicRMWPattern>(typeConverter, context); 1638+ patterns 1639+ .insert< // TODO: view should have custom pattern that views the layout 1640+ TritonGenericPattern<triton::ViewOp>, 1641+ TritonGenericPattern<triton::BitcastOp>, 1642+ TritonGenericPattern<triton::FpToFpOp>, 1643+ TritonGenericPattern<triton::IntToPtrOp>, 1644+ TritonGenericPattern<triton::PtrToIntOp>, 1645+ TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern, 1646+ TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern, 1647+ TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern, 1648+ TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, 1649+ TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern, 1650+ TritonAtomicRMWPattern>(typeConverter, context); 1651 } 1652 1653 // 1654@@ -623,29 +627,28 @@ void populateSCFPatterns(TritonGPUTypeConverter &typeConverter, 1655 1656 // CF 1657 1658-class CFBranchPattern : public OpConversionPattern<BranchOp> { 1659+class CFBranchPattern : public OpConversionPattern<cf::BranchOp> { 1660 public: 1661- using OpConversionPattern<BranchOp>::OpConversionPattern; 1662+ using OpConversionPattern<cf::BranchOp>::OpConversionPattern; 1663 1664 LogicalResult 1665- matchAndRewrite(BranchOp op, BranchOp::Adaptor adaptor, 1666+ matchAndRewrite(cf::BranchOp op, cf::BranchOp::Adaptor adaptor, 1667 ConversionPatternRewriter &rewriter) const override { 1668- auto converter = getTypeConverter(); 1669- auto newOp = rewriter.replaceOpWithNewOp<BranchOp>(op, op.getSuccessor(), 1670- adaptor.getOperands()); 1671+ auto newOp = rewriter.replaceOpWithNewOp<cf::BranchOp>( 1672+ op, op.getSuccessor(), adaptor.getOperands()); 1673 return success(); 1674 } 1675 }; 1676 1677-class CFCondBranchPattern : public OpConversionPattern<CondBranchOp> { 1678+class CFCondBranchPattern : public OpConversionPattern<cf::CondBranchOp> { 1679 public: 1680- using OpConversionPattern<CondBranchOp>::OpConversionPattern; 1681+ using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern; 1682 1683 LogicalResult 1684- matchAndRewrite(CondBranchOp op, CondBranchOp::Adaptor adaptor, 1685+ matchAndRewrite(cf::CondBranchOp op, cf::CondBranchOp::Adaptor adaptor, 1686 ConversionPatternRewriter &rewriter) const override { 1687 auto converter = getTypeConverter(); 1688- auto newOp = rewriter.replaceOpWithNewOp<CondBranchOp>( 1689+ auto newOp = rewriter.replaceOpWithNewOp<cf::CondBranchOp>( 1690 op, adaptor.getCondition(), op.getTrueDest(), 1691 adaptor.getTrueDestOperands(), op.getFalseDest(), 1692 adaptor.getFalseDestOperands()); 1693diff --git a/lib/Dialect/Triton/IR/CMakeLists.txt b/lib/Dialect/Triton/IR/CMakeLists.txt 1694index 2d679b21fd..705554ba6b 100644 1695--- a/lib/Dialect/Triton/IR/CMakeLists.txt 1696+++ b/lib/Dialect/Triton/IR/CMakeLists.txt 1697@@ -10,11 +10,7 @@ add_mlir_dialect_library(TritonIR 1698 1699 LINK_LIBS PUBLIC 1700 MLIRIR 1701- MLIRArithmetic 1702- MLIRSCF 1703- 1704- # Since LLVM 15 1705- # MLIRFunc 1706- # else 1707- MLIRStandard 1708+ MLIRArithmeticDialect 1709+ MLIRSCFDialect 1710+ MLIRFuncDialect 1711 ) 1712diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp 1713index 3aadbfa0c0..86570359c5 100644 1714--- a/lib/Dialect/Triton/IR/Ops.cpp 1715+++ b/lib/Dialect/Triton/IR/Ops.cpp 1716@@ -1,10 +1,9 @@ 1717-#include "triton/Dialect/Triton/IR/Dialect.h" 1718-#include "triton/Dialect/Triton/IR/Types.h" 1719- 1720 #include "mlir/IR/Builders.h" 1721 #include "mlir/IR/BuiltinAttributes.h" 1722 #include "mlir/IR/BuiltinTypes.h" 1723 #include "mlir/IR/OperationSupport.h" 1724+#include "triton/Dialect/Triton/IR/Dialect.h" 1725+#include "triton/Dialect/Triton/IR/Types.h" 1726 1727 namespace mlir { 1728 namespace triton { 1729@@ -38,8 +37,8 @@ static Type getPointerTypeSameShape(Type type) { 1730 } 1731 1732 // Parser & printer for assembly forms 1733-ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { 1734- SmallVector<OpAsmParser::OperandType, 4> allOperands; 1735+ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { 1736+ SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands; 1737 Type resultTypes[1]; 1738 SMLoc allOperandLoc = parser.getCurrentLocation(); 1739 if (parser.parseOperandList(allOperands) || 1740@@ -73,18 +72,18 @@ ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { 1741 return success(); 1742 } 1743 1744-void printLoadOp(OpAsmPrinter &printer, LoadOp loadOp) { 1745+void LoadOp::print(OpAsmPrinter &printer) { 1746 printer << " "; 1747- printer << loadOp.getOperation()->getOperands(); 1748+ printer << getOperation()->getOperands(); 1749 // "operand_segment_sizes" can be deduced, so we don't print it. 1750- printer.printOptionalAttrDict(loadOp->getAttrs(), 1751- {loadOp.operand_segment_sizesAttrName()}); 1752+ printer.printOptionalAttrDict(getOperation()->getAttrs(), 1753+ {operand_segment_sizesAttrName()}); 1754 printer << " : "; 1755- printer.printStrippedAttrOrType(loadOp.result().getType()); 1756+ printer.printStrippedAttrOrType(getResult().getType()); 1757 } 1758 1759-ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { 1760- SmallVector<OpAsmParser::OperandType, 4> allOperands; 1761+ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { 1762+ SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands; 1763 Type valueType; 1764 SMLoc allOperandLoc = parser.getCurrentLocation(); 1765 if (parser.parseOperandList(allOperands) || 1766@@ -104,12 +103,12 @@ ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { 1767 return success(); 1768 } 1769 1770-void printStoreOp(OpAsmPrinter &printer, StoreOp storeOp) { 1771+void StoreOp::print(OpAsmPrinter &printer) { 1772 printer << " "; 1773- printer << storeOp.getOperation()->getOperands(); 1774- printer.printOptionalAttrDict(storeOp->getAttrs(), /*elidedAttrs=*/{}); 1775+ printer << getOperation()->getOperands(); 1776+ printer.printOptionalAttrDict(getOperation()->getAttrs(), /*elidedAttrs=*/{}); 1777 printer << " : "; 1778- printer.printStrippedAttrOrType(storeOp.value().getType()); 1779+ printer.printStrippedAttrOrType(value().getType()); 1780 } 1781 1782 } // namespace triton 1783@@ -319,7 +318,8 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) { 1784 if (!constOperand) 1785 return {}; 1786 auto shapedType = getType().cast<ShapedType>(); 1787- auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()}); 1788+ auto ret = SplatElementsAttr::get( 1789+ shapedType, ArrayRef<Attribute>(constOperand.getValue())); 1790 return ret; 1791 } 1792 1793diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp 1794index 2261472170..11570283d6 100644 1795--- a/lib/Dialect/Triton/Transforms/Combine.cpp 1796+++ b/lib/Dialect/Triton/Transforms/Combine.cpp 1797@@ -57,13 +57,13 @@ DenseElementsAttr getConstantValue(Builder &builder, Attribute value, 1798 class CombineSelectMaskedLoadPattern : public mlir::RewritePattern { 1799 public: 1800 CombineSelectMaskedLoadPattern(mlir::MLIRContext *context) 1801- : mlir::RewritePattern(mlir::SelectOp::getOperationName(), 3, context, 1802- {triton::LoadOp::getOperationName()}) {} 1803+ : mlir::RewritePattern(mlir::arith::SelectOp::getOperationName(), 3, 1804+ context, {triton::LoadOp::getOperationName()}) {} 1805 1806 mlir::LogicalResult 1807 matchAndRewrite(mlir::Operation *op, 1808 mlir::PatternRewriter &rewriter) const override { 1809- auto selectOp = llvm::dyn_cast<mlir::SelectOp>(op); 1810+ auto selectOp = llvm::dyn_cast<mlir::arith::SelectOp>(op); 1811 if (!selectOp) 1812 return mlir::failure(); 1813 1814diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td 1815index 14f286b26e..ded0e346e6 100644 1816--- a/lib/Dialect/Triton/Transforms/Combine.td 1817+++ b/lib/Dialect/Triton/Transforms/Combine.td 1818@@ -1,9 +1,9 @@ 1819 #ifndef TRITON_PATTERNS 1820 #define TRITON_PATTERNS 1821 1822-include "mlir/Dialect/StandardOps/IR/Ops.td" 1823 include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td" 1824 include "triton/Dialect/Triton/IR/TritonOps.td" 1825+include "mlir/IR/PatternBase.td" 1826 1827 1828 // AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) 1829diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp 1830index 1fbc609e88..bfc3f3d3da 100644 1831--- a/lib/Dialect/TritonGPU/IR/Dialect.cpp 1832+++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp 1833@@ -1,14 +1,14 @@ 1834+#include "triton/Dialect/Triton/IR/Dialect.h" 1835+ 1836 #include <numeric> 1837 1838 #include "mlir/IR/DialectImplementation.h" 1839 #include "mlir/IR/OpImplementation.h" 1840 #include "triton/Analysis/Utility.h" 1841-#include "triton/Dialect/Triton/IR/Dialect.h" 1842+#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" 1843 #include "triton/Dialect/TritonGPU/IR/Dialect.h" 1844 #include "llvm/ADT/TypeSwitch.h" 1845 1846-#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" 1847- 1848 using namespace mlir; 1849 using namespace mlir::triton::gpu; 1850 1851@@ -366,7 +366,6 @@ template SmallVector<int64_t> 1852 SliceEncodingAttr::paddedShape<int64_t>(ArrayRef<int64_t> shape) const; 1853 1854 unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const { 1855- size_t rank = shape.size(); 1856 auto parent = getParent(); 1857 return ::getElemsPerThread(parent, paddedShape(shape)); 1858 } 1859@@ -655,9 +654,9 @@ void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const { 1860 // InsertSliceAsyncOp 1861 //===----------------------------------------------------------------------===// 1862 1863-ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser, 1864- OperationState &result) { 1865- SmallVector<OpAsmParser::OperandType, 8> allOperands; 1866+ParseResult InsertSliceAsyncOp::parse(OpAsmParser &parser, 1867+ OperationState &result) { 1868+ SmallVector<OpAsmParser::UnresolvedOperand, 8> allOperands; 1869 Type srcType, dstType; 1870 SMLoc allOperandLoc = parser.getCurrentLocation(); 1871 if (parser.parseOperandList(allOperands) || 1872@@ -696,18 +695,16 @@ ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser, 1873 return success(); 1874 } 1875 1876-void printInsertSliceAsyncOp(OpAsmPrinter &printer, 1877- InsertSliceAsyncOp insertSliceAsyncOp) { 1878+void InsertSliceAsyncOp::print(OpAsmPrinter &printer) { 1879 printer << " "; 1880- printer << insertSliceAsyncOp.getOperation()->getOperands(); 1881+ printer << getOperation()->getOperands(); 1882 // "operand_segment_sizes" can be deduced, so we don't print it. 1883- printer.printOptionalAttrDict( 1884- insertSliceAsyncOp->getAttrs(), 1885- {insertSliceAsyncOp.operand_segment_sizesAttrName()}); 1886+ printer.printOptionalAttrDict(getOperation()->getAttrs(), 1887+ {operand_segment_sizesAttrName()}); 1888 printer << " : "; 1889- printer.printStrippedAttrOrType(insertSliceAsyncOp.src().getType()); 1890+ printer.printStrippedAttrOrType(src().getType()); 1891 printer << " -> "; 1892- printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType()); 1893+ printer.printStrippedAttrOrType(result().getType()); 1894 } 1895 1896 //===----------------------------------------------------------------------===// 1897diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp 1898index 82407980d3..ee6009f44a 100644 1899--- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp 1900+++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp 1901@@ -27,7 +27,11 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> { 1902 auto origType = ptr.getType().cast<RankedTensorType>(); 1903 // Get the shape of the tensor. 1904 size_t rank = origType.getRank(); 1905- AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue(); 1906+ dataflow::Lattice<AxisInfo> *latticeElement = 1907+ axisInfo.getLatticeElement(ptr); 1908+ AxisInfo info = latticeElement && !latticeElement->isUninitialized() 1909+ ? latticeElement->getValue() 1910+ : AxisInfo(); 1911 // Get the contiguity order of `ptr` 1912 auto order = argSort(info.getContiguity()); 1913 // The desired divisibility is the maximum divisibility 1914@@ -40,7 +44,7 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> { 1915 for (Value val : op->getResults()) { 1916 if (val.getType() != origType) 1917 continue; 1918- auto valInfo = axisInfo.lookupLatticeElement(val); 1919+ auto valInfo = axisInfo.getLatticeElement(val); 1920 auto currOrder = argSort(valInfo->getValue().getContiguity()); 1921 if (order == currOrder) 1922 withSameOrder.insert(val); 1923@@ -55,7 +59,7 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> { 1924 unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); 1925 unsigned perThread = 1; 1926 for (Value val : withSameOrder) { 1927- AxisInfo info = axisInfo.lookupLatticeElement(val)->getValue(); 1928+ AxisInfo info = axisInfo.getLatticeElement(val)->getValue(); 1929 unsigned maxMultipleBytes = info.getDivisibility(order[0]); 1930 unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u); 1931 unsigned maxContig = info.getContiguity(order[0]); 1932@@ -123,8 +127,10 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> { 1933 void runOnOperation() override { 1934 Operation *op = getOperation(); 1935 // Run axis info analysis 1936- AxisInfoAnalysis axisInfo(&getContext()); 1937- axisInfo.run(op); 1938+ std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver(); 1939+ AxisInfoAnalysis *axisInfo = solver->load<AxisInfoAnalysis>(); 1940+ if (failed(solver->initializeAndRun(op))) 1941+ return signalPassFailure(); 1942 1943 // For each i/o operation, we determine what layout 1944 // the pointers should have for best memory coalescing 1945@@ -146,10 +152,10 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> { 1946 RankedTensorType ty = ptr.getType().template dyn_cast<RankedTensorType>(); 1947 if (!ty || !ty.getElementType().isa<PointerType>()) 1948 return; 1949- AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue(); 1950+ AxisInfo info = axisInfo->getLatticeElement(ptr)->getValue(); 1951 auto mod = curr->getParentOfType<ModuleOp>(); 1952 int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); 1953- auto convertType = getTypeConverter(axisInfo, ptr, numWarps); 1954+ auto convertType = getTypeConverter(*axisInfo, ptr, numWarps); 1955 layoutMap[ptr] = convertType; 1956 }); 1957 1958diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp 1959index efa37ff2dc..089ce3996c 100644 1960--- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp 1961+++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp 1962@@ -1,6 +1,6 @@ 1963 #include "Utility.h" 1964 #include "mlir/Analysis/SliceAnalysis.h" 1965-#include "mlir/Dialect/SCF/SCF.h" 1966+#include "mlir/Dialect/SCF/IR/SCF.h" 1967 #include "mlir/IR/BlockAndValueMapping.h" 1968 #include "mlir/IR/BuiltinAttributes.h" 1969 #include "mlir/IR/Matchers.h" 1970diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.td b/lib/Dialect/TritonGPU/Transforms/Combine.td 1971index 6bf1b14866..6a7b10dbcb 100644 1972--- a/lib/Dialect/TritonGPU/Transforms/Combine.td 1973+++ b/lib/Dialect/TritonGPU/Transforms/Combine.td 1974@@ -3,5 +3,6 @@ 1975 1976 include "triton/Dialect/TritonGPU/IR/TritonGPUOps.td" 1977 include "triton/Dialect/Triton/IR/TritonOps.td" 1978+include "mlir/IR/PatternBase.td" 1979 1980 #endif 1981diff --git a/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp b/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp 1982index 4bd3bc76bf..b2f8defd81 100644 1983--- a/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp 1984+++ b/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp 1985@@ -1,5 +1,5 @@ 1986 #include "mlir/Analysis/SliceAnalysis.h" 1987-#include "mlir/Dialect/SCF/SCF.h" 1988+#include "mlir/Dialect/SCF/IR/SCF.h" 1989 #include "mlir/IR/BlockAndValueMapping.h" 1990 #include "mlir/IR/BuiltinAttributes.h" 1991 #include "mlir/IR/Matchers.h" 1992diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp 1993index 9b2f42231e..85f746c1dc 100644 1994--- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp 1995+++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp 1996@@ -2,6 +2,7 @@ 1997 #include "mlir/IR/BlockAndValueMapping.h" 1998 #include "mlir/IR/TypeUtilities.h" 1999 #include "triton/Analysis/AxisInfo.h" 2000+#include "triton/Analysis/Utility.h" 2001 #include "triton/Dialect/TritonGPU/IR/Dialect.h" 2002 #include "triton/Dialect/TritonGPU/Transforms/Passes.h" 2003 2004@@ -160,15 +161,18 @@ ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op, 2005 LogicalResult LoopPipeliner::initialize() { 2006 Block *loop = forOp.getBody(); 2007 2008- AxisInfoAnalysis axisInfoAnalysis(forOp.getContext()); 2009- axisInfoAnalysis.run(forOp->getParentOfType<ModuleOp>()); 2010+ std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver(); 2011+ AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>(); 2012+ if (failed(solver->initializeAndRun(forOp->getParentOfType<ModuleOp>()))) { 2013+ return failure(); 2014+ } 2015 2016 // can we use forOp.walk(...) here? 2017 SmallVector<triton::LoadOp, 2> allLoads; 2018 for (Operation &op : *loop) 2019 if (auto loadOp = dyn_cast<triton::LoadOp>(&op)) { 2020 auto ptr = loadOp.ptr(); 2021- unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); 2022+ unsigned vec = axisInfoAnalysis->getPtrContiguity(ptr); 2023 auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>(); 2024 if (!tensorTy) 2025 continue; 2026diff --git a/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp b/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp 2027index 0e7dbe5264..b95a4f50a6 100644 2028--- a/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp 2029+++ b/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp 2030@@ -1,5 +1,5 @@ 2031 #include "mlir/Analysis/SliceAnalysis.h" 2032-#include "mlir/Dialect/SCF/SCF.h" 2033+#include "mlir/Dialect/SCF/IR/SCF.h" 2034 #include "mlir/IR/BlockAndValueMapping.h" 2035 #include "mlir/IR/BuiltinAttributes.h" 2036 #include "mlir/IR/Matchers.h" 2037diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp 2038index 37ac710995..762e887f36 100644 2039--- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp 2040+++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp 2041@@ -82,12 +82,12 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( 2042 scf::ReduceReturnOp>(); 2043 2044 addDynamicallyLegalDialect<arith::ArithmeticDialect, math::MathDialect, 2045- triton::TritonDialect, StandardOpsDialect, 2046- scf::SCFDialect>([&](Operation *op) { 2047- if (typeConverter.isLegal(op)) 2048- return true; 2049- return false; 2050- }); 2051+ triton::TritonDialect, scf::SCFDialect>( 2052+ [&](Operation *op) { 2053+ if (typeConverter.isLegal(op)) 2054+ return true; 2055+ return false; 2056+ }); 2057 2058 // We have requirements for the data layouts 2059 addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool { 2060diff --git a/lib/Dialect/TritonGPU/Transforms/UpdateMmaForVolta.cpp b/lib/Dialect/TritonGPU/Transforms/UpdateMmaForVolta.cpp 2061index c229104286..c911fd4a5c 100644 2062--- a/lib/Dialect/TritonGPU/Transforms/UpdateMmaForVolta.cpp 2063+++ b/lib/Dialect/TritonGPU/Transforms/UpdateMmaForVolta.cpp 2064@@ -1,5 +1,5 @@ 2065 #include "Utility.h" 2066-#include "mlir/Dialect/SCF/SCF.h" 2067+#include "mlir/Dialect/SCF/IR/SCF.h" 2068 #include "mlir/IR/Matchers.h" 2069 #include "mlir/IR/PatternMatch.h" 2070 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 2071@@ -118,8 +118,8 @@ void setOpResultType(Operation *op, ArrayRef<Type> newTypes) { 2072 .get("value") 2073 .dyn_cast<mlir::DenseElementsAttr>(); 2074 if (attr) { 2075- auto newAttr = mlir::DenseElementsAttr::getFromRawBuffer( 2076- newType, attr.getRawData(), true); 2077+ auto newAttr = 2078+ mlir::DenseElementsAttr::getFromRawBuffer(newType, attr.getRawData()); 2079 op->setAttr("value", newAttr); 2080 } 2081 } 2082diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp 2083index ed15f02f67..6400f1633a 100644 2084--- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp 2085+++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp 2086@@ -1,5 +1,5 @@ 2087 #include "Utility.h" 2088-#include "mlir/Dialect/SCF/SCF.h" 2089+#include "mlir/Dialect/SCF/IR/SCF.h" 2090 #include "mlir/IR/BlockAndValueMapping.h" 2091 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 2092 2093diff --git a/lib/Target/LLVMIR/CMakeLists.txt b/lib/Target/LLVMIR/CMakeLists.txt 2094index f1bbd0bf4e..ac8973ad19 100644 2095--- a/lib/Target/LLVMIR/CMakeLists.txt 2096+++ b/lib/Target/LLVMIR/CMakeLists.txt 2097@@ -6,8 +6,7 @@ add_mlir_translation_library(TritonLLVMIR 2098 2099 LINK_LIBS PUBLIC 2100 MLIRIR 2101- MLIRLLVMIR 2102- MLIRSCFToStandard 2103+ MLIRLLVMDialect 2104 MLIRSupport 2105 MLIRTargetLLVMIRExport 2106 ) 2107diff --git a/lib/Target/PTX/PTXTranslation.cpp b/lib/Target/PTX/PTXTranslation.cpp 2108index 4cb0d8193c..6a5453a6e7 100644 2109--- a/lib/Target/PTX/PTXTranslation.cpp 2110+++ b/lib/Target/PTX/PTXTranslation.cpp 2111@@ -1,11 +1,14 @@ 2112 #include "triton/Target/PTX/PTXTranslation.h" 2113 #include "triton/Target/LLVMIR/LLVMIRTranslation.h" 2114+#include <optional> 2115 2116 #include "llvm/IR/IRBuilder.h" 2117 #include "llvm/IR/LegacyPassManager.h" 2118 #include "llvm/IR/Module.h" 2119 #include "llvm/IR/Verifier.h" 2120 #include "llvm/MC/TargetRegistry.h" 2121+#include "llvm/Pass.h" 2122+#include "llvm/Support/CommandLine.h" 2123 #include "llvm/Support/TargetSelect.h" 2124 #include "llvm/Target/TargetMachine.h" 2125 2126diff --git a/python/setup.py b/python/setup.py 2127index 2ac3accd25..4530b36714 100644 2128--- a/python/setup.py 2129+++ b/python/setup.py 2130@@ -57,19 +57,10 @@ def get_pybind11_package_info(): 2131 def get_llvm_package_info(): 2132 # download if nothing is installed 2133 system = platform.system() 2134- if system == "Darwin": 2135- system_suffix = "apple-darwin" 2136- elif system == "Linux": 2137- vglibc = tuple(map(int, platform.libc_ver()[1].split('.'))) 2138- vglibc = vglibc[0] * 100 + vglibc[1] 2139- linux_suffix = 'ubuntu-18.04' if vglibc > 217 else 'centos-7' 2140- system_suffix = f"linux-gnu-{linux_suffix}" 2141- else: 2142- raise RuntimeError(f"unsupported system: {system}") 2143+ system_suffix = {"Linux": "linux-gnu-ubuntu-18.04", "Darwin": "apple-darwin"}[system] 2144 use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False") 2145- release_suffix = "assert" if use_assert_enabled_llvm else "release" 2146- name = f'llvm+mlir-14.0.6-x86_64-{system_suffix}-{release_suffix}' 2147- url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/llvm-14.0.6-f28c006a5895/{name}.tar.xz" 2148+ name = 'llvm+mlir-15.0.7-x86_64-{}-{}'.format(system_suffix, "assert" if use_assert_enabled_llvm else "release") 2149+ url = "https://github.com/ptillet/triton-llvm-releases/releases/download/llvm-15.0.7-8dfdcc7b7bf6/{}.tar.xz".format(name) 2150 return Package("llvm", name, url, "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH") 2151 2152 2153diff --git a/python/src/triton.cc b/python/src/triton.cc 2154index c40b117a55..f190eacc34 100644 2155--- a/python/src/triton.cc 2156+++ b/python/src/triton.cc 2157@@ -8,9 +8,10 @@ 2158 #include "mlir/Pass/PassManager.h" 2159 #include "mlir/Transforms/Passes.h" 2160 2161-#include "mlir/Parser.h" 2162+#include "mlir/Parser/Parser.h" 2163 #include "mlir/Support/FileUtilities.h" 2164 2165+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 2166 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 2167 #include "triton/Analysis/Allocation.h" 2168 #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h" 2169@@ -195,7 +196,7 @@ void init_triton_ir(py::module &&m) { 2170 std::string attrName = name + "_arg" + std::to_string(id); 2171 mlir::Block *owner = arg.getOwner(); 2172 if (owner->isEntryBlock() && 2173- !mlir::isa<mlir::FuncOp>(owner->getParentOp())) { 2174+ !mlir::isa<mlir::func::FuncOp>(owner->getParentOp())) { 2175 owner->getParentOp()->setAttr(attrName, attr); 2176 } 2177 } 2178@@ -348,7 +349,7 @@ void init_triton_ir(py::module &&m) { 2179 return str; 2180 }) 2181 .def("push_back", 2182- [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void { 2183+ [](mlir::ModuleOp &self, mlir::func::FuncOp &funcOp) -> void { 2184 self.push_back(funcOp); 2185 }) 2186 .def("has_function", 2187@@ -358,16 +359,18 @@ void init_triton_ir(py::module &&m) { 2188 return false; 2189 }) 2190 .def("get_function", 2191- [](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp { 2192- return self.lookupSymbol<mlir::FuncOp>(funcName); 2193- }) 2194- .def("get_single_function", [](mlir::ModuleOp &self) -> mlir::FuncOp { 2195- llvm::SmallVector<mlir::FuncOp> funcs; 2196- self.walk([&](mlir::FuncOp func) { funcs.push_back(func); }); 2197- if (funcs.size() != 1) 2198- throw std::runtime_error("Expected a single function"); 2199- return funcs[0]; 2200- }); 2201+ [](mlir::ModuleOp &self, 2202+ std::string &funcName) -> mlir::func::FuncOp { 2203+ return self.lookupSymbol<mlir::func::FuncOp>(funcName); 2204+ }) 2205+ .def("get_single_function", 2206+ [](mlir::ModuleOp &self) -> mlir::func::FuncOp { 2207+ llvm::SmallVector<mlir::func::FuncOp> funcs; 2208+ self.walk([&](mlir::func::FuncOp func) { funcs.push_back(func); }); 2209+ if (funcs.size() != 1) 2210+ throw std::runtime_error("Expected a single function"); 2211+ return funcs[0]; 2212+ }); 2213 2214 m.def("make_attr", 2215 [](const std::vector<int> &values, mlir::MLIRContext &context) { 2216@@ -388,47 +391,48 @@ void init_triton_ir(py::module &&m) { 2217 registry.insert<mlir::triton::TritonDialect, 2218 mlir::triton::gpu::TritonGPUDialect, 2219 mlir::math::MathDialect, mlir::arith::ArithmeticDialect, 2220- mlir::StandardOpsDialect, mlir::scf::SCFDialect>(); 2221+ mlir::func::FuncDialect, mlir::scf::SCFDialect>(); 2222 context.appendDialectRegistry(registry); 2223 context.loadAllAvailableDialects(); 2224 2225 // parse module 2226- mlir::OwningOpRef<mlir::ModuleOp> module( 2227- mlir::parseSourceFile(inputFilename, &context)); 2228+ mlir::OwningOpRef<mlir::ModuleOp> module = 2229+ mlir::parseSourceFile<mlir::ModuleOp>(inputFilename, &context); 2230+ if (!module) 2231+ throw std::runtime_error("Parse MLIR file failed."); 2232 // locations are incompatible with ptx < 7.5 ! 2233 module->walk([](mlir::Operation *op) { 2234 op->setLoc(mlir::UnknownLoc::get(op->getContext())); 2235 }); 2236- if (!module) 2237- throw std::runtime_error("Parse MLIR file failed."); 2238 2239 return module->clone(); 2240 }, 2241 ret::take_ownership); 2242 2243- py::class_<mlir::FuncOp, mlir::OpState>(m, "function") 2244+ py::class_<mlir::func::FuncOp, mlir::OpState>(m, "function") 2245 // .def_property_readonly("attrs", &ir::function::attrs) 2246 // .def("add_attr", &ir::function::add_attr); 2247 .def("args", 2248- [](mlir::FuncOp &self, unsigned idx) -> mlir::BlockArgument { 2249+ [](mlir::func::FuncOp &self, unsigned idx) -> mlir::BlockArgument { 2250 return self.getArgument(idx); 2251 }) 2252 .def( 2253 "add_entry_block", 2254- [](mlir::FuncOp &self) -> mlir::Block * { 2255+ [](mlir::func::FuncOp &self) -> mlir::Block * { 2256 return self.addEntryBlock(); 2257 }, 2258 ret::reference) 2259 .def( 2260 "set_arg_attr", 2261- [](mlir::FuncOp &self, int arg_no, const std::string &name, int val) { 2262+ [](mlir::func::FuncOp &self, int arg_no, const std::string &name, 2263+ int val) { 2264 // set arg attributes "name" to value "val" 2265 auto attrTy = mlir::IntegerType::get(self.getContext(), 32); 2266 self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val)); 2267 }, 2268 ret::reference) 2269- .def_property_readonly("type", &mlir::FuncOp::getType) 2270- .def("reset_type", &mlir::FuncOp::setType); 2271+ .def_property_readonly("type", &mlir::func::FuncOp::getFunctionType) 2272+ .def("reset_type", &mlir::func::FuncOp::setType); 2273 2274 py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint"); 2275 2276@@ -445,13 +449,13 @@ void init_triton_ir(py::module &&m) { 2277 .def("ret", 2278 [](mlir::OpBuilder &self, std::vector<mlir::Value> &vals) -> void { 2279 auto loc = self.getUnknownLoc(); 2280- self.create<mlir::ReturnOp>(loc, vals); 2281+ self.create<mlir::func::ReturnOp>(loc, vals); 2282 }) 2283 .def("call", 2284- [](mlir::OpBuilder &self, mlir::FuncOp &func, 2285+ [](mlir::OpBuilder &self, mlir::func::FuncOp &func, 2286 std::vector<mlir::Value> &args) -> mlir::OpState { 2287 auto loc = self.getUnknownLoc(); 2288- return self.create<mlir::CallOp>(loc, func, args); 2289+ return self.create<mlir::func::CallOp>(loc, func, args); 2290 }) 2291 // insertion block/point 2292 .def("set_insertion_point_to_start", 2293@@ -618,15 +622,16 @@ void init_triton_ir(py::module &&m) { 2294 .def("get_or_insert_function", 2295 [](mlir::OpBuilder &self, mlir::ModuleOp &module, 2296 std::string &funcName, mlir::Type &funcType, 2297- std::string &visibility) -> mlir::FuncOp { 2298+ std::string &visibility) -> mlir::func::FuncOp { 2299 if (mlir::Operation *funcOperation = module.lookupSymbol(funcName)) 2300- return llvm::dyn_cast<mlir::FuncOp>(funcOperation); 2301+ return llvm::dyn_cast<mlir::func::FuncOp>(funcOperation); 2302 auto loc = self.getUnknownLoc(); 2303 if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) { 2304 llvm::SmallVector<mlir::NamedAttribute> attrs = { 2305 mlir::NamedAttribute(self.getStringAttr("sym_visibility"), 2306 self.getStringAttr(visibility))}; 2307- return self.create<mlir::FuncOp>(loc, funcName, funcTy, attrs); 2308+ return self.create<mlir::func::FuncOp>(loc, funcName, funcTy, 2309+ attrs); 2310 } 2311 throw std::runtime_error("invalid function type"); 2312 }) 2313@@ -658,15 +663,15 @@ void init_triton_ir(py::module &&m) { 2314 [](mlir::OpBuilder &self, mlir::Value condition, 2315 mlir::Block *trueDest, mlir::Block *falseDest) { 2316 auto loc = self.getUnknownLoc(); 2317- self.create<mlir::CondBranchOp>(loc, condition, trueDest, 2318- falseDest); 2319+ self.create<mlir::cf::CondBranchOp>(loc, condition, trueDest, 2320+ falseDest); 2321 return; 2322 }) 2323 .def("create_branch", 2324 [](mlir::OpBuilder &self, mlir::Block *dest, 2325 std::vector<mlir::Value> &args) { 2326 auto loc = self.getUnknownLoc(); 2327- self.create<mlir::BranchOp>(loc, dest, args); 2328+ self.create<mlir::cf::BranchOp>(loc, dest, args); 2329 return; 2330 }) 2331 // Structured control flow 2332@@ -792,14 +797,14 @@ void init_triton_ir(py::module &&m) { 2333 .def("create_to_index", 2334 [](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value { 2335 auto loc = self.getUnknownLoc(); 2336- return self.create<mlir::arith::IndexCastOp>(loc, input, 2337- self.getIndexType()); 2338+ return self.create<mlir::arith::IndexCastOp>( 2339+ loc, self.getIndexType(), input); 2340 }) 2341 .def("create_index_to_si", 2342 [](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value { 2343 auto loc = self.getUnknownLoc(); 2344- return self.create<mlir::arith::IndexCastOp>(loc, input, 2345- self.getI32Type()); 2346+ return self.create<mlir::arith::IndexCastOp>( 2347+ loc, self.getI32Type(), input); 2348 }) 2349 .def("create_fmul", 2350 [](mlir::OpBuilder &self, mlir::Value &lhs, 2351@@ -1316,8 +1321,8 @@ void init_triton_ir(py::module &&m) { 2352 [](mlir::OpBuilder &self, mlir::Value &condition, 2353 mlir::Value &trueValue, mlir::Value &falseValue) -> mlir::Value { 2354 auto loc = self.getUnknownLoc(); 2355- return self.create<mlir::SelectOp>(loc, condition, trueValue, 2356- falseValue); 2357+ return self.create<mlir::arith::SelectOp>(loc, condition, 2358+ trueValue, falseValue); 2359 }) 2360 .def("create_printf", 2361 [](mlir::OpBuilder &self, const std::string &prefix, 2362@@ -1429,7 +1434,7 @@ void init_triton_ir(py::module &&m) { 2363 self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass()); 2364 }) 2365 .def("add_scf_to_cfg", [](mlir::PassManager &self) { 2366- self.addPass(mlir::createLowerToCFGPass()); 2367+ self.addPass(mlir::createConvertSCFToCFPass()); 2368 }); 2369 } 2370 2371diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py 2372index 432544a8a4..018f544714 100644 2373--- a/python/test/unit/language/test_core.py 2374+++ b/python/test/unit/language/test_core.py 2375@@ -1918,7 +1918,7 @@ def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'): 2376 #dst = {dst_layout} 2377 """ + """ 2378 module attributes {"triton_gpu.num-warps" = 4 : i32} { 2379- func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) { 2380+ func.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) { 2381 %cst = arith.constant dense<128> : tensor<128x1xi32, #src> 2382 %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>> 2383 %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>> 2384diff --git a/python/triton/compiler.py b/python/triton/compiler.py 2385index 5d167634df..c36589037c 100644 2386--- a/python/triton/compiler.py 2387+++ b/python/triton/compiler.py 2388@@ -1514,14 +1514,14 @@ def make_hash(fn, **kwargs): 2389 return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest() 2390 2391 2392-# - ^\s*func\s+ : match the start of the string, any leading whitespace, the keyword func, 2393+# - ^\s*func\.func\s+ : match the start of the string, any leading whitespace, the keyword func, 2394 # and any following whitespace 2395 # - (public\s+)? : optionally match the keyword public and any following whitespace 2396 # - (@\w+) : match an @ symbol followed by one or more word characters 2397 # (letters, digits, or underscores), and capture it as group 1 (the function name) 2398 # - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing 2399 # zero or more arguments separated by commas, and capture it as group 2 (the argument list) 2400-mlir_prototype_pattern = r'^\s*func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$' 2401+mlir_prototype_pattern = r'^\s*func\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$' 2402 ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" 2403 prototype_pattern = { 2404 "ttir": mlir_prototype_pattern, 2405diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir 2406index b3d5673f85..bb21615e68 100644 2407--- a/test/Analysis/test-alias.mlir 2408+++ b/test/Analysis/test-alias.mlir 2409@@ -11,7 +11,7 @@ 2410 2411 // CHECK-LABEL: matmul_loop 2412 // There shouldn't be any aliasing with the dot op encoding. 2413-func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) { 2414+func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) { 2415 %a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL> 2416 %b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL> 2417 %a_mask = arith.constant dense<true> : tensor<128x32xi1, #AL> 2418@@ -36,7 +36,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B 2419 } 2420 2421 // CHECK-LABEL: alloc 2422-func @alloc(%A : !tt.ptr<f16>) { 2423+func.func @alloc(%A : !tt.ptr<f16>) { 2424 // CHECK: %cst -> %cst 2425 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> 2426 %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> 2427@@ -46,7 +46,7 @@ func @alloc(%A : !tt.ptr<f16>) { 2428 } 2429 2430 // CHECK-LABEL: convert 2431-func @convert(%A : !tt.ptr<f16>) { 2432+func.func @convert(%A : !tt.ptr<f16>) { 2433 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> 2434 // CHECK: %0 -> %0 2435 %cst1 = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED> 2436@@ -54,7 +54,7 @@ func @convert(%A : !tt.ptr<f16>) { 2437 } 2438 2439 // CHECK-LABEL: trans 2440-func @trans(%A : !tt.ptr<f16>) { 2441+func.func @trans(%A : !tt.ptr<f16>) { 2442 // CHECK: %cst -> %cst 2443 %tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED> 2444 // CHECK: %0 -> %cst 2445@@ -63,7 +63,7 @@ func @trans(%A : !tt.ptr<f16>) { 2446 } 2447 2448 // CHECK-LABEL: insert_slice_async 2449-func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) { 2450+func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) { 2451 %a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL> 2452 %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> 2453 %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> 2454@@ -76,7 +76,7 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) { 2455 } 2456 2457 // CHECK-LABEL: insert_slice 2458-func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) { 2459+func.func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) { 2460 %a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL> 2461 %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> 2462 %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> 2463@@ -90,7 +90,7 @@ func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) { 2464 } 2465 2466 // CHECK-LABEL: extract_slice 2467-func @extract_slice(%A : !tt.ptr<f16>) { 2468+func.func @extract_slice(%A : !tt.ptr<f16>) { 2469 // CHECK: %cst -> %cst 2470 %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED> 2471 %index = arith.constant 0 : index 2472@@ -100,7 +100,7 @@ func @extract_slice(%A : !tt.ptr<f16>) { 2473 } 2474 2475 // CHECK-LABEL: if_cat 2476-func @if_cat(%i1 : i1) { 2477+func.func @if_cat(%i1 : i1) { 2478 // CHECK: %cst -> %cst 2479 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> 2480 // CHECK: %cst_0 -> %cst_0 2481@@ -119,7 +119,7 @@ func @if_cat(%i1 : i1) { 2482 } 2483 2484 // CHECK-LABEL: if_alias 2485-func @if_alias(%i1 : i1) { 2486+func.func @if_alias(%i1 : i1) { 2487 // CHECK: %cst -> %cst 2488 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> 2489 // CHECK-NEXT: %cst_0 -> %cst_0 2490@@ -134,7 +134,7 @@ func @if_alias(%i1 : i1) { 2491 } 2492 2493 // CHECK-LABEL: for 2494-func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) { 2495+func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) { 2496 // CHECK: %cst -> %cst 2497 %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> 2498 // CHECK-NEXT: %cst_0 -> %cst_0 2499@@ -154,7 +154,7 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.p 2500 } 2501 2502 // CHECK-LABEL: for_if 2503-func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) { 2504+func.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) { 2505 // CHECK: %cst -> %cst 2506 %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> 2507 // CHECK-NEXT: %cst_0 -> %cst_0 2508@@ -180,7 +180,7 @@ func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !t 2509 } 2510 2511 // CHECK-LABEL: for_if_for 2512-func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) { 2513+func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) { 2514 // CHECK: %cst -> %cst 2515 %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> 2516 // CHECK-NEXT: %cst_0 -> %cst_0 2517diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir 2518index 0ab34c7a78..af8ea6f856 100644 2519--- a/test/Analysis/test-alignment.mlir 2520+++ b/test/Analysis/test-alignment.mlir 2521@@ -1,288 +1,288 @@ 2522-// RUN: triton-opt %s -test-print-alignment -split-input-file 2>&1 | FileCheck %s 2523+// RUN: triton-opt %s -test-print-alignment -split-input-file -o %t 2>&1 | FileCheck %s 2524 2525-// CHECK-LABEL: cast 2526-func @cast() { 2527- // CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [1] 2528+// CHECK-LABEL: @cast 2529+func.func @cast() { 2530+ // CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1 2531 %cst = arith.constant 1 : i32 2532- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [1] 2533+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1 2534 %0 = arith.extsi %cst : i32 to i64 2535- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1] 2536+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 2537 %cst_tensor = arith.constant dense<1> : tensor<128xi32> 2538- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1] 2539+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 2540 %1 = tt.bitcast %cst_tensor : tensor<128xi32> -> tensor<128xi64> 2541 return 2542 } 2543 2544 // ----- 2545 2546-// CHECK-LABEL: add 2547-func @add() { 2548- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] 2549+// CHECK-LABEL: @add 2550+func.func @add() { 2551+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none> 2552 %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> 2553- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1] 2554+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 2555 %1 = arith.constant dense<1> : tensor<128xi32> 2556- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] 2557+ // CHECK-NEXT: contiguity = [128], divisibility = [1], constancy = [1], constant_value = <none> 2558 %2 = arith.addi %0, %1 : tensor<128xi32> 2559- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [127] 2560+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 127 2561 %3 = arith.constant dense<127> : tensor<128xi32> 2562- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128] 2563+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 2564 %4 = arith.addi %1, %3 : tensor<128xi32> 2565 return 2566 } 2567 2568 // ----- 2569 2570-// CHECK-LABEL: sub 2571-func @sub() { 2572- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] 2573+// CHECK-LABEL: @sub 2574+func.func @sub() { 2575+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none> 2576 %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> 2577- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1] 2578+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 2579 %1 = arith.constant dense<1> : tensor<128xi32> 2580- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] 2581+ // CHECK-NEXT: contiguity = [128], divisibility = [1], constancy = [1], constant_value = <none> 2582 %2 = arith.subi %0, %1 : tensor<128xi32> 2583- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [129] 2584+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 129 2585 %3 = arith.constant dense<129> : tensor<128xi32> 2586- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128] 2587+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 2588 %4 = arith.subi %3, %1 : tensor<128xi32> 2589 return 2590 } 2591 2592 // ----- 2593 2594-// CHECK-LABEL: mul 2595-func @mul() { 2596- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] 2597+// CHECK-LABEL: @mul 2598+func.func @mul() { 2599+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none> 2600 %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> 2601- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1] 2602+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 2603 %1 = arith.constant dense<1> : tensor<128xi32> 2604- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] 2605+ // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none> 2606 %2 = arith.muli %0, %1 : tensor<128xi32> 2607- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128] 2608+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 2609 %3 = arith.constant dense<128> : tensor<128xi32> 2610- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128] 2611+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 2612 %4 = arith.muli %3, %1 : tensor<128xi32> 2613- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [2] 2614+ // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 2 2615 %5 = arith.constant dense<2> : tensor<128xi32> 2616- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [256] ; Constancy: [128] ; ConstantValue: [256] 2617+ // CHECK-NEXT: contiguity = [1], divisibility = [256], constancy = [128], constant_value = 256 2618 %6 = arith.muli %4, %5 : tensor<128xi32> 2619 return 2620 } 2621 2622 // ----- 2623 2624-// CHECK-LABEL: div 2625-func @div() { 2626- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] 2627+// CHECK-LABEL: @div 2628+func.func @div() { 2629+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none> 2630 %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> 2631- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1] 2632+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 2633 %1 = arith.constant dense<1> : tensor<128xi32> 2634- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] 2635+ // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none> 2636 %2 = arith.divsi %0, %1 : tensor<128xi32> 2637- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] 2638+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none> 2639 %3 = arith.divui %1, %0 : tensor<128xi32> 2640- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64] 2641+ // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 2642 %4 = arith.constant dense<64> : tensor<128xi32> 2643- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16777216] ; Constancy: [64] ; ConstantValue: [None] 2644+ // CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = <none> 2645 %5 = arith.divsi %0, %4 : tensor<128xi32> 2646- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] 2647+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none> 2648 %6 = arith.divsi %4, %0 : tensor<128xi32> 2649- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64] 2650+ // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 2651 %7 = arith.divsi %4, %1 : tensor<128xi32> 2652- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [66] 2653+ // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 66 2654 %8 = arith.constant dense<66> : tensor<128xi32> 2655- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [2] ; ConstantValue: [None] 2656+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [2], constant_value = <none> 2657 %9 = arith.divui %0, %8 : tensor<128xi32> 2658- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [8192] ; Constancy: [1] ; ConstantValue: [None] 2659+ // CHECK-NEXT: contiguity = [128], divisibility = [8192], constancy = [1], constant_value = <none> 2660 %10 = tt.make_range {end = 8320 : i32, start = 8192 : i32} : tensor<128xi32> 2661- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [64] ; ConstantValue: [None] 2662+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [64], constant_value = <none> 2663 %11 = arith.divsi %10, %4 : tensor<128xi32> 2664- return 2665+ return 2666 } 2667 2668 // ----- 2669 2670-// CHECK-LABEL: rem 2671-func @rem() { 2672- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] 2673+// CHECK-LABEL: @rem 2674+func.func @rem() { 2675+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none> 2676 %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> 2677- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1] 2678+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 2679 %1 = arith.constant dense<1> : tensor<128xi32> 2680- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0] 2681+ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 2682 %2 = arith.remsi %0, %1 : tensor<128xi32> 2683- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] 2684+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none> 2685 %3 = arith.remui %1, %0 : tensor<128xi32> 2686- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64] 2687+ // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 2688 %4 = arith.constant dense<64> : tensor<128xi32> 2689- // CHECK-NEXT: Contiguity: [64] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None] 2690+ // CHECK-NEXT: contiguity = [64], divisibility = [64], constancy = [1], constant_value = <none> 2691 %5 = arith.remsi %0, %4 : tensor<128xi32> 2692- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None] 2693+ // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [1], constant_value = <none> 2694 %6 = arith.remsi %4, %0 : tensor<128xi32> 2695- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [66] 2696+ // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 66 2697 %7 = arith.constant dense<66> : tensor<128xi32> 2698- // CHECK-NEXT: Contiguity: [2] ; Divisibility: [2] ; Constancy: [1] ; ConstantValue: [None] 2699+ // CHECK-NEXT: contiguity = [2], divisibility = [2], constancy = [1], constant_value = <none> 2700 %8 = arith.remui %0, %7 : tensor<128xi32> 2701- return 2702+ return 2703 } 2704 2705 // ----- 2706 2707-// CHECK-LABEL: broadcast 2708-func @broadcast() { 2709- // CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64] 2710+// CHECK-LABEL: @broadcast 2711+func.func @broadcast() { 2712+ // CHECK: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 2713 %0 = arith.constant dense<64> : tensor<128xi32> 2714- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [64, 1] ; Constancy: [128, 1] ; ConstantValue: [64] 2715+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 1], constant_value = 64 2716 %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> 2717- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [64, 1] ; Constancy: [128, 128] ; ConstantValue: [64] 2718+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 128], constant_value = 64 2719 %2 = tt.broadcast %1 : (tensor<128x1xi32>) -> tensor<128x128xi32> 2720 return 2721 } 2722 2723 // ----- 2724 2725-// CHECK-LABEL: splat 2726-func @splat(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) { 2727- // CHECK: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 128] ; ConstantValue: [None] 2728+// CHECK-LABEL: @splat 2729+func.func @splat(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) { 2730+ // CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = <none> 2731 %0 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x128x!tt.ptr<f32>> 2732 return 2733 } 2734 2735 // ----- 2736 2737-// CHECK-LABEL: cmp 2738-func @cmp() { 2739- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] 2740+// CHECK-LABEL: @cmp 2741+func.func @cmp() { 2742+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none> 2743 %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> 2744- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0] 2745+ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 2746 %1 = arith.constant dense<0> : tensor<128xi32> 2747- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None] 2748+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none> 2749 %2 = arith.cmpi eq, %0, %1 : tensor<128xi32> 2750- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None] 2751+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none> 2752 %3 = arith.cmpi slt, %0, %1 : tensor<128xi32> 2753- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] 2754+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none> 2755 %4 = arith.cmpi sle, %0, %1 : tensor<128xi32> 2756- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None] 2757+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none> 2758 %5 = arith.cmpi sge, %0, %1 : tensor<128xi32> 2759- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8] 2760+ // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8 2761 %6 = arith.constant dense<8> : tensor<128xi32> 2762- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None] 2763+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none> 2764 %7 = arith.cmpi sgt, %0, %6 : tensor<128xi32> 2765- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [0] 2766+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 0 2767 %8 = arith.cmpi sgt, %1, %6 : tensor<128xi32> 2768 return 2769 } 2770 2771 // ----- 2772 2773-// CHECK-LABEL: logic 2774-func @logic() { 2775- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] 2776+// CHECK-LABEL: @logic 2777+func.func @logic() { 2778+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none> 2779 %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> 2780- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64] 2781+ // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 2782 %1 = arith.constant dense<64> : tensor<128xi32> 2783- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16777216] ; Constancy: [64] ; ConstantValue: [None] 2784+ // CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = <none> 2785 %2 = arith.divsi %0, %1 : tensor<128xi32> 2786- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8] 2787+ // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8 2788 %3 = arith.constant dense<8> : tensor<128xi32> 2789- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [134217728] ; Constancy: [8] ; ConstantValue: [None] 2790+ // CHECK-NEXT: contiguity = [1], divisibility = [134217728], constancy = [8], constant_value = <none> 2791 %4 = arith.divsi %0, %3 : tensor<128xi32> 2792- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] 2793+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none> 2794 %5 = arith.andi %0, %1 : tensor<128xi32> 2795- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] 2796+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none> 2797 %6 = arith.ori %0, %1 : tensor<128xi32> 2798- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] 2799+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none> 2800 %7 = arith.xori %0, %1 : tensor<128xi32> 2801- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None] 2802+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none> 2803 %8 = arith.andi %2, %4 : tensor<128xi32> 2804- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None] 2805+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none> 2806 %9 = arith.ori %2, %4 : tensor<128xi32> 2807- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None] 2808+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none> 2809 %10 = arith.xori %2, %4 : tensor<128xi32> 2810 return 2811 } 2812 2813 // ----- 2814 2815-// CHECK-LABEL: select 2816-func @select() { 2817- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] 2818+// CHECK-LABEL: @select 2819+func.func @select() { 2820+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none> 2821 %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> 2822- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0] 2823+ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 2824 %1 = arith.constant dense<0> : tensor<128xi32> 2825- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None] 2826+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none> 2827 %2 = arith.cmpi eq, %0, %1 : tensor<128xi32> 2828- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None] 2829+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none> 2830 %3 = arith.cmpi slt, %0, %1 : tensor<128xi32> 2831- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [1] ; ConstantValue: [0] 2832+ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0 2833 %4 = arith.constant 0 : i1 2834- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0] 2835+ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 2836 %7 = tt.splat %4 : (i1) -> tensor<128xi1> 2837- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0] 2838- %5 = select %4, %3, %7 : tensor<128xi1> 2839- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None] 2840+ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 2841+ %5 = arith.select %4, %3, %7 : tensor<128xi1> 2842+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none> 2843 %8 = "triton_gpu.select"(%7, %3, %2) : (tensor<128xi1>, tensor<128xi1>, tensor<128xi1>) -> tensor<128xi1> 2844 return 2845 } 2846 2847 // ----- 2848 2849-func @shift() { 2850- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] 2851+func.func @shift() { 2852+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none> 2853 %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> 2854- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8] 2855+ // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8 2856 %1 = arith.constant dense<8> : tensor<128xi32> 2857- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4] ; Constancy: [128] ; ConstantValue: [4] 2858+ // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4 2859 %2 = arith.constant dense<4> : tensor<128xi32> 2860- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [274877906944] ; Constancy: [1] ; ConstantValue: [None] 2861+ // CHECK-NEXT: contiguity = [1], divisibility = [274877906944], constancy = [1], constant_value = <none> 2862 %3 = arith.shli %0, %1 : tensor<128xi32> 2863- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [67108864] ; Constancy: [1] ; ConstantValue: [None] 2864+ // CHECK-NEXT: contiguity = [1], divisibility = [67108864], constancy = [1], constant_value = <none> 2865 %4 = arith.shrsi %0, %2 : tensor<128xi32> 2866- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128] 2867+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 2868 %5 = arith.shli %1, %2 : tensor<128xi32> 2869 return 2870 } 2871 2872 // ----- 2873 2874-func @max_min() { 2875- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] 2876+func.func @max_min() { 2877+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none> 2878 %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> 2879- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None] 2880+ // CHECK-NEXT: contiguity = [128], divisibility = [64], constancy = [1], constant_value = <none> 2881 %1 = tt.make_range {end = 192 : i32, start = 64 : i32} : tensor<128xi32> 2882- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] 2883+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none> 2884 %2 = arith.maxsi %0, %1 : tensor<128xi32> 2885- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] 2886+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none> 2887 %3 = arith.minsi %0, %1 : tensor<128xi32> 2888- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8] 2889+ // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8 2890 %4 = arith.constant dense<8> : tensor<128xi32> 2891- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4] ; Constancy: [128] ; ConstantValue: [4] 2892+ // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4 2893 %5 = arith.constant dense<4> : tensor<128xi32> 2894- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [8] 2895+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 8 2896 %6 = arith.maxsi %4, %5 : tensor<128xi32> 2897 return 2898 } 2899 2900 // ----- 2901 2902-// CHECK-LABEL: for 2903-func @for() { 2904- // CHECK: Contiguity: [1, 1] ; Divisibility: [4611686018427387904, 4611686018427387904] ; Constancy: [128, 32] ; ConstantValue: [0] 2905+// CHECK-LABEL: @for 2906+func.func @for() { 2907+ // CHECK: contiguity = [1, 1], divisibility = [4611686018427387904, 4611686018427387904], constancy = [128, 32], constant_value = 0 2908 %a_init = arith.constant dense<0> : tensor<128x32xi32> 2909- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [1] 2910+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = 1 2911 %b_init = arith.constant dense<1> : tensor<128x32xi32> 2912- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [4, 4] ; Constancy: [128, 32] ; ConstantValue: [4] 2913+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4 2914 %c_init = arith.constant dense<4> : tensor<128x32xi32> 2915- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [128] 2916+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128 2917 %ub = arith.constant 128 : index 2918- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [1] ; ConstantValue: [0] 2919+ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0 2920 %lb = arith.constant 0 : index 2921- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [16] 2922+ // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = 16 2923 %step = arith.constant 16 : index 2924 %a, %b, %c = scf.for %iv = %lb to %ub step %step iter_args(%a = %a_init, %b = %b_init, %c = %c_init) -> (tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>) { 2925- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None] 2926+ // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = <none> 2927 %t = arith.index_cast %iv : index to i32 2928- // CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [None] 2929- // CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [None] 2930- // CHECK: Contiguity: [1, 1] ; Divisibility: [4, 4] ; Constancy: [128, 32] ; ConstantValue: [4] 2931+ // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = <none> 2932+ // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = <none> 2933+ // CHECK: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4 2934 scf.yield %b, %a, %c : tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32> 2935 } 2936 return 2937@@ -290,53 +290,53 @@ func @for() { 2938 2939 // ----- 2940 2941-// CHECK-LABEL: permute_2d 2942-func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { 2943- // CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 128] ; ConstantValue: [1] 2944+// CHECK-LABEL: @permute_2d 2945+func.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { 2946+ // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 128], constant_value = 1 2947 %cst = arith.constant dense<true> : tensor<128x128xi1> 2948- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] ; ConstantValue: [None] 2949+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = <none> 2950 %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> 2951- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] 2952+ // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none> 2953 %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> 2954- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] 2955+ // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none> 2956 %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> 2957- // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [1073741824, 1] ; Constancy: [1, 1] ; ConstantValue: [None] 2958+ // CHECK-NEXT: contiguity = [128, 1], divisibility = [1073741824, 1], constancy = [1, 1], constant_value = <none> 2959 %2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> 2960- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None] 2961+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = <none> 2962 %3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32> 2963- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [17179869184, 16] ; Constancy: [1, 1] ; ConstantValue: [None] 2964+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [17179869184, 16], constancy = [1, 1], constant_value = <none> 2965 %4 = arith.muli %2, %3 : tensor<128x1xi32> 2966- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None] 2967+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = <none> 2968 %5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>> 2969- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1] ; ConstantValue: [None] 2970+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = <none> 2971 %6 = tt.addptr %5, %4 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32> 2972- // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [1, 1] ; ConstantValue: [None] 2973+ // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = <none> 2974 %7 = tt.expand_dims %1 {axis = 0 : i32}: (tensor<128xi32>) -> tensor<1x128xi32> 2975- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] ; ConstantValue: [None] 2976+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 128], constant_value = <none> 2977 %8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>> 2978- // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [128, 1] ; ConstantValue: [None] 2979+ // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [128, 1], constant_value = <none> 2980 %9 = tt.broadcast %7 : (tensor<1x128xi32>) -> tensor<128x128xi32> 2981- // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1] ; ConstantValue: [None] 2982+ // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 16], constancy = [1, 1], constant_value = <none> 2983 %10 = tt.addptr %8, %9 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32> 2984- // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [1073741824, 1] ; Constancy: [1, 1] ; ConstantValue: [None] 2985+ // CHECK-NEXT: contiguity = [128, 1], divisibility = [1073741824, 1], constancy = [1, 1], constant_value = <none> 2986 %11 = tt.expand_dims %0 {axis = 1 : i32}: (tensor<128xi32>) -> tensor<128x1xi32> 2987- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None] 2988+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = <none> 2989 %12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>> 2990- // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] ; ConstantValue: [None] 2991+ // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = <none> 2992 %13 = tt.addptr %12, %11 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32> 2993- // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [1, 1] ; ConstantValue: [None] 2994+ // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = <none> 2995 %14 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32> 2996- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] ; ConstantValue: [None] 2997+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 128], constant_value = <none> 2998 %15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32> 2999- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 17179869184] ; Constancy: [1, 1] ; ConstantValue: [None] 3000+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [1, 1], constant_value = <none> 3001 %16 = arith.muli %14, %15 : tensor<1x128xi32> 3002- // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 128] ; ConstantValue: [None] 3003+ // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 128], constant_value = <none> 3004 %17 = tt.broadcast %13 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>> 3005- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 17179869184] ; Constancy: [128, 1] ; ConstantValue: [None] 3006+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [128, 1], constant_value = <none> 3007 %18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32> 3008- // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] ; ConstantValue: [None] 3009+ // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = <none> 3010 %19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32> 3011- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] ; ConstantValue: [None] 3012+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = <none> 3013 %20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32> 3014 tt.store %19, %20, %cst : tensor<128x128xf32> 3015 return 3016@@ -347,29 +347,29 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t 3017 module { 3018 3019 // This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer. 3020-// CHECK-LABEL: store_constant_align 3021-func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) { 3022- // CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] 3023+// CHECK-LABEL: @store_constant_align 3024+func.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) { 3025+ // CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none> 3026 %pid = tt.get_program_id {axis = 0 : i32} : i32 3027- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [128] 3028+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128 3029 %c128_i32 = arith.constant 128 : i32 3030- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [None] 3031+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = <none> 3032 %1 = arith.muli %pid, %c128_i32 : i32 3033- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] 3034+ // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none> 3035 %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> 3036- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [None] 3037+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = <none> 3038 %3 = tt.splat %1 : (i32) -> tensor<128xi32> 3039- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [None] 3040+ // CHECK-NEXT: contiguity = [128], divisibility = [128], constancy = [1], constant_value = <none> 3041 %4 = arith.addi %3, %2 : tensor<128xi32> 3042- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128] ; ConstantValue: [None] 3043+ // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [128], constant_value = <none> 3044 %5 = tt.splat %addr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>> 3045- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None] 3046+ // CHECK-NEXT: contiguity = [128], divisibility = [16], constancy = [1], constant_value = <none> 3047 %6 = tt.addptr %5, %4 : tensor<128x!tt.ptr<f32>>, tensor<128xi32> 3048- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128] ; ConstantValue: [None] 3049+ // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [128], constant_value = <none> 3050 %9 = tt.splat %n : (i32) -> tensor<128xi32> 3051- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [16] ; ConstantValue: [None] 3052+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [16], constant_value = <none> 3053 %mask = arith.cmpi slt, %4, %9 : tensor<128xi32> 3054- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] 3055+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none> 3056 %cst = arith.constant dense<0.0> : tensor<128xf32> 3057 tt.store %5, %cst, %mask : tensor<128xf32> 3058 return 3059@@ -381,8 +381,8 @@ func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: 3060 3061 // This IR is dumped from vecadd test. 3062 // Note, the hint {tt.divisibility = 16 : i32} for %n_elements affects the alignment of mask. 3063-// CHECK-LABEL: vecadd_mask_align_16 3064-func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) { 3065+// CHECK-LABEL: @vecadd_mask_align_16 3066+func.func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) { 3067 %c64_i32 = arith.constant 64 : i32 3068 %0 = tt.get_program_id {axis = 0 : i32} : i32 3069 %1 = arith.muli %0, %c64_i32 : i32 3070@@ -394,13 +394,13 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ar 3071 %7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>> 3072 %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32> 3073 %9 = tt.splat %n_elements : (i32) -> tensor<64xi32> 3074- // CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [16] ; ConstantValue: [None] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> ) 3075+ // CHECK: arith.cmpi slt, %{{.*}} => contiguity = [1], divisibility = [1], constancy = [16], constant_value = <none> 3076 %mask = arith.cmpi slt, %4, %9 : tensor<64xi32> 3077 %11 = tt.load %6, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> 3078 %12 = tt.load %8, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> 3079 %13 = arith.addf %11, %12 : tensor<64xf32> 3080 %14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>> 3081- // CHECK: Contiguity: [64] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None] ( %{{.*}} = tt.addptr %{{.*}}, %{{.*}} : tensor<64x!tt.ptr<f32>>, tensor<64xi32> ) 3082+ // CHECK: tt.addptr %{{.*}} => contiguity = [64], divisibility = [16], constancy = [1], constant_value = <none> 3083 %15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32> 3084 tt.store %15, %13, %mask : tensor<64xf32> 3085 return 3086@@ -410,8 +410,8 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ar 3087 3088 // This IR is dumped from vecadd test. 3089 // Note, there is no divisibility hint for %n_elements, Triton should assume its divisibility to be 1 by default. 3090-// CHECK-LABEL: vecadd_mask_align_1 3091-func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) { 3092+// CHECK-LABEL: @vecadd_mask_align_1 3093+func.func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) { 3094 %c64_i32 = arith.constant 64 : i32 3095 %0 = tt.get_program_id {axis = 0 : i32} : i32 3096 %1 = arith.muli %0, %c64_i32 : i32 3097@@ -423,7 +423,7 @@ func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg 3098 %7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>> 3099 %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32> 3100 %9 = tt.splat %n_elements : (i32) -> tensor<64xi32> 3101- // CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> ) 3102+ // CHECK: arith.cmpi slt, %{{.*}} => contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none> 3103 %10 = arith.cmpi slt, %4, %9 : tensor<64xi32> 3104 %11 = tt.load %6, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> 3105 %12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> 3106diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir 3107index efb00c404d..f79222aa7b 100644 3108--- a/test/Analysis/test-allocation.mlir 3109+++ b/test/Analysis/test-allocation.mlir 3110@@ -13,7 +13,7 @@ 3111 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3112 3113 // CHECK-LABEL: matmul_loop 3114-func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) { 3115+func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) { 3116 %a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL> 3117 %b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL> 3118 3119@@ -46,7 +46,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B 3120 3121 // Shared memory is available after a tensor's liveness range ends 3122 // CHECK-LABEL: reusable 3123-func @reusable(%A : !tt.ptr<f16>) { 3124+func.func @reusable(%A : !tt.ptr<f16>) { 3125 %cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL> 3126 %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> 3127 %cst3 = arith.constant dense<true> : tensor<32x128xi1, #AL> 3128@@ -78,7 +78,7 @@ func @reusable(%A : !tt.ptr<f16>) { 3129 // %cst1->%cst4 3130 // %cst3->%g->%h->%i 3131 // CHECK-LABEL: preallocate 3132-func @preallocate(%A : !tt.ptr<f16>) { 3133+func.func @preallocate(%A : !tt.ptr<f16>) { 3134 // CHECK: offset = 0, size = 512 3135 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> 3136 // CHECK-NEXT: offset = 1024, size = 512 3137@@ -113,7 +113,7 @@ func @preallocate(%A : !tt.ptr<f16>) { 3138 3139 // Unused tensors are immediately released 3140 // CHECK-LABEL: unused 3141-func @unused(%A : !tt.ptr<f16>) { 3142+func.func @unused(%A : !tt.ptr<f16>) { 3143 // CHECK: offset = 0, size = 1024 3144 %cst0 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A_SHARED> 3145 // CHECK-NEXT: offset = 0, size = 512 3146@@ -128,7 +128,7 @@ func @unused(%A : !tt.ptr<f16>) { 3147 3148 // cst0 is alive through the entire function, it cannot be released before the end of the function 3149 // CHECK-LABEL: longlive 3150-func @longlive(%A : !tt.ptr<f16>) { 3151+func.func @longlive(%A : !tt.ptr<f16>) { 3152 // CHECK: offset = 0, size = 512 3153 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> 3154 // CHECK-NEXT: offset = 512, size = 512 3155@@ -156,7 +156,7 @@ func @longlive(%A : !tt.ptr<f16>) { 3156 } 3157 3158 // CHECK-LABEL: alloc 3159-func @alloc(%A : !tt.ptr<f16>) { 3160+func.func @alloc(%A : !tt.ptr<f16>) { 3161 // CHECK: offset = 0, size = 512 3162 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> 3163 %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> 3164@@ -167,7 +167,7 @@ func @alloc(%A : !tt.ptr<f16>) { 3165 } 3166 3167 // CHECK-LABEL: scratch 3168-func @scratch() { 3169+func.func @scratch() { 3170 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> 3171 // CHECK: scratch offset = 0, size = 512 3172 %b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #sliceAd0> 3173@@ -176,7 +176,7 @@ func @scratch() { 3174 } 3175 3176 // CHECK-LABEL: trans 3177-func @trans(%A : !tt.ptr<f16>) { 3178+func.func @trans(%A : !tt.ptr<f16>) { 3179 // CHECK: offset = 0, size = 1024 3180 %tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED> 3181 %b = tt.trans %tensor : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T> 3182@@ -184,7 +184,7 @@ func @trans(%A : !tt.ptr<f16>) { 3183 } 3184 3185 // CHECK-LABEL: insert_slice_async 3186-func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) { 3187+func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) { 3188 %a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL> 3189 %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> 3190 %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> 3191@@ -197,7 +197,7 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) { 3192 } 3193 3194 // CHECK-LABEL: extract_slice 3195-func @extract_slice(%A : !tt.ptr<f16>) { 3196+func.func @extract_slice(%A : !tt.ptr<f16>) { 3197 // CHECK: offset = 0, size = 512 3198 %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED> 3199 %index = arith.constant 0 : index 3200@@ -209,7 +209,7 @@ func @extract_slice(%A : !tt.ptr<f16>) { 3201 // B0 -> (B1) -> B0 3202 // Memory used by B1 can be reused by B0. 3203 // CHECK-LABEL: if 3204-func @if(%i1 : i1) { 3205+func.func @if(%i1 : i1) { 3206 // CHECK: offset = 0, size = 512 3207 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> 3208 // CHECK-NEXT: offset = 512, size = 512 3209@@ -233,7 +233,7 @@ func @if(%i1 : i1) { 3210 // B0 -> (B1) -> (B2) -> B0 3211 // Memory used by B0 cannot be reused by B1 or B2. 3212 // CHECK-LABEL: if_else 3213-func @if_else(%i1 : i1) { 3214+func.func @if_else(%i1 : i1) { 3215 // CHECK: offset = 0, size = 512 3216 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> 3217 // CHECK-NEXT: offset = 512, size = 512 3218@@ -260,7 +260,7 @@ func @if_else(%i1 : i1) { 3219 // Block arguments and yields are memory aliases that do not trigger a new 3220 // allocation. 3221 // CHECK-LABEL: for 3222-func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) { 3223+func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) { 3224 // CHECK: offset = 0, size = 8192 3225 %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> 3226 // CHECK-NEXT: offset = 8192, size = 8192 3227@@ -275,7 +275,7 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.p 3228 } 3229 3230 // CHECK-LABEL: for_if_slice 3231-func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) { 3232+func.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) { 3233 // CHECK: offset = 0, size = 8192 3234 %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> 3235 // CHECK-NEXT: offset = 8192, size = 8192 3236@@ -296,7 +296,7 @@ func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, % 3237 3238 // c0 cannot be released in the loop 3239 // CHECK-LABEL: for_use_ancestor 3240-func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) { 3241+func.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) { 3242 // CHECK: offset = 0, size = 8192 3243 %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> 3244 // CHECK-NEXT: offset = 8192, size = 8192 3245@@ -316,7 +316,7 @@ func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16 3246 // a_shared_init, b_shared_init, and c_shared_init's liveness ranges are span over the entire function before cst2. 3247 // So they cannot be reused by cst0 and cst1, but can be reused by cst2. 3248 // CHECK-LABEL: for_if_for 3249-func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) { 3250+func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) { 3251 // CHECK: offset = 0, size = 8192 3252 %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> 3253 // CHECK-NEXT: offset = 8192, size = 8192 3254diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir 3255index 7199e5f53d..17880b2094 100644 3256--- a/test/Analysis/test-membar.mlir 3257+++ b/test/Analysis/test-membar.mlir 3258@@ -14,7 +14,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3259 3260 // CHECK-LABEL: matmul_loop 3261 // There shouldn't be any membar with the dot op encoding. 3262-func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) { 3263+func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) { 3264 %a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL> 3265 %b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL> 3266 3267@@ -42,7 +42,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B 3268 } 3269 3270 // CHECK-LABEL: raw_single_block 3271-func @raw_single_block(%A : !tt.ptr<f16>) { 3272+func.func @raw_single_block(%A : !tt.ptr<f16>) { 3273 %cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL> 3274 %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> 3275 %a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL> 3276@@ -54,7 +54,7 @@ func @raw_single_block(%A : !tt.ptr<f16>) { 3277 } 3278 3279 // CHECK-LABEL: war_single_block 3280-func @war_single_block(%A : !tt.ptr<f16>) { 3281+func.func @war_single_block(%A : !tt.ptr<f16>) { 3282 %cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL> 3283 %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> 3284 %a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL> 3285@@ -70,7 +70,7 @@ func @war_single_block(%A : !tt.ptr<f16>) { 3286 } 3287 3288 // CHECK-LABEL: scratch 3289-func @scratch() { 3290+func.func @scratch() { 3291 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> 3292 // CHECK: Membar 1 3293 %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> 3294@@ -81,7 +81,7 @@ func @scratch() { 3295 } 3296 3297 // CHECK-LABEL: async_wait 3298-func @async_wait() { 3299+func.func @async_wait() { 3300 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> 3301 // CHECK: Membar 1 3302 %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> 3303@@ -92,7 +92,7 @@ func @async_wait() { 3304 } 3305 3306 // CHECK-LABEL: alloc 3307-func @alloc() { 3308+func.func @alloc() { 3309 %cst0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED> 3310 %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> 3311 // CHECK: Membar 2 3312@@ -101,7 +101,7 @@ func @alloc() { 3313 } 3314 3315 // CHECK-LABEL: extract_slice 3316-func @extract_slice() { 3317+func.func @extract_slice() { 3318 %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED> 3319 %index = arith.constant 0 : index 3320 %cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED> 3321@@ -113,14 +113,14 @@ func @extract_slice() { 3322 } 3323 3324 // CHECK-LABEL: trans 3325-func @trans() { 3326+func.func @trans() { 3327 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED> 3328 %b = tt.trans %cst0 : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T> 3329 return 3330 } 3331 3332 // CHECK-LABEL: insert_slice_async 3333-func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) { 3334+func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) { 3335 %a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL> 3336 %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> 3337 %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> 3338@@ -135,7 +135,7 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) { 3339 } 3340 3341 // CHECK-LABEL: insert_slice 3342-func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) { 3343+func.func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) { 3344 %a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL> 3345 %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> 3346 %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> 3347@@ -153,7 +153,7 @@ func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) { 3348 3349 // If branch inserted a barrier for %cst0 and %cst1, but else didn't, then the barrier should be inserted in the parent region 3350 // CHECK-LABEL: multi_blocks 3351-func @multi_blocks(%i1 : i1) { 3352+func.func @multi_blocks(%i1 : i1) { 3353 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> 3354 %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> 3355 scf.if %i1 { 3356@@ -174,7 +174,7 @@ func @multi_blocks(%i1 : i1) { 3357 3358 // Both branches inserted a barrier for %cst0 and %cst1, then the barrier doesn't need to be inserted in the parent region 3359 // CHECK-LABEL: multi_blocks_join_barrier 3360-func @multi_blocks_join_barrier(%i1 : i1) { 3361+func.func @multi_blocks_join_barrier(%i1 : i1) { 3362 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> 3363 %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> 3364 scf.if %i1 { 3365@@ -192,7 +192,7 @@ func @multi_blocks_join_barrier(%i1 : i1) { 3366 3367 // Read yielded tensor requires a barrier 3368 // CHECK-LABEL: multi_blocks_yield 3369-func @multi_blocks_yield(%i1 : i1) { 3370+func.func @multi_blocks_yield(%i1 : i1) { 3371 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> 3372 %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> 3373 %a = scf.if %i1 -> (tensor<32x16xf16, #A_SHARED>) { 3374@@ -212,7 +212,7 @@ func @multi_blocks_yield(%i1 : i1) { 3375 3376 // Conservatively add a barrier as if the branch (%i1) is never taken 3377 // CHECK-LABEL: multi_blocks_noelse 3378-func @multi_blocks_noelse(%i1 : i1) { 3379+func.func @multi_blocks_noelse(%i1 : i1) { 3380 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> 3381 %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> 3382 scf.if %i1 { 3383@@ -226,7 +226,7 @@ func @multi_blocks_noelse(%i1 : i1) { 3384 3385 // Conservatively add a barrier as if the branch (%i2) is never taken 3386 // CHECK-LABEL: multi_blocks_nested_scf 3387-func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) { 3388+func.func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) { 3389 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> 3390 %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> 3391 scf.if %i1 { 3392@@ -247,7 +247,7 @@ func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) { 3393 } 3394 3395 // CHECK-LABEL: for 3396-func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) { 3397+func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) { 3398 %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> 3399 %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> 3400 %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> 3401@@ -262,7 +262,7 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.p 3402 // Although a_shared and b_shared are synced before entering the loop, 3403 // they are reassociated with aliases (c_shared) and thus require a barrier. 3404 // CHECK-LABEL: for_alias 3405-func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) { 3406+func.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) { 3407 %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> 3408 %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> 3409 // CHECK-NEXT: Membar 2 3410@@ -282,7 +282,7 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : 3411 // Although cst2 is not an argument of scf.yield, its memory is reused by cst1. 3412 // So we need a barrier both before and after cst1 3413 // CHECK-LABEL: for_reuse 3414-func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) { 3415+func.func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) { 3416 %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> 3417 %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> 3418 // CHECK-NEXT: Membar 2 3419@@ -302,7 +302,7 @@ func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : 3420 3421 3422 // CHECK-LABEL: for_reuse_nested 3423-func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) { 3424+func.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) { 3425 %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> 3426 %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> 3427 // CHECK-NEXT: Membar 2 3428diff --git a/test/Conversion/triton_ops.mlir b/test/Conversion/triton_ops.mlir 3429index e9ee502435..0e979b148d 100644 3430--- a/test/Conversion/triton_ops.mlir 3431+++ b/test/Conversion/triton_ops.mlir 3432@@ -1,6 +1,6 @@ 3433 // RUN: triton-opt %s | FileCheck %s 3434 3435-func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) { 3436+func.func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) { 3437 // scalar -> scalar 3438 // CHECK: i64 -> !tt.ptr<f32> 3439 %0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr<f32> 3440@@ -35,7 +35,7 @@ func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) { 3441 return 3442 } 3443 3444-func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) { 3445+func.func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) { 3446 // scalar -> scalar 3447 // CHECK: !tt.ptr<f32> 3448 %0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr<f32>, i32 3449@@ -54,7 +54,7 @@ func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) { 3450 return 3451 } 3452 3453-func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %mask : i1) { 3454+func.func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %mask : i1) { 3455 // Test if Load/Store ops can handle scalar values 3456 %other = arith.constant 0.0e+0 : f32 3457 3458@@ -76,7 +76,7 @@ func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ma 3459 return 3460 } 3461 3462-func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) { 3463+func.func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) { 3464 // Test if reduce ops infer types correctly 3465 3466 // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<2x4xf32> 3467@@ -101,7 +101,7 @@ func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) { 3468 return 3469 } 3470 3471-func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) { 3472+func.func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) { 3473 // Test if reduce ops infer types correctly 3474 %v128x32 = tt.splat %v : (f32) -> tensor<128x32xf32> 3475 %v32x128 = tt.splat %v : (f32) -> tensor<32x128xf32> 3476diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir 3477index a160bc8815..b461ca542f 100644 3478--- a/test/Conversion/triton_to_tritongpu.mlir 3479+++ b/test/Conversion/triton_to_tritongpu.mlir 3480@@ -1,6 +1,6 @@ 3481 // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s 3482 3483-func @ops() { 3484+func.func @ops() { 3485 // CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}} 3486 %a = arith.constant dense<1.00e+00> : tensor<128x32xf16> 3487 %b = arith.constant dense<2.00e+00> : tensor<32x128xf16> 3488@@ -11,7 +11,7 @@ func @ops() { 3489 3490 // ----- 3491 3492-func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) { 3493+func.func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) { 3494 // Test if LoadOp is lowered properly (see #771) 3495 %ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>> 3496 %mask = arith.constant dense<true> : tensor<128xi1> 3497@@ -30,7 +30,7 @@ func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) { 3498 3499 // ----- 3500 3501-func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) { 3502+func.func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) { 3503 // Test if the total number of threadsPerWarp is 32 3504 // Test if the total number of warps is 2 3505 // CHECK: #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 2], order = [0, 1]}> 3506diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir 3507index e9e7d5a340..507b362c99 100644 3508--- a/test/Conversion/tritongpu_to_llvm.mlir 3509+++ b/test/Conversion/tritongpu_to_llvm.mlir 3510@@ -4,7 +4,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3511 // CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>) 3512 // Here the 128 comes from the 4 in module attribute multiples 32 3513 // CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : i32} {{.*}} 3514- func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) { 3515+ func.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) { 3516 // CHECK: llvm.return 3517 return 3518 } 3519@@ -15,7 +15,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3520 #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> 3521 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3522 // CHECK-LABEL: basic_load 3523- func @basic_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { 3524+ func.func @basic_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { 3525 // CHECK: llvm.inline_asm 3526 // CHECK: llvm.inline_asm 3527 %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> 3528@@ -28,7 +28,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3529 #blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> 3530 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3531 // CHECK-LABEL: vectorized_load 3532- func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { 3533+ func.func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { 3534 // CHECK: llvm.inline_asm 3535 // CHECK-SAME: ld.global.b32 3536 // CHECK: llvm.inline_asm 3537@@ -43,7 +43,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3538 #blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> 3539 module attributes {"triton_gpu.num-warps" = 1 : i32} { 3540 // CHECK-LABEL: vectorized_load_f16 3541- func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr<f16>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) { 3542+ func.func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr<f16>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) { 3543 // CHECK: llvm.inline_asm 3544 // CHECK-SAME: ld.global.b16 3545 // CHECK: llvm.inline_asm 3546@@ -59,7 +59,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { 3547 #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> 3548 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3549 // CHECK-LABEL: masked_load_const_other 3550- func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) { 3551+ func.func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) { 3552 %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0> 3553 %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> 3554 return 3555@@ -72,7 +72,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3556 #blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> 3557 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3558 // CHECK-LABEL: masked_load_const_other_vec 3559- func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) { 3560+ func.func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) { 3561 %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0> 3562 %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> 3563 return 3564@@ -84,7 +84,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3565 #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> 3566 module attributes {"triton_gpu.num-warps" = 2 : i32} { 3567 // CHECK-LABEL: global_load_store_no_vec 3568- func @global_load_store_no_vec(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) { 3569+ func.func @global_load_store_no_vec(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) { 3570 %c256_i32 = arith.constant 256 : i32 3571 %0 = tt.get_program_id {axis = 0 : i32} : i32 3572 %1 = arith.muli %0, %c256_i32 : i32 3573@@ -128,7 +128,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { 3574 #blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> 3575 module attributes {"triton_gpu.num-warps" = 2 : i32} { 3576 // CHECK-LABEL: global_load_store_vec4 3577- func @global_load_store_vec4(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) { 3578+ func.func @global_load_store_vec4(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) { 3579 %c256_i32 = arith.constant 256 : i32 3580 %0 = tt.get_program_id {axis = 0 : i32} : i32 3581 %1 = arith.muli %0, %c256_i32 : i32 3582@@ -165,7 +165,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { 3583 #blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> 3584 // Note, the %n_elements doesn't have a "tt.divisibility" hint, so Triton assumes it's divisibility is 1, this should effect the mask's alignment and further restrict the load/store ops' vector width to be 1. 3585 module attributes {"triton_gpu.num-warps" = 2 : i32} { 3586- func @vecadd_masked_vec1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) { 3587+ func.func @vecadd_masked_vec1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) { 3588 %c64_i32 = arith.constant 64 : i32 3589 %0 = tt.get_program_id {axis = 0 : i32} : i32 3590 %1 = arith.muli %0, %c64_i32 : i32 3591@@ -195,7 +195,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { 3592 #blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> 3593 module attributes {"triton_gpu.num-warps" = 1 : i32} { 3594 // CHECK-LABEL: global_load_store_vec2 3595- func @global_load_store_vec2(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg3: i32) { 3596+ func.func @global_load_store_vec2(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg3: i32) { 3597 %c256_i32 = arith.constant 256 : i32 3598 %0 = tt.get_program_id {axis = 0 : i32} : i32 3599 %1 = arith.muli %0, %c256_i32 : i32 3600@@ -240,7 +240,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { 3601 #blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> 3602 module attributes {"triton_gpu.num-warps" = 1 : i32} { 3603 // CHECK-LABEL: global_load_store_vec8 3604- func @global_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) { 3605+ func.func @global_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) { 3606 %c256_i32 = arith.constant 256 : i32 3607 %0 = tt.get_program_id {axis = 0 : i32} : i32 3608 %1 = arith.muli %0, %c256_i32 : i32 3609@@ -283,7 +283,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { 3610 #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> 3611 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3612 // CHECK-LABEL: basic_view_broadcast 3613- func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) { 3614+ func.func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) { 3615 // CHECK: llvm.mlir.undef 3616 // CHECK: %[[T0:.*]] = llvm.extractvalue 3617 // CHECK: %[[T1:.*]] = llvm.extractvalue 3618@@ -307,7 +307,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3619 #blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> 3620 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3621 // CHECK-LABEL: basic_make_range 3622- func @basic_make_range() { 3623+ func.func @basic_make_range() { 3624 // CHECK: nvvm.read.ptx.sreg.tid.x 3625 // CHECK: llvm.mlir.undef 3626 // CHECK: llvm.insertvalue 3627@@ -322,7 +322,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3628 #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> 3629 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3630 // CHECK-LABEL: basic_addf 3631- func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) { 3632+ func.func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) { 3633 // CHECK: llvm.fadd 3634 // CHECK: llvm.fadd 3635 %1 = arith.addf %arg0, %arg1 : tensor<256xf32,#blocked0> 3636@@ -335,7 +335,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3637 #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> 3638 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3639 // CHECK-LABEL: basic_addi 3640- func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { 3641+ func.func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { 3642 // CHECK: llvm.add 3643 // CHECK: llvm.add 3644 %1 = arith.addi %arg0, %arg1 : tensor<256xi32,#blocked0> 3645@@ -347,7 +347,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3646 3647 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3648 // CHECK-LABEL: basic_program_id 3649- func @basic_program_id() { 3650+ func.func @basic_program_id() { 3651 // CHECK: nvvm.read.ptx.sreg.ctaid.x : i32 3652 %0 = tt.get_program_id {axis = 0 : i32} : i32 3653 return 3654@@ -359,7 +359,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3655 #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> 3656 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3657 // CHECK-LABEL: basic_addptr 3658- func @basic_addptr(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { 3659+ func.func @basic_addptr(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { 3660 // CHECK: llvm.getelementptr 3661 // CHECK: llvm.getelementptr 3662 %0 = tt.addptr %arg0, %arg1 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0> 3663@@ -373,7 +373,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3664 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3665 // CHECK: llvm.mlir.global external @global_smem 3666 // CHECK-LABEL: basic_alloc_tensor 3667- func @basic_alloc_tensor() { 3668+ func.func @basic_alloc_tensor() { 3669 // CHECK: llvm.mlir.addressof @global_smem 3670 // CHECK-NEXT: llvm.bitcast 3671 // CHECK-NEXT: llvm.mlir.constant 3672@@ -390,7 +390,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3673 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3674 // CHECK: llvm.mlir.global external @global_smem 3675 // CHECK-LABEL: basic_extract_slice 3676- func @basic_extract_slice() { 3677+ func.func @basic_extract_slice() { 3678 // CHECK: llvm.mlir.addressof @global_smem 3679 // CHECK: llvm.extractvalue 3680 // CHECK-NEXT: llvm.extractvalue 3681@@ -423,7 +423,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3682 3683 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3684 // CHECK-LABEL: basic_async_wait 3685- func @basic_async_wait() { 3686+ func.func @basic_async_wait() { 3687 // CHECK: cp.async.wait_group 0x4 3688 triton_gpu.async_wait {num = 4: i32} 3689 return 3690@@ -442,7 +442,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3691 #A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}> 3692 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3693 // CHECK-LABEL: basic_insert_slice_async_fallback 3694- func @basic_insert_slice_async_fallback(%arg0: !tt.ptr<f16> {tt.divisibility = 1 : i32}) { 3695+ func.func @basic_insert_slice_async_fallback(%arg0: !tt.ptr<f16> {tt.divisibility = 1 : i32}) { 3696 %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> 3697 %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0> 3698 %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2> 3699@@ -481,7 +481,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3700 #A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}> 3701 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3702 // CHECK-LABEL: basic_insert_slice_async_v4 3703- func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 32 : i32}) { 3704+ func.func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 32 : i32}) { 3705 %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> 3706 %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0> 3707 %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2> 3708@@ -523,7 +523,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3709 #A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}> 3710 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3711 // CHECK-LABEL: basic_insert_slice_async_v1 3712- func @basic_insert_slice_async_v1(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) { 3713+ func.func @basic_insert_slice_async_v1(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) { 3714 %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> 3715 %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0> 3716 %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2> 3717@@ -568,7 +568,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3718 #A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}> 3719 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3720 // CHECK-LABEL: basic_insert_slice_async_v1_multictas 3721- func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) { 3722+ func.func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) { 3723 %off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice2d1> 3724 %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0> 3725 %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<32xi32, #slice2d1>) -> tensor<32x1xi32, #block2> 3726@@ -619,7 +619,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3727 #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> 3728 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3729 // CHECK: basic_splat 3730- func @basic_splat(%ptr: !tt.ptr<f32>) { 3731+ func.func @basic_splat(%ptr: !tt.ptr<f32>) { 3732 // CHECK: llvm.mlir.undef 3733 // CHECK: llvm.insertvalue 3734 // CHECK: llvm.insertvalue 3735@@ -633,7 +633,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3736 #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> 3737 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3738 // CHECK-LABEL: basic_store 3739- func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) { 3740+ func.func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) { 3741 // CHECK: llvm.inline_asm 3742 // CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} }; 3743 // CHECK: llvm.inline_asm 3744@@ -650,7 +650,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3745 module attributes {"triton_gpu.num-warps" = 1 : i32} { 3746 // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> 3747 // CHECK-LABEL: convert_layout_blocked_blocked 3748- func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) { 3749+ func.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) { 3750 // CHECK: llvm.mlir.addressof @global_smem 3751 // CHECK: llvm.store 3752 // CHECK-SAME: !llvm.ptr<vector<1xf32>, 3> 3753@@ -697,7 +697,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { 3754 module attributes {"triton_gpu.num-warps" = 1 : i32} { 3755 // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> 3756 // CHECK-LABEL: convert_layout_blocked_blocked_vec 3757- func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) { 3758+ func.func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) { 3759 // CHECK: llvm.mlir.addressof @global_smem 3760 // CHECK: llvm.store 3761 // CHECK-SAME: !llvm.ptr<vector<4xf32>, 3> 3762@@ -720,7 +720,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { 3763 module attributes {"triton_gpu.num-warps" = 1 : i32} { 3764 // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> 3765 // CHECK-LABEL: convert_layout_blocked_blocked_multi_rep 3766- func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) { 3767+ func.func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) { 3768 // CHECK: llvm.mlir.addressof @global_smem 3769 // CHECK: llvm.store 3770 // CHECK-SAME: !llvm.ptr<vector<4xf32>, 3> 3771@@ -751,7 +751,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { 3772 #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0}> 3773 module attributes {"triton_gpu.num-warps" = 1 : i32} { 3774 // CHECK-LABEL: convert_dot 3775- func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { 3776+ func.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { 3777 %AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0> 3778 %BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0> 3779 // CHECK: llvm.inline_asm 3780@@ -775,7 +775,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { 3781 // TODO: problems in MLIR's parser on slice layout 3782 // #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> 3783 // module attributes {"triton_gpu.num-warps" = 1 : i32} { 3784-// func @make_range_sliced_layout() { 3785+// func.func @make_range_sliced_layout() { 3786 // %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> 3787 // return 3788 // } 3789@@ -788,7 +788,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { 3790 module attributes {"triton_gpu.num-warps" = 1 : i32} { 3791 // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> 3792 // CHECK-LABEL: convert_layout_mmav2_block 3793- func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) { 3794+ func.func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) { 3795 // CHECK: llvm.store 3796 // CHECK-SAME: !llvm.ptr<vector<2xf32>, 3> 3797 // CHECK: llvm.store 3798@@ -808,7 +808,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { 3799 module attributes {"triton_gpu.num-warps" = 1 : i32} { 3800 // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> 3801 // CHECK-LABEL: convert_layout_mmav1_block 3802- func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) { 3803+ func.func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) { 3804 // CHECK: llvm.store 3805 // CHECK-SAME: !llvm.ptr<vector<2xf32>, 3> 3806 // CHECK: llvm.store 3807@@ -831,7 +831,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { 3808 module attributes {"triton_gpu.num-warps" = 1 : i32} { 3809 // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> 3810 // CHECK-LABEL: convert_layout_blocked_shared 3811- func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) { 3812+ func.func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) { 3813 // CHECK: llvm.store 3814 // CHECK-SAME: !llvm.ptr<vector<8xf32>, 3> 3815 // CHECK: llvm.store 3816@@ -847,7 +847,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { 3817 #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}> 3818 module attributes {"triton_gpu.num-warps" = 1 : i32} { 3819 // CHECK-LABEL: convert_blocked1d_to_slice0 3820- func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) { 3821+ func.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) { 3822 // CHECK-COUNT-4: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3> 3823 %cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> 3824 return 3825@@ -860,7 +860,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { 3826 #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}> 3827 module attributes {"triton_gpu.num-warps" = 1 : i32} { 3828 // CHECK-LABEL: convert_blocked1d_to_slice1 3829- func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) { 3830+ func.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) { 3831 // CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3> 3832 %cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> 3833 return 3834@@ -873,7 +873,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { 3835 #blocked1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> 3836 module attributes {"triton_gpu.num-warps" = 1 : i32} { 3837 // CHECK-LABEL: convert_blocked_to_blocked_ptr 3838- func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr<f32>, #blocked0>) { 3839+ func.func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr<f32>, #blocked0>) { 3840 // CHECK: llvm.ptrtoint 3841 // CHECK: llvm.store 3842 // CHECK: nvvm.barrier0 3843@@ -892,7 +892,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { 3844 #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}> 3845 #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}> 3846 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3847- func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32}, 3848+ func.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32}, 3849 %a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) { 3850 %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> 3851 // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 3852@@ -918,7 +918,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3853 #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}> 3854 #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}> 3855 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3856- func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32}, 3857+ func.func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32}, 3858 %a:tensor<32x64xf16, #shared0>, %b:tensor<64x64xf16, #shared1>) { 3859 %cst = arith.constant dense<0.000000e+00> : tensor<32x64xf32, #mma> 3860 // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 3861@@ -941,7 +941,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3862 #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#blocked}> 3863 #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> 3864 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3865- func @matmul_fmadot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32}, 3866+ func.func @matmul_fmadot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32}, 3867 %a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) { 3868 %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> 3869 // CHECK: llvm.intr.fmuladd 3870@@ -965,7 +965,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3871 #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}> 3872 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3873 // CHECK-LABEL: matmul_tf32dot 3874- func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32}, 3875+ func.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32}, 3876 %a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) { 3877 %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> 3878 // CHECK: llvm.inline_asm 3879@@ -1000,7 +1000,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3880 #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> 3881 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3882 // CHECK-LABEL: atomic_add_f32 3883- func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { 3884+ func.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { 3885 // CHECK: llvm.inline_asm 3886 // CHECK-SAME: atom.global.gpu.add.f32 3887 %0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0> 3888@@ -1012,7 +1012,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3889 #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> 3890 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3891 3892-func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) { 3893+func.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) { 3894 %blockidx = tt.get_program_id {axis=0:i32} : i32 3895 %blockidy = tt.get_program_id {axis=1:i32} : i32 3896 %blockidz = tt.get_program_id {axis=2:i32} : i32 3897@@ -1032,7 +1032,7 @@ func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) { 3898 // ----- 3899 #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> 3900 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3901- func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) { 3902+ func.func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) { 3903 // CHECK: nvvm.read.ptx.sreg.nctaid.x 3904 // CHECK: nvvm.read.ptx.sreg.nctaid.y 3905 // CHECK: nvvm.read.ptx.sreg.nctaid.z 3906@@ -1052,7 +1052,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3907 #blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> 3908 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3909 // CHECK-LABEL: test_index_cache 3910- func @test_index_cache() { 3911+ func.func @test_index_cache() { 3912 // CHECK: nvvm.read.ptx.sreg.tid.x 3913 %0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> 3914 // CHECK-NOT: nvvm.read.ptx.sreg.tid.x 3915@@ -1066,7 +1066,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 3916 #shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> 3917 module attributes {"triton_gpu.num-warps" = 1 : i32} { 3918 // CHECK-LABEL: test_base_index_cache 3919- func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) { 3920+ func.func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) { 3921 // CHECK: nvvm.read.ptx.sreg.tid.x 3922 %0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0> 3923 // CHECK-NOT: nvvm.read.ptx.sreg.tid.x 3924@@ -1080,7 +1080,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { 3925 #shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> 3926 module attributes {"triton_gpu.num-warps" = 1 : i32} { 3927 // CHECK-LABEL: test_index_cache_different_block 3928- func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) { 3929+ func.func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) { 3930 // CHECK: nvvm.read.ptx.sreg.tid.x 3931 %0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0> 3932 scf.if %arg1 { 3933diff --git a/test/Target/tritongpu_to_llvmir.mlir b/test/Target/tritongpu_to_llvmir.mlir 3934index cafff3ca60..114d3a9eb2 100644 3935--- a/test/Target/tritongpu_to_llvmir.mlir 3936+++ b/test/Target/tritongpu_to_llvmir.mlir 3937@@ -4,11 +4,11 @@ 3938 // CHECK-LABEL: ; ModuleID = 'LLVMDialectModule' 3939 // CHECK: define void @test_empty_kernel 3940 // CHECK: !nvvm.annotations 3941-// CHECK: !{void (i32, half addrspace(1)*)* @test_empty_kernel, !"maxntidx", i32 128} 3942+// CHECK: !{ptr @test_empty_kernel, !"maxntidx", i32 128} 3943 3944 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3945 3946-func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) { 3947+func.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) { 3948 3949 return 3950 } 3951diff --git a/test/Target/tritongpu_to_ptx.mlir b/test/Target/tritongpu_to_ptx.mlir 3952index 404e970a29..12742ad9e2 100644 3953--- a/test/Target/tritongpu_to_ptx.mlir 3954+++ b/test/Target/tritongpu_to_ptx.mlir 3955@@ -6,7 +6,7 @@ 3956 3957 module attributes {"triton_gpu.num-warps" = 4 : i32} { 3958 3959-func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) { 3960+func.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) { 3961 3962 return 3963 } 3964diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir 3965index 050a3f7565..5ef6790e69 100644 3966--- a/test/Triton/combine.mlir 3967+++ b/test/Triton/combine.mlir 3968@@ -2,10 +2,10 @@ 3969 // RUN: triton-opt %s -split-input-file -canonicalize -triton-combine | FileCheck %s 3970 3971 // CHECK-LABEL: @test_combine_dot_add_pattern 3972-func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) { 3973- // CHECK: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32> 3974- // CHECK: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32> 3975- // CHECK: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32> 3976+func.func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) { 3977+ // CHECK-DAG: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32> 3978+ // CHECK-DAG: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32> 3979+ // CHECK-DAG: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32> 3980 %a = arith.constant dense<1.0> : tensor<128x128xf32> 3981 %b = arith.constant dense<2.0> : tensor<128x128xf32> 3982 %zero = arith.constant dense<0.0> : tensor<128x128xf32> 3983@@ -24,7 +24,7 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32 3984 3985 3986 // COM: CHECK-LABEL: @test_combine_addptr_pattern 3987-func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> { 3988+func.func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> { 3989 %off0 = arith.constant 10 : i32 3990 %off1 = arith.constant 15 : i32 3991 3992@@ -47,46 +47,46 @@ func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> 3993 3994 3995 // CHECK-LABEL: @test_combine_select_masked_load_pattern 3996-func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) { 3997+func.func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) { 3998 %mask = tt.broadcast %cond : (i1) -> tensor<8xi1> 3999 %false_val = arith.constant dense<0.0> : tensor<8xf32> 4000 4001 // CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> 4002 %x = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> 4003- %0 = select %cond, %x, %false_val : tensor<8xf32> 4004+ %0 = arith.select %cond, %x, %false_val : tensor<8xf32> 4005 4006 // CHECK: %[[res2:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> 4007 %y = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> 4008- %1 = select %cond, %y, %false_val : tensor<8xf32> 4009+ %1 = arith.select %cond, %y, %false_val : tensor<8xf32> 4010 4011 // CHECK: return %[[res1]], %[[res2]] : tensor<8xf32>, tensor<8xf32> 4012 return %0, %1 : tensor<8xf32>, tensor<8xf32> 4013 } 4014 4015 // CHECK-LABEL: @test_combine_select_masked_load_fail_pattern 4016-func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond0: i1, %cond1: i1) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { 4017+func.func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond0: i1, %cond1: i1) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { 4018 %false_val = arith.constant dense<0.0> : tensor<8xf32> 4019 4020 // Case 1: value at the "load" position is not an "op". Select should not be canonicalized. 4021- // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> 4022- %0 = select %cond0, %dummy_load, %false_val : tensor<8xf32> 4023+ // CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> 4024+ %0 = arith.select %cond0, %dummy_load, %false_val : tensor<8xf32> 4025 4026 // Case 2: value at the "broadcast" position is not an "op". Select should not be canonicalized. 4027 %real_load0 = tt.load %ptr, %dummy_broadcast, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> 4028- // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> 4029- %1 = select %cond0, %real_load0, %false_val : tensor<8xf32> 4030+ // CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> 4031+ %1 = arith.select %cond0, %real_load0, %false_val : tensor<8xf32> 4032 4033 // Case 3: condition of "broadcast" is not the same as the condition of "select". Select should not be canonicalized. 4034 %cond0_ = tt.broadcast %cond0 : (i1) -> tensor<8xi1> 4035 %real_load1 = tt.load %ptr, %cond0_, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> 4036- // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> 4037- %2 = select %cond1, %real_load1, %false_val : tensor<8xf32> 4038+ // CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> 4039+ %2 = arith.select %cond1, %real_load1, %false_val : tensor<8xf32> 4040 4041 return %0, %1, %2 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32> 4042 } 4043 4044 // CHECK-LABEL: @test_combine_broadcast_constant_pattern 4045-func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> { 4046+func.func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> { 4047 // CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32> 4048 %const = arith.constant dense<1.0> : tensor<8xf32> 4049 %bst_out = tt.broadcast %const : (tensor<8xf32>) -> tensor<8x2xf32> 4050@@ -96,7 +96,7 @@ func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> { 4051 } 4052 4053 // CHECK-LABEL: @test_canonicalize_masked_load_pattern 4054-func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { 4055+func.func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { 4056 %true_mask = arith.constant dense<true> : tensor<8xi1> 4057 %false_mask = arith.constant dense<false> : tensor<8xi1> 4058 %other_val = arith.constant dense<0.0> : tensor<8xf32> 4059@@ -117,7 +117,7 @@ func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (te 4060 } 4061 4062 // CHECK-LABEL: @test_canonicalize_masked_load_fail_pattern 4063-func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) { 4064+func.func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) { 4065 %other_val = arith.constant dense<0.0> : tensor<8xf32> 4066 4067 // Case: value at the "mask" position is not an "op". Load should not be canonicalized. 4068@@ -130,7 +130,7 @@ func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, % 4069 } 4070 4071 // CHECK-LABEL: @test_canonicalize_masked_store_pattern 4072-func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>) { 4073+func.func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>) { 4074 %true_mask = arith.constant dense<true> : tensor<8xi1> 4075 %false_mask = arith.constant dense<false> : tensor<8xi1> 4076 4077@@ -144,7 +144,7 @@ func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: 4078 } 4079 4080 // CHECK-LABEL: @test_canonicalize_masked_store_fail_pattern 4081-func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>, %mask: tensor<8xi1>) { 4082+func.func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>, %mask: tensor<8xi1>) { 4083 // Case: value at the "mask" position is not an "op". Store should not be canonicalized. 4084 // CHECK: tt.store %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> 4085 tt.store %ptr, %val, %mask : tensor<8xf32> 4086diff --git a/test/Triton/vecadd.mlir b/test/Triton/vecadd.mlir 4087index 0b69ef3054..f5019b1cdd 100644 4088--- a/test/Triton/vecadd.mlir 4089+++ b/test/Triton/vecadd.mlir 4090@@ -1,7 +1,7 @@ 4091 // RUN: triton-opt %s -verify-diagnostics 4092 4093 module { 4094- func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) { 4095+ func.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) { 4096 %0 = tt.get_program_id {axis = 0 : i32} : i32 4097 %c256_i32 = arith.constant 256 : i32 4098 %1 = arith.muli %0, %c256_i32 : i32 4099@@ -43,7 +43,7 @@ module { 4100 } 4101 } 4102 // module { 4103-// func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) { 4104+// func.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) { 4105 // %c64 = arith.constant 64 : index 4106 // %c32 = arith.constant 32 : index 4107 // %c0 = arith.constant 0 : index 4108diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir 4109index 60e359f527..51cccccfbd 100644 4110--- a/test/TritonGPU/coalesce.mlir 4111+++ b/test/TritonGPU/coalesce.mlir 4112@@ -19,7 +19,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { 4113 // CHECK: [[store_val:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[col_layout]]> 4114 // CHECK: [[store_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[col_layout]]> 4115 // CHECK: tt.store [[store_ptr]], [[store_val]], [[store_mask]] 4116-func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, 4117+func.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, 4118 %arg1: i32 {tt.divisibility = 16 : i32}, 4119 %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, 4120 %arg3: i32 {tt.divisibility = 16 : i32}) { 4121diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir 4122index 2c009ffa48..7e9cb9d504 100644 4123--- a/test/TritonGPU/combine.mlir 4124+++ b/test/TritonGPU/combine.mlir 4125@@ -9,7 +9,7 @@ 4126 // CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> 4127 // CHECK: [[col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> 4128 // CHECK-LABEL: cst 4129-func @cst() -> tensor<1024xi32, #layout1> { 4130+func.func @cst() -> tensor<1024xi32, #layout1> { 4131 %cst = arith.constant dense<0> : tensor<1024xi32, #layout0> 4132 %1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1> 4133 // CHECK-NOT: triton_gpu.convert_layout 4134@@ -18,7 +18,7 @@ func @cst() -> tensor<1024xi32, #layout1> { 4135 } 4136 4137 // CHECK-LABEL: range 4138-func @range() -> tensor<1024xi32, #layout1> { 4139+func.func @range() -> tensor<1024xi32, #layout1> { 4140 %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0> 4141 %1 = triton_gpu.convert_layout %0 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1> 4142 // CHECK-NOT: triton_gpu.convert_layout 4143@@ -27,7 +27,7 @@ func @range() -> tensor<1024xi32, #layout1> { 4144 } 4145 4146 // CHECK-LABEL: splat 4147-func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> { 4148+func.func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> { 4149 %0 = tt.splat %arg0 : (i32) -> tensor<1024xi32, #layout0> 4150 %1 = triton_gpu.convert_layout %0 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1> 4151 // CHECK-NOT: triton_gpu.convert_layout 4152@@ -36,7 +36,7 @@ func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> { 4153 } 4154 4155 // CHECK-LABEL: remat 4156-func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { 4157+func.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { 4158 %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0> 4159 %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0> 4160 %2 = arith.muli %0, %1 : tensor<1024xi32, #layout0> 4161@@ -56,7 +56,7 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { 4162 } 4163 4164 // CHECK-LABEL: remat_load_store 4165-func @remat_load_store(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) { 4166+func.func @remat_load_store(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) { 4167 %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout0> 4168 %1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout0> 4169 %2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout0>, tensor<64xi32, #layout0> 4170@@ -70,7 +70,7 @@ func @remat_load_store(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) { 4171 4172 // Don't rematerialize vectorized loads 4173 // CHECK-LABEL: remat_expensive 4174-func @remat_expensive(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) { 4175+func.func @remat_expensive(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) { 4176 %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout1> 4177 %1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout1> 4178 %2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout1>, tensor<64xi32, #layout1> 4179@@ -85,7 +85,7 @@ func @remat_expensive(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) { 4180 4181 // Don't rematerialize loads when original and target layouts are different 4182 // CHECK-LABEL: remat_multi_layout 4183-func @remat_multi_layout(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) { 4184+func.func @remat_multi_layout(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) { 4185 %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout0> 4186 %1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout0> 4187 %2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout0>, tensor<64xi32, #layout0> 4188@@ -100,7 +100,7 @@ func @remat_multi_layout(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) { 4189 4190 // Always rematerialize single value loads 4191 // CHECK-LABEL: remat_single_value 4192-func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) { 4193+func.func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) { 4194 %0 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<1x!tt.ptr<i32>, #layout1> 4195 %1 = tt.load %0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1xi32, #layout1> 4196 // CHECK-NOT: triton_gpu.convert_layout 4197@@ -111,7 +111,7 @@ func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) { 4198 } 4199 4200 // CHECK-LABEL: if 4201-func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) { 4202+func.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) { 4203 // CHECK-NOT: triton_gpu.convert_layout 4204 %c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout1> 4205 %0 = tt.get_program_id {axis = 0 : i32} : i32 4206@@ -128,7 +128,7 @@ func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) { 4207 } 4208 4209 // CHECK-LABEL: if_convert_else_not 4210-func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) { 4211+func.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) { 4212 %c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0> 4213 %0 = tt.get_program_id {axis = 0 : i32} : i32 4214 %1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0> 4215@@ -149,7 +149,7 @@ func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 4216 } 4217 4218 // CHECK-LABEL: if_not_else_convert 4219-func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) { 4220+func.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) { 4221 %c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0> 4222 %0 = tt.get_program_id {axis = 0 : i32} : i32 4223 %1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0> 4224@@ -170,7 +170,7 @@ func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 4225 } 4226 4227 // CHECK-LABEL: if_else_both_convert 4228-func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) { 4229+func.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) { 4230 %c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0> 4231 %0 = tt.get_program_id {axis = 0 : i32} : i32 4232 %1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0> 4233@@ -200,7 +200,7 @@ func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 4234 #blocked4 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> 4235 4236 // CHECK-LABEL: transpose 4237-func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { 4238+func.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { 4239 // CHECK-NOT: triton_gpu.convert_layout 4240 // CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, {{%cst.*}}, {{%cst.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]> 4241 // CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]> 4242@@ -241,7 +241,7 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt 4243 } 4244 4245 // CHECK-LABEL: loop 4246-func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) { 4247+func.func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) { 4248 // CHECK-NOT: triton_gpu.convert_layout 4249 // CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[row_layout]]>) 4250 // CHECK-NEXT: {{.*}} = tt.load {{.*}} : tensor<64x64xf32, [[row_layout]]> 4251@@ -295,7 +295,7 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar 4252 } 4253 4254 // CHECK-LABEL: vecadd 4255-func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) { 4256+func.func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) { 4257 // CHECK-NOT: triton_gpu.convert_layout 4258 %c256_i32 = arith.constant 256 : i32 4259 %0 = tt.get_program_id {axis = 0 : i32} : i32 4260@@ -327,7 +327,7 @@ func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f3 4261 4262 // Select has args with different element types 4263 // CHECK-LABEL: select 4264-func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) { 4265+func.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) { 4266 // CHECK-NOT: triton_gpu.convert_layout 4267 %cst = arith.constant dense<30000> : tensor<1x1xi32, #blocked2> 4268 %cst_0 = arith.constant dense<30000> : tensor<1x512xi32, #blocked2> 4269@@ -378,7 +378,7 @@ func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f6 4270 4271 // Make sure the following IR doesn't hang the compiler. 4272 // CHECK-LABEL: long_func 4273-func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg13: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg14: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) { 4274+func.func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg13: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg14: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) { 4275 %cst = arith.constant dense<1.000000e+00> : tensor<1024xf32, #blocked0> 4276 %cst_0 = arith.constant dense<5.000000e-04> : tensor<1024xf32, #blocked0> 4277 %cst_1 = arith.constant dense<0.999499976> : tensor<1024xf32, #blocked0> 4278@@ -775,7 +775,7 @@ func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: 4279 // A mnist model from torch inductor. 4280 // Check if topological sort is working correct and there's no unnecessary convert 4281 // CHECK-LABEL: mnist 4282-func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) { 4283+func.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) { 4284 // CHECK-NOT: triton_gpu.convert_layout 4285 %cst = arith.constant dense<10> : tensor<16x1xi32, #blocked2> 4286 %cst_0 = arith.constant dense<10> : tensor<1x16xi32, #blocked3> 4287@@ -862,7 +862,7 @@ func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt. 4288 #blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> 4289 // cmpf and cmpi have different operands and result types 4290 // CHECK-LABEL: cmp 4291-func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { 4292+func.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { 4293 %c64 = arith.constant 64 : index 4294 %c2048 = arith.constant 2048 : index 4295 %c0 = arith.constant 0 : index 4296diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir 4297index 6ee3b15fbc..663f2da7b0 100644 4298--- a/test/TritonGPU/loop-pipeline.mlir 4299+++ b/test/TritonGPU/loop-pipeline.mlir 4300@@ -10,7 +10,7 @@ 4301 #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C}> 4302 #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C}> 4303 4304-// CHECK: func @matmul_loop 4305+// CHECK: func.func @matmul_loop 4306 // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 4307 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 4308 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 4309@@ -46,8 +46,8 @@ 4310 // CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] 4311 // CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] 4312 // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] 4313-func @matmul_loop(%lb : index, %ub : index, %step : index, 4314- %A : !tt.ptr<f16> {tt.divisibility = 16 : i32}, 4315+func.func @matmul_loop(%lb : index, %ub : index, %step : index, 4316+ %A : !tt.ptr<f16> {tt.divisibility = 16 : i32}, 4317 %B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) { 4318 // A ptrs 4319 %a_ptr_splat = tt.splat %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL> 4320@@ -61,7 +61,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, 4321 %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : (tensor<128xi32, #BLs0>) -> tensor<1x128xi32, #BL> 4322 %b_offs = tt.broadcast %b_tmp1 : (tensor<1x128xi32, #BL>) -> tensor<32x128xi32, #BL> 4323 %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL> 4324- 4325+ 4326 4327 %a_mask = arith.constant dense<true> : tensor<128x32xi1, #AL> 4328 %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> 4329@@ -88,7 +88,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, 4330 } 4331 4332 4333-// CHECK: func @matmul_loop_nested 4334+// CHECK: func.func @matmul_loop_nested 4335 // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 4336 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 4337 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 4338@@ -118,8 +118,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, 4339 // CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] 4340 // CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] 4341 // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] 4342-func @matmul_loop_nested(%lb : index, %ub : index, %step : index, 4343- %A : !tt.ptr<f16> {tt.divisibility = 16 : i32}, 4344+func.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, 4345+ %A : !tt.ptr<f16> {tt.divisibility = 16 : i32}, 4346 %B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) { 4347 scf.for %iv0 = %lb to %ub step %step { 4348 // A ptrs 4349@@ -134,7 +134,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, 4350 %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : (tensor<128xi32, #BLs0>) -> tensor<1x128xi32, #BL> 4351 %b_offs = tt.broadcast %b_tmp1 : (tensor<1x128xi32, #BL>) -> tensor<32x128xi32, #BL> 4352 %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL> 4353- 4354+ 4355 %a_mask = arith.constant dense<true> : tensor<128x32xi1, #AL> 4356 %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> 4357 %b_mask = arith.constant dense<true> : tensor<32x128xi1, #BL> 4358@@ -161,7 +161,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, 4359 } 4360 4361 4362-// CHECK: func @matmul_loop_single_pipeline 4363+// CHECK: func.func @matmul_loop_single_pipeline 4364 // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 4365 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 4366 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 4367@@ -183,8 +183,8 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, 4368 // CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] 4369 // CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] 4370 // CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] 4371-func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, 4372- %A : !tt.ptr<f16> {tt.divisibility = 16 : i32}, 4373+func.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, 4374+ %A : !tt.ptr<f16> {tt.divisibility = 16 : i32}, 4375 %B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) { 4376 // A ptrs 4377 %a_ptr_splat = tt.splat %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL> 4378diff --git a/test/TritonGPU/matmul.mlir b/test/TritonGPU/matmul.mlir 4379index 9bd5318e1e..01dc3f0ab1 100644 4380--- a/test/TritonGPU/matmul.mlir 4381+++ b/test/TritonGPU/matmul.mlir 4382@@ -4,7 +4,7 @@ 4383 // CHECK: offset = 49152, size = 49152 4384 // CHECK: size = 98304 4385 module { 4386-func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) { 4387+func.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) { 4388 %cst = arith.constant dense<true> : tensor<64x64xi1> 4389 %c64 = arith.constant 64 : index 4390 %c0 = arith.constant 0 : index 4391@@ -22,7 +22,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6 4392 %7 = arith.muli %6, %c8_i32 : i32 4393 %8 = arith.subi %2, %7 : i32 4394 %9 = arith.cmpi slt, %8, %c8_i32 : i32 4395- %10 = select %9, %8, %c8_i32 : i32 4396+ %10 = arith.select %9, %8, %c8_i32 : i32 4397 %11 = arith.remsi %0, %10 : i32 4398 %12 = arith.addi %7, %11 : i32 4399 %13 = arith.remsi %0, %5 : i32 4400diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir 4401index 52b4dddec1..b427547890 100644 4402--- a/test/TritonGPU/prefetch.mlir 4403+++ b/test/TritonGPU/prefetch.mlir 4404@@ -11,7 +11,7 @@ 4405 #B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C}> 4406 4407 4408-// CHECK: func @matmul_loop 4409+// CHECK: func.func @matmul_loop 4410 // CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = tensor.extract_slice %[[A0:.*]][0, 0] [128, 16] 4411 // CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.convert_layout %[[A0_PREFETCH_SMEM]] 4412 // CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = tensor.extract_slice %[[B0:.*]][0, 0] [16, 128] 4413@@ -28,7 +28,7 @@ 4414 // CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = tensor.extract_slice {{.*}}[0, 0] [16, 128] 4415 // CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_B_PREFETCH_SMEM]] 4416 // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH]], %[[NEXT_B_PREFETCH]] 4417-func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) { 4418+func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) { 4419 %a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL> 4420 %b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL> 4421 4422diff --git a/test/TritonGPU/update-mma-for-volta.mlir b/test/TritonGPU/update-mma-for-volta.mlir 4423index d587fffcca..7571ec6185 100644 4424--- a/test/TritonGPU/update-mma-for-volta.mlir 4425+++ b/test/TritonGPU/update-mma-for-volta.mlir 4426@@ -15,7 +15,7 @@ 4427 // CHECK: [[new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [4, 2]}> 4428 module attributes {"triton_gpu.num-warps" = 16 : i32} { 4429 // CHECK-LABEL: dot_mmav1 4430- func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> { 4431+ func.func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> { 4432 %C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked0> 4433 %AA = triton_gpu.convert_layout %A : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_a> 4434 %BB = triton_gpu.convert_layout %B : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_b> 4435@@ -50,7 +50,7 @@ module attributes {"triton_gpu.num-warps" = 16 : i32} { 4436 4437 module attributes {"triton_gpu.num-warps" = 16 : i32} { 4438 // CHECK-LABEL: dot_mmav1 4439- func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> { 4440+ func.func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> { 4441 %C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked0> 4442 %AA = triton_gpu.convert_layout %A : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_a> 4443 %BB = triton_gpu.convert_layout %B : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_b> 4444diff --git a/test/lib/Analysis/TestAlias.cpp b/test/lib/Analysis/TestAlias.cpp 4445index 88a4118fe9..3fd0cfd0d3 100644 4446--- a/test/lib/Analysis/TestAlias.cpp 4447+++ b/test/lib/Analysis/TestAlias.cpp 4448@@ -9,10 +9,10 @@ using namespace mlir; 4449 namespace { 4450 4451 struct TestAliasPass 4452- : public PassWrapper<TestAliasPass, OperationPass<FuncOp>> { 4453+ : public PassWrapper<TestAliasPass, OperationPass<func::FuncOp>> { 4454+ 4455+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass); 4456 4457- // LLVM15+ 4458- // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass); 4459 static void print(StringRef name, SmallVector<std::string, 4> &vals, 4460 raw_ostream &os) { 4461 if (vals.empty()) 4462@@ -39,23 +39,24 @@ struct TestAliasPass 4463 auto opName = SymbolTable::getSymbolName(operation).getValue().str(); 4464 os << opName << "\n"; 4465 4466- SharedMemoryAliasAnalysis analysis(&getContext()); 4467- analysis.run(operation); 4468+ std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver(); 4469+ SharedMemoryAliasAnalysis *analysis = 4470+ solver->load<SharedMemoryAliasAnalysis>(); 4471+ if (failed(solver->initializeAndRun(operation))) 4472+ return signalPassFailure(); 4473 4474 AsmState state(operation->getParentOfType<ModuleOp>()); 4475 // Get operation ids of value's aliases 4476 auto getAllocOpNames = [&](Value value) { 4477- LatticeElement<AliasInfo> *latticeElement = 4478- analysis.lookupLatticeElement(value); 4479+ dataflow::Lattice<AliasInfo> *latticeElement = 4480+ analysis->getLatticeElement(value); 4481 SmallVector<std::string, 4> opNames; 4482- if (latticeElement) { 4483+ if (latticeElement && !latticeElement->isUninitialized()) { 4484 auto &info = latticeElement->getValue(); 4485- if (!info.getAllocs().empty()) { 4486- for (auto &alias : info.getAllocs()) { 4487- auto opName = 4488- getValueOperandName(alias.getDefiningOp()->getResult(0), state); 4489- opNames.push_back(std::move(opName)); 4490- } 4491+ for (auto &alias : info.getAllocs()) { 4492+ auto opName = 4493+ getValueOperandName(alias.getDefiningOp()->getResult(0), state); 4494+ opNames.push_back(std::move(opName)); 4495 } 4496 } 4497 // Ensure deterministic output 4498diff --git a/test/lib/Analysis/TestAllocation.cpp b/test/lib/Analysis/TestAllocation.cpp 4499index 84108c4d36..35e42242bd 100644 4500--- a/test/lib/Analysis/TestAllocation.cpp 4501+++ b/test/lib/Analysis/TestAllocation.cpp 4502@@ -6,10 +6,9 @@ using namespace mlir; 4503 namespace { 4504 4505 struct TestAllocationPass 4506- : public PassWrapper<TestAllocationPass, OperationPass<FuncOp>> { 4507+ : public PassWrapper<TestAllocationPass, OperationPass<func::FuncOp>> { 4508 4509- // LLVM15+ 4510- // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass); 4511+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass); 4512 4513 StringRef getArgument() const final { return "test-print-allocation"; } 4514 StringRef getDescription() const final { 4515diff --git a/test/lib/Analysis/TestAxisInfo.cpp b/test/lib/Analysis/TestAxisInfo.cpp 4516index a5205bb0a0..22347c32f0 100644 4517--- a/test/lib/Analysis/TestAxisInfo.cpp 4518+++ b/test/lib/Analysis/TestAxisInfo.cpp 4519@@ -1,25 +1,15 @@ 4520 #include "mlir/Pass/Pass.h" 4521 #include "triton/Analysis/AxisInfo.h" 4522+#include "triton/Analysis/Utility.h" 4523 4524 using namespace mlir; 4525 4526 namespace { 4527 4528 struct TestAxisInfoPass 4529- : public PassWrapper<TestAxisInfoPass, OperationPass<FuncOp>> { 4530+ : public PassWrapper<TestAxisInfoPass, OperationPass<func::FuncOp>> { 4531 4532- // LLVM15+ 4533- // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAlignmentPass); 4534- 4535- void print(const std::string &name, raw_ostream &os, ArrayRef<int64_t> vals) { 4536- os << name << ": ["; 4537- for (size_t d = 0; d < vals.size(); d++) { 4538- if (d != 0) 4539- os << ", "; 4540- os << vals[d]; 4541- } 4542- os << "]"; 4543- } 4544+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAxisInfoPass); 4545 4546 StringRef getArgument() const final { return "test-print-alignment"; } 4547 StringRef getDescription() const final { 4548@@ -30,38 +20,19 @@ struct TestAxisInfoPass 4549 Operation *operation = getOperation(); 4550 auto &os = llvm::errs(); 4551 auto opName = SymbolTable::getSymbolName(operation).getValue().str(); 4552- os << opName << "\n"; 4553- AxisInfoAnalysis analysis(&getContext()); 4554- analysis.run(operation); 4555+ os << "@" << opName << "\n"; 4556+ 4557+ std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver(); 4558+ AxisInfoAnalysis *analysis = solver->load<AxisInfoAnalysis>(); 4559+ if (failed(solver->initializeAndRun(operation))) 4560+ return signalPassFailure(); 4561 operation->walk([&](Operation *op) { 4562 if (op->getNumResults() < 1) 4563 return; 4564 for (Value result : op->getResults()) { 4565- // std::ostringstream oss; 4566- // result.print(oss); 4567- // os << " => "; 4568- LatticeElement<AxisInfo> *latticeElement = 4569- analysis.lookupLatticeElement(result); 4570- if (!latticeElement) { 4571- os << "None\n"; 4572- return; 4573- } 4574- AxisInfo &info = latticeElement->getValue(); 4575- print("Contiguity", os, info.getContiguity()); 4576- os << " ; "; 4577- print("Divisibility", os, info.getDivisibility()); 4578- os << " ; "; 4579- print("Constancy", os, info.getConstancy()); 4580- os << " ; "; 4581- auto constantValue = info.getConstantValue(); 4582- os << "ConstantValue: ["; 4583- if (constantValue.has_value()) 4584- os << constantValue.value(); 4585- else 4586- os << "None"; 4587- os << "] ( "; 4588 result.print(os); 4589- os << " ) "; 4590+ os << " => "; 4591+ analysis->getLatticeElement(result)->getValue().print(os); 4592 os << "\n"; 4593 } 4594 }); 4595diff --git a/test/lib/Analysis/TestMembar.cpp b/test/lib/Analysis/TestMembar.cpp 4596index df4279fe24..ab9b9f3fb7 100644 4597--- a/test/lib/Analysis/TestMembar.cpp 4598+++ b/test/lib/Analysis/TestMembar.cpp 4599@@ -1,4 +1,4 @@ 4600-#include "mlir/Dialect/GPU/GPUDialect.h" 4601+#include "mlir/Dialect/GPU/IR/GPUDialect.h" 4602 #include "mlir/IR/Dialect.h" 4603 #include "mlir/Pass/Pass.h" 4604 #include "triton/Analysis/Allocation.h" 4605@@ -9,10 +9,9 @@ using namespace mlir; 4606 namespace { 4607 4608 struct TestMembarPass 4609- : public PassWrapper<TestMembarPass, OperationPass<FuncOp>> { 4610+ : public PassWrapper<TestMembarPass, OperationPass<func::FuncOp>> { 4611 4612- // LLVM15+ 4613- // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMembarPass); 4614+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMembarPass); 4615 4616 StringRef getArgument() const final { return "test-print-membar"; } 4617 StringRef getDescription() const final {