24 #ifndef TVM_TOPI_CONTRIB_ROCBLAS_H_
25 #define TVM_TOPI_CONTRIB_ROCBLAS_H_
46 auto n = transa ? lhs->shape[1] : lhs->shape[0];
47 auto m = transb ? rhs->shape[0] : rhs->shape[1];
50 {{n, m}}, {lhs->dtype}, {lhs, rhs},
52 return call_packed({
StringImm(
"tvm.contrib.rocblas.matmul"), pack_buffer(ins[0]),
53 pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb});
68 auto batch_size = lhs->shape[0];
69 auto n = transa ? lhs->shape[2] : lhs->shape[1];
70 auto m = transb ? rhs->shape[1] : rhs->shape[2];
73 {{batch_size, n, m}}, {lhs->dtype}, {lhs, rhs},
75 return call_packed({
StringImm(
"tvm.contrib.rocblas.batch_matmul"), pack_buffer(ins[0]),
76 pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb});
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Managed reference to StringImmNode.
Definition: expr.h:78
Helpers for using external functions.
Tensor expression language DSL.
Definition: extracted_task.h:33
Tensor rocblas_matmul(const Tensor &lhs, const Tensor &rhs, bool transa, bool transb)
Create an op that multiplies lhs and rhs with rocBLAS.
Definition: rocblas.h:45
Tensor rocblas_batch_matmul(const Tensor &lhs, const Tensor &rhs, bool transa, bool transb)
Create an op that batch multiplies lhs and rhs with rocBLAS.
Definition: rocblas.h:67
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Operation node can generate one or multiple Tensors.