24 #ifndef TVM_TOPI_NN_DENSE_H_
25 #define TVM_TOPI_NN_DENSE_H_
50 ICHECK_EQ(data->shape.size(), 2) <<
"dense requires 2-D data";
51 ICHECK_EQ(weight->shape.size(), 2) <<
"dense requires 2-D weight";
53 ICHECK_EQ(bias->shape.size(), 1) <<
"dense requires 1-D bias";
56 auto batch = data->shape[0];
57 auto in_dim = data->shape[1];
58 auto out_dim = weight->shape[0];
Range container
Definition: expr.h:725
Runtime primitive data type.
Definition: data_type.h:43
bool defined() const
Definition: object.h:552
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
a named variable in TIR
Definition: var.h:89
Tensor expression language DSL.
Definition: extracted_task.h:33
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
tvm::te::Tensor dense(const tvm::te::Tensor &data, const tvm::te::Tensor &weight, const tvm::te::Tensor &bias, const DataType &out_dtype)
Creates an operation that calculates data * weight^T + bias.
Definition: dense.h:48
constexpr auto kBroadcast
Definition: tags.h:36
tvm::te::Tensor matmul(const tvm::te::Tensor &A, const tvm::te::Tensor &B, bool trans_a=false, bool trans_b=false, std::string name="T_matmul", std::string tag=kMatMul)
Creates an operation that calculates a matrix multiplication (row-major notation): A(i,...
Definition: transform.h:1557
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Operation node can generate one or multiple Tensors.