24 #ifndef TVM_TOPI_CONTRIB_ROCBLAS_H_
25 #define TVM_TOPI_CONTRIB_ROCBLAS_H_
46 auto n = transa ? lhs->shape[1] : lhs->shape[0];
47 auto m = transb ? rhs->shape[0] : rhs->shape[1];
50 {{n, m}}, {lhs->dtype}, {lhs, rhs},
51 [&](Array<Buffer> ins, Array<Buffer> outs) {
52 return call_packed({
StringImm(
"tvm.contrib.rocblas.matmul"), pack_buffer(ins[0]),
53 pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb});
68 auto batch_size = lhs->shape[0];
69 auto n = transa ? lhs->shape[2] : lhs->shape[1];
70 auto m = transb ? rhs->shape[1] : rhs->shape[2];
73 {{batch_size, n, m}}, {lhs->dtype}, {lhs, rhs},
74 [&](Array<Buffer> ins, Array<Buffer> outs) {
75 return call_packed({
StringImm(
"tvm.contrib.rocblas.batch_matmul"), pack_buffer(ins[0]),
76 pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb});
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:100
Managed reference to StringImmNode.
Definition: expr.h:71
Helpers for using external functions.
Tensor expression language DSL.
Definition: extracted_task.h:33
Tensor rocblas_matmul(const Tensor &lhs, const Tensor &rhs, bool transa, bool transb)
Create an op that multiplies lhs and rhs with rocBLAS.
Definition: rocblas.h:45
Tensor rocblas_batch_matmul(const Tensor &lhs, const Tensor &rhs, bool transa, bool transb)
Create an op that batch multiplies lhs and rhs with rocBLAS.
Definition: rocblas.h:67
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
Operation node can generate one or multiple Tensors.