diff --git a/library/src/hipblaslt_host.cpp b/library/src/hipblaslt_host.cpp index 8080070c..97d5216e 100644 --- a/library/src/hipblaslt_host.cpp +++ b/library/src/hipblaslt_host.cpp @@ -155,22 +155,22 @@ namespace hipblaslt_compute_type); hipblaslt_ext::GemmProblemType problemType; - problemType.op_a = (hipblasOperation_t)prob.trans_a; - problemType.op_b = (hipblasOperation_t)prob.trans_b; - problemType.type_a = hipblaslt_datatype; - problemType.type_b = hipblaslt_datatype; - problemType.type_c = hipblaslt_datatype; - problemType.type_d = hipblaslt_datatype; - problemType.type_compute = hipblaslt_compute_type; + problemType.setOpA((hipblasOperation_t)prob.trans_a); + problemType.setOpB((hipblasOperation_t)prob.trans_b); + problemType.setTypeA(hipblaslt_datatype); + problemType.setTypeB(hipblaslt_datatype); + problemType.setTypeC(hipblaslt_datatype); + problemType.setTypeD(hipblaslt_datatype); + problemType.setTypeCompute(hipblaslt_compute_type); hipblaslt_ext::GemmEpilogue epilogue; hipblaslt_ext::GemmInputs inputs; - inputs.a = (void*)(prob.A + prob.buffer_offset_a); - inputs.b = (void*)(prob.B + prob.buffer_offset_b); - inputs.c = (void*)(prob.C + prob.buffer_offset_c); - inputs.d = (void*)(prob.D + prob.buffer_offset_d); - inputs.alpha = (void*)prob.alpha; - inputs.beta = (void*)prob.beta; + inputs.setA((void*)(prob.A + prob.buffer_offset_a)); + inputs.setB((void*)(prob.B + prob.buffer_offset_b)); + inputs.setC((void*)(prob.C + prob.buffer_offset_c)); + inputs.setD((void*)(prob.D + prob.buffer_offset_d)); + inputs.setAlpha((void*)prob.alpha); + inputs.setBeta((void*)prob.beta); gemm.setProblem(prob.m, prob.n, @@ -214,13 +214,13 @@ namespace hipblaslt_compute_type); hipblaslt_ext::GemmProblemType problemType; - problemType.op_a = (hipblasOperation_t)prob.trans_a; - problemType.op_b = (hipblasOperation_t)prob.trans_b; - problemType.type_a = hipblaslt_datatype; - problemType.type_b = hipblaslt_datatype; - problemType.type_c = hipblaslt_datatype; - problemType.type_d = hipblaslt_datatype; - problemType.type_compute = hipblaslt_compute_type; + problemType.setOpA((hipblasOperation_t)prob.trans_a); + problemType.setOpB((hipblasOperation_t)prob.trans_b); + problemType.setTypeA(hipblaslt_datatype); + problemType.setTypeB(hipblaslt_datatype); + problemType.setTypeC(hipblaslt_datatype); + problemType.setTypeD(hipblaslt_datatype); + problemType.setTypeCompute(hipblaslt_compute_type); std::vector Ms(prob.batch_count); std::vector Ns(prob.batch_count); @@ -251,12 +251,12 @@ namespace stridecs[batch] = prob.batch_stride_c; strideds[batch] = prob.batch_stride_d; batch_counts[batch] = 1; - inputs[batch].a = (void*)(prob.batch_A[batch] + prob.buffer_offset_a); - inputs[batch].b = (void*)(prob.batch_B[batch] + prob.buffer_offset_b); - inputs[batch].c = (void*)(prob.batch_C[batch] + prob.buffer_offset_c); - inputs[batch].d = (void*)(prob.batch_D[batch] + prob.buffer_offset_d); - inputs[batch].alpha = (void*)prob.alpha; - inputs[batch].beta = (void*)prob.beta; + inputs[batch].setA((void*)(prob.batch_A[batch] + prob.buffer_offset_a)); + inputs[batch].setB((void*)(prob.batch_B[batch] + prob.buffer_offset_b)); + inputs[batch].setC((void*)(prob.batch_C[batch] + prob.buffer_offset_c)); + inputs[batch].setD((void*)(prob.batch_D[batch] + prob.buffer_offset_d)); + inputs[batch].setAlpha((void*)prob.alpha); + inputs[batch].setBeta((void*)prob.beta); } gemm.setProblem(Ms, diff --git a/library/src/tensile_host.cpp b/library/src/tensile_host.cpp index 1b1289f3..ed463725 100644 --- a/library/src/tensile_host.cpp +++ b/library/src/tensile_host.cpp @@ -271,14 +271,6 @@ namespace { return Tensile::LazyLoadingInit::gfx90a; } - else if(deviceString.find("gfx940") != std::string::npos) - { - return Tensile::LazyLoadingInit::gfx940; - } - else if(deviceString.find("gfx941") != std::string::npos) - { - return Tensile::LazyLoadingInit::gfx941; - } else if(deviceString.find("gfx942") != std::string::npos) { return Tensile::LazyLoadingInit::gfx942;