24 #ifndef TVM_TOPI_CONTRIB_CUBLAS_H_
25 #define TVM_TOPI_CONTRIB_CUBLAS_H_
35 using namespace topi::detail;
47 auto n = transa ? lhs->shape[1] : lhs->shape[0];
48 auto m = transb ? rhs->shape[0] : rhs->shape[1];
51 {{n, m}}, {lhs->dtype}, {lhs, rhs},
53 return call_packed({
StringImm(
"tvm.contrib.cublas.matmul"), pack_buffer(ins[0]),
54 pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb});
71 auto b = lhs->shape[0];
72 auto n = transa ? lhs->shape[2] : lhs->shape[1];
73 auto m = transb ? rhs->shape[1] : rhs->shape[2];
76 {{b, n, m}}, {lhs->dtype}, {lhs, rhs},
78 return call_packed({
StringImm(
"tvm.contrib.cublas.batch_matmul"), pack_buffer(ins[0]),
79 pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb});
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Managed reference to StringImmNode.
Definition: expr.h:78
Helpers for using external functions.
Tensor expression language DSL.
Definition: extracted_task.h:33
Tensor cublas_matmul(const Tensor &lhs, const Tensor &rhs, bool transa, bool transb)
Create an op that multiplies lhs and rhs with cuBLAS.
Definition: cublas.h:46
Tensor cublas_batch_matmul(const Tensor &lhs, const Tensor &rhs, bool transa, bool transb)
Create an op that multiplies batch matrices lhs and rhs with cuBLAS.
Definition: cublas.h:70
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Operation node can generate one or multiple Tensors.