nixpkgs mirror (for testing) github.com/NixOS/nixpkgs
nix
at python-updates 99 lines 5.0 kB view raw
1diff --git a/library/src/hipblaslt_host.cpp b/library/src/hipblaslt_host.cpp 2index 8080070c..97d5216e 100644 3--- a/library/src/hipblaslt_host.cpp 4+++ b/library/src/hipblaslt_host.cpp 5@@ -155,22 +155,22 @@ namespace 6 hipblaslt_compute_type<Tc>); 7 8 hipblaslt_ext::GemmProblemType problemType; 9- problemType.op_a = (hipblasOperation_t)prob.trans_a; 10- problemType.op_b = (hipblasOperation_t)prob.trans_b; 11- problemType.type_a = hipblaslt_datatype<TiA>; 12- problemType.type_b = hipblaslt_datatype<TiB>; 13- problemType.type_c = hipblaslt_datatype<To>; 14- problemType.type_d = hipblaslt_datatype<To>; 15- problemType.type_compute = hipblaslt_compute_type<Tc>; 16+ problemType.setOpA((hipblasOperation_t)prob.trans_a); 17+ problemType.setOpB((hipblasOperation_t)prob.trans_b); 18+ problemType.setTypeA(hipblaslt_datatype<TiA>); 19+ problemType.setTypeB(hipblaslt_datatype<TiB>); 20+ problemType.setTypeC(hipblaslt_datatype<To>); 21+ problemType.setTypeD(hipblaslt_datatype<To>); 22+ problemType.setTypeCompute(hipblaslt_compute_type<Tc>); 23 24 hipblaslt_ext::GemmEpilogue epilogue; 25 hipblaslt_ext::GemmInputs inputs; 26- inputs.a = (void*)(prob.A + prob.buffer_offset_a); 27- inputs.b = (void*)(prob.B + prob.buffer_offset_b); 28- inputs.c = (void*)(prob.C + prob.buffer_offset_c); 29- inputs.d = (void*)(prob.D + prob.buffer_offset_d); 30- inputs.alpha = (void*)prob.alpha; 31- inputs.beta = (void*)prob.beta; 32+ inputs.setA((void*)(prob.A + prob.buffer_offset_a)); 33+ inputs.setB((void*)(prob.B + prob.buffer_offset_b)); 34+ inputs.setC((void*)(prob.C + prob.buffer_offset_c)); 35+ inputs.setD((void*)(prob.D + prob.buffer_offset_d)); 36+ inputs.setAlpha((void*)prob.alpha); 37+ inputs.setBeta((void*)prob.beta); 38 39 gemm.setProblem(prob.m, 40 prob.n, 41@@ -214,13 +214,13 @@ namespace 42 hipblaslt_compute_type<Tc>); 43 44 hipblaslt_ext::GemmProblemType problemType; 45- problemType.op_a = (hipblasOperation_t)prob.trans_a; 46- problemType.op_b = (hipblasOperation_t)prob.trans_b; 47- problemType.type_a = hipblaslt_datatype<TiA>; 48- problemType.type_b = hipblaslt_datatype<TiB>; 49- problemType.type_c = hipblaslt_datatype<To>; 50- problemType.type_d = hipblaslt_datatype<To>; 51- problemType.type_compute = hipblaslt_compute_type<Tc>; 52+ problemType.setOpA((hipblasOperation_t)prob.trans_a); 53+ problemType.setOpB((hipblasOperation_t)prob.trans_b); 54+ problemType.setTypeA(hipblaslt_datatype<TiA>); 55+ problemType.setTypeB(hipblaslt_datatype<TiB>); 56+ problemType.setTypeC(hipblaslt_datatype<To>); 57+ problemType.setTypeD(hipblaslt_datatype<To>); 58+ problemType.setTypeCompute(hipblaslt_compute_type<Tc>); 59 60 std::vector<int64_t> Ms(prob.batch_count); 61 std::vector<int64_t> Ns(prob.batch_count); 62@@ -251,12 +251,12 @@ namespace 63 stridecs[batch] = prob.batch_stride_c; 64 strideds[batch] = prob.batch_stride_d; 65 batch_counts[batch] = 1; 66- inputs[batch].a = (void*)(prob.batch_A[batch] + prob.buffer_offset_a); 67- inputs[batch].b = (void*)(prob.batch_B[batch] + prob.buffer_offset_b); 68- inputs[batch].c = (void*)(prob.batch_C[batch] + prob.buffer_offset_c); 69- inputs[batch].d = (void*)(prob.batch_D[batch] + prob.buffer_offset_d); 70- inputs[batch].alpha = (void*)prob.alpha; 71- inputs[batch].beta = (void*)prob.beta; 72+ inputs[batch].setA((void*)(prob.batch_A[batch] + prob.buffer_offset_a)); 73+ inputs[batch].setB((void*)(prob.batch_B[batch] + prob.buffer_offset_b)); 74+ inputs[batch].setC((void*)(prob.batch_C[batch] + prob.buffer_offset_c)); 75+ inputs[batch].setD((void*)(prob.batch_D[batch] + prob.buffer_offset_d)); 76+ inputs[batch].setAlpha((void*)prob.alpha); 77+ inputs[batch].setBeta((void*)prob.beta); 78 } 79 80 gemm.setProblem(Ms, 81diff --git a/library/src/tensile_host.cpp b/library/src/tensile_host.cpp 82index 1b1289f3..ed463725 100644 83--- a/library/src/tensile_host.cpp 84+++ b/library/src/tensile_host.cpp 85@@ -271,14 +271,6 @@ namespace 86 { 87 return Tensile::LazyLoadingInit::gfx90a; 88 } 89- else if(deviceString.find("gfx940") != std::string::npos) 90- { 91- return Tensile::LazyLoadingInit::gfx940; 92- } 93- else if(deviceString.find("gfx941") != std::string::npos) 94- { 95- return Tensile::LazyLoadingInit::gfx941; 96- } 97 else if(deviceString.find("gfx942") != std::string::npos) 98 { 99 return Tensile::LazyLoadingInit::gfx942;