tvm
bnn.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_X86_BNN_H_
25 #define TVM_TOPI_X86_BNN_H_
26 
28 #include <tvm/te/operation.h>
29 #include <tvm/topi/detail/fuse.h>
30 #include <tvm/topi/tags.h>
31 
32 namespace tvm {
33 namespace topi {
34 
35 using namespace tvm::te;
36 
37 namespace x86 {
46 inline Schedule schedule_binarize_pack(const Target& target, const Array<Tensor>& outs) {
47  Array<Operation> out_ops;
48  for (auto t : outs) {
49  out_ops.push_back(t->op);
50  }
51  auto s = create_schedule(out_ops);
52 
53  auto _schedule = [&](const Tensor& out) {
54  s[out].parallel(out->op.as<ComputeOpNode>()->axis[0]);
55  };
56 
57  std::function<void(Operation)> traverse;
58  traverse = [&](const Operation& op) {
59  if (op->tag == "binarize_pack") {
60  _schedule(op.output(0));
61  } else {
62  LOG(ERROR) << "Unsupported operator " << op->tag;
63  }
64  };
65 
66  traverse(outs[0]->op);
67  return s;
68 }
69 
78 inline Schedule schedule_binary_dense(const Target& target, const Array<Tensor>& outs) {
79  Array<Operation> out_ops;
80  for (auto t : outs) {
81  out_ops.push_back(t->op);
82  }
83  auto s = create_schedule(out_ops);
84 
85  auto _schedule = [&](const Tensor& A, const Tensor& B, const Tensor& C) {
86  IterVar co, ci;
87  s[C].split(s[C]->op.as<ComputeOpNode>()->reduce_axis[0], 8, &co, &ci);
88  s[C].parallel(s[C]->op.as<ComputeOpNode>()->axis[0]);
89 
90  Tensor out;
91  if (detail::contains(s->outputs, C->op)) {
92  out = C;
93  } else {
94  out = outs[0]->op.output(0);
95  }
96 
97  IterVar xo, xi;
98  s[out].split(out->op.as<ComputeOpNode>()->axis[1], 8, &xo, &xi);
99  s[out].vectorize(xi);
100  };
101 
102  std::function<void(Operation)> traverse;
103  traverse = [&](const Operation& op) {
104  // Inline all one-to-one-mapping operators except the last stage (output)
105  if (is_broadcast(op->tag)) {
106  if (!detail::contains(s->outputs, op)) {
107  s[op].compute_inline();
108  }
109  for (auto tensor : op->InputTensors()) {
110  if (tensor->op->InputTensors().size() > 0) {
111  traverse(tensor->op);
112  }
113  }
114  } else if (op->tag == "binary_dense") {
115  auto output = op.output(0);
116  auto data = op->InputTensors()[0];
117  auto weight = op->InputTensors()[1];
118  _schedule(data, weight, output);
119  } else {
120  LOG(ERROR) << "Unsupported operator " << op->tag;
121  }
122  };
123 
124  traverse(outs[0]->op);
125  return s;
126 }
127 
128 } // namespace x86
129 } // namespace topi
130 } // namespace tvm
131 #endif // TVM_TOPI_X86_BNN_H_
Managed reference class to TargetNode.
Definition: target.h:200
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:910
Array< IterVar > axis
IterVar on each axis.
Definition: operation.h:207
Array< IterVar > reduce_axis
IterVar on each reduction axis, if the body is a Reduce.
Definition: operation.h:209
A Compute op that compute a tensor on certain domain.
Definition: operation.h:226
Operation that produces tensors.
Definition: tensor.h:47
Global schedule container For operations and all the operations they depend on. The schedule per Oper...
Definition: schedule.h:326
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:315
Fuse operation.
Generic function that can be specialzied on a per target basis.
Tensor expression language DSL.
Definition: extracted_task.h:33
Schedule create_schedule(Array< Operation > ops)
Create a schedule for array of ops(and their dependencies).
Definition: schedule.h:702
Schedule schedule_binarize_pack(const Target &target, const Array< Tensor > &outs)
Create a generic schedule for binarize_pack.
Definition: bnn.h:46
Schedule schedule_binary_dense(const Target &target, const Array< Tensor > &outs)
Create a generic schedule for binary_dense.
Definition: bnn.h:78
bool is_broadcast(std::string tag)
Definition: tags.h:47
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Operation node can generate one or multiple Tensors.
External function interface to rocBLAS libraries.