nixpkgs mirror (for testing)
github.com/NixOS/nixpkgs
nix
1diff --git a/library/src/hipblaslt_host.cpp b/library/src/hipblaslt_host.cpp
2index 8080070c..97d5216e 100644
3--- a/library/src/hipblaslt_host.cpp
4+++ b/library/src/hipblaslt_host.cpp
5@@ -155,22 +155,22 @@ namespace
6 hipblaslt_compute_type<Tc>);
7
8 hipblaslt_ext::GemmProblemType problemType;
9- problemType.op_a = (hipblasOperation_t)prob.trans_a;
10- problemType.op_b = (hipblasOperation_t)prob.trans_b;
11- problemType.type_a = hipblaslt_datatype<TiA>;
12- problemType.type_b = hipblaslt_datatype<TiB>;
13- problemType.type_c = hipblaslt_datatype<To>;
14- problemType.type_d = hipblaslt_datatype<To>;
15- problemType.type_compute = hipblaslt_compute_type<Tc>;
16+ problemType.setOpA((hipblasOperation_t)prob.trans_a);
17+ problemType.setOpB((hipblasOperation_t)prob.trans_b);
18+ problemType.setTypeA(hipblaslt_datatype<TiA>);
19+ problemType.setTypeB(hipblaslt_datatype<TiB>);
20+ problemType.setTypeC(hipblaslt_datatype<To>);
21+ problemType.setTypeD(hipblaslt_datatype<To>);
22+ problemType.setTypeCompute(hipblaslt_compute_type<Tc>);
23
24 hipblaslt_ext::GemmEpilogue epilogue;
25 hipblaslt_ext::GemmInputs inputs;
26- inputs.a = (void*)(prob.A + prob.buffer_offset_a);
27- inputs.b = (void*)(prob.B + prob.buffer_offset_b);
28- inputs.c = (void*)(prob.C + prob.buffer_offset_c);
29- inputs.d = (void*)(prob.D + prob.buffer_offset_d);
30- inputs.alpha = (void*)prob.alpha;
31- inputs.beta = (void*)prob.beta;
32+ inputs.setA((void*)(prob.A + prob.buffer_offset_a));
33+ inputs.setB((void*)(prob.B + prob.buffer_offset_b));
34+ inputs.setC((void*)(prob.C + prob.buffer_offset_c));
35+ inputs.setD((void*)(prob.D + prob.buffer_offset_d));
36+ inputs.setAlpha((void*)prob.alpha);
37+ inputs.setBeta((void*)prob.beta);
38
39 gemm.setProblem(prob.m,
40 prob.n,
41@@ -214,13 +214,13 @@ namespace
42 hipblaslt_compute_type<Tc>);
43
44 hipblaslt_ext::GemmProblemType problemType;
45- problemType.op_a = (hipblasOperation_t)prob.trans_a;
46- problemType.op_b = (hipblasOperation_t)prob.trans_b;
47- problemType.type_a = hipblaslt_datatype<TiA>;
48- problemType.type_b = hipblaslt_datatype<TiB>;
49- problemType.type_c = hipblaslt_datatype<To>;
50- problemType.type_d = hipblaslt_datatype<To>;
51- problemType.type_compute = hipblaslt_compute_type<Tc>;
52+ problemType.setOpA((hipblasOperation_t)prob.trans_a);
53+ problemType.setOpB((hipblasOperation_t)prob.trans_b);
54+ problemType.setTypeA(hipblaslt_datatype<TiA>);
55+ problemType.setTypeB(hipblaslt_datatype<TiB>);
56+ problemType.setTypeC(hipblaslt_datatype<To>);
57+ problemType.setTypeD(hipblaslt_datatype<To>);
58+ problemType.setTypeCompute(hipblaslt_compute_type<Tc>);
59
60 std::vector<int64_t> Ms(prob.batch_count);
61 std::vector<int64_t> Ns(prob.batch_count);
62@@ -251,12 +251,12 @@ namespace
63 stridecs[batch] = prob.batch_stride_c;
64 strideds[batch] = prob.batch_stride_d;
65 batch_counts[batch] = 1;
66- inputs[batch].a = (void*)(prob.batch_A[batch] + prob.buffer_offset_a);
67- inputs[batch].b = (void*)(prob.batch_B[batch] + prob.buffer_offset_b);
68- inputs[batch].c = (void*)(prob.batch_C[batch] + prob.buffer_offset_c);
69- inputs[batch].d = (void*)(prob.batch_D[batch] + prob.buffer_offset_d);
70- inputs[batch].alpha = (void*)prob.alpha;
71- inputs[batch].beta = (void*)prob.beta;
72+ inputs[batch].setA((void*)(prob.batch_A[batch] + prob.buffer_offset_a));
73+ inputs[batch].setB((void*)(prob.batch_B[batch] + prob.buffer_offset_b));
74+ inputs[batch].setC((void*)(prob.batch_C[batch] + prob.buffer_offset_c));
75+ inputs[batch].setD((void*)(prob.batch_D[batch] + prob.buffer_offset_d));
76+ inputs[batch].setAlpha((void*)prob.alpha);
77+ inputs[batch].setBeta((void*)prob.beta);
78 }
79
80 gemm.setProblem(Ms,
81diff --git a/library/src/tensile_host.cpp b/library/src/tensile_host.cpp
82index 1b1289f3..ed463725 100644
83--- a/library/src/tensile_host.cpp
84+++ b/library/src/tensile_host.cpp
85@@ -271,14 +271,6 @@ namespace
86 {
87 return Tensile::LazyLoadingInit::gfx90a;
88 }
89- else if(deviceString.find("gfx940") != std::string::npos)
90- {
91- return Tensile::LazyLoadingInit::gfx940;
92- }
93- else if(deviceString.find("gfx941") != std::string::npos)
94- {
95- return Tensile::LazyLoadingInit::gfx941;
96- }
97 else if(deviceString.find("gfx942") != std::string::npos)
98 {
99 return Tensile::LazyLoadingInit::gfx942;