tvm
tirx_op.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 #ifndef TVM_TIRX_TIRX_OP_H_
24 #define TVM_TIRX_TIRX_OP_H_
25 
26 #include <tvm/ir/op.h>
27 #include <tvm/target/target.h>
28 #include <tvm/tirx/exec_scope.h>
29 #include <tvm/tirx/stmt.h>
30 #include <tvm/tirx/tirx_stmt.h>
31 
32 namespace tvm {
33 namespace tirx {
34 
40 using FArgSanitizer = ffi::TypedFunction<void(tvm::Op, ffi::Array<ffi::ObjectRef>)>;
41 
42 namespace callback {
44 constexpr const char* kPrivateAlloc = "private_alloc";
48 constexpr const char* kDeviceInitStmt = "device_init_stmt";
52 constexpr const char* kHostInitStmt = "host_init_stmt";
56 constexpr const char* kPostBufferDefStmt = "post_buffer_def_stmt";
57 } // namespace callback
58 
62 class ScheduleContextNode : public ffi::Object {
63  public:
69  ffi::Map<ffi::String, IterVar> launch_params;
71  ffi::Map<Var, Range> var_range_map;
73  bool alloc_only;
75  ffi::Map<ffi::String, ffi::ObjectRef> callbacks;
76 
77  static void RegisterReflection() {
78  namespace refl = tvm::ffi::reflection;
79  refl::ObjectDef<ScheduleContextNode>()
80  .def_ro("target", &ScheduleContextNode::target)
81  .def_ro("exec_scope", &ScheduleContextNode::exec_scope)
82  .def_ro("launch_params", &ScheduleContextNode::launch_params)
83  .def_ro("var_range_map", &ScheduleContextNode::var_range_map)
84  .def_ro("alloc_only", &ScheduleContextNode::alloc_only)
85  .def_ro("callbacks", &ScheduleContextNode::callbacks);
86  }
87 
89  void AddAllocBuffer(Buffer buffer);
90 
97  void AddInitStmt(Stmt stmt, bool host = false);
98 
99  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.ScheduleContext", ScheduleContextNode, ffi::Object);
100 };
101 
105 class ScheduleContext : public ffi::ObjectRef {
106  public:
116  TVM_DLL ScheduleContext(Target target, ExecScope exec_scope,
117  ffi::Map<ffi::String, IterVar> launch_params = {},
118  ffi::Map<Var, Range> var_range_map = {}, bool alloc_only = false,
119  ffi::Map<ffi::String, ffi::ObjectRef> callbacks = {});
120 
122 };
123 
130 using FOpScheduler = ffi::TypedFunction<Stmt(tvm::Op, ffi::Array<ffi::ObjectRef>, ScheduleContext)>;
131 
135 class DispatchContextNode : public ffi::Object {
136  public:
142  ffi::Map<ffi::String, IterVar> launch_params;
144  ffi::Map<Var, Range> var_range_map;
148  ffi::Map<ffi::String, ffi::ObjectRef> callbacks;
150  ffi::Map<ffi::String, ffi::ObjectRef> shared_state;
160  ffi::Map<ffi::String, ffi::Array<PrimExpr>> inter;
162  ffi::Map<ffi::String, ffi::Array<PrimExpr>> intra;
164  ffi::String scope_kind;
165 
166  static void RegisterReflection() {
167  namespace refl = tvm::ffi::reflection;
168  refl::ObjectDef<DispatchContextNode>()
169  .def_ro("target", &DispatchContextNode::target)
170  .def_ro("exec_scope", &DispatchContextNode::exec_scope)
171  .def_ro("launch_params", &DispatchContextNode::launch_params)
172  .def_ro("var_range_map", &DispatchContextNode::var_range_map)
173  .def_ro("alloc_only", &DispatchContextNode::alloc_only)
174  .def_ro("callbacks", &DispatchContextNode::callbacks)
175  .def_ro("shared_state", &DispatchContextNode::shared_state)
176  .def_ro("inter", &DispatchContextNode::inter)
177  .def_ro("intra", &DispatchContextNode::intra)
178  .def_ro("scope_kind", &DispatchContextNode::scope_kind);
179  }
180 
182  void AddAllocBuffer(Buffer buffer);
183 
185  void AddInitStmt(Stmt stmt, bool host = false);
186 
188  void AddPostBufferDefStmt(Buffer buffer, Stmt stmt);
189 
191  void SharedStateSet(ffi::String key, ffi::ObjectRef value);
192 
194  ffi::Optional<ffi::ObjectRef> SharedStateGet(ffi::String key);
195 
196  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.DispatchContext", DispatchContextNode, ffi::Object);
197 };
198 
202 class DispatchContext : public ffi::ObjectRef {
203  public:
204  TVM_DLL DispatchContext(Target target, ExecScope exec_scope,
205  ffi::Map<ffi::String, IterVar> launch_params = {},
206  ffi::Map<Var, Range> var_range_map = {}, bool alloc_only = false,
207  ffi::Map<ffi::String, ffi::ObjectRef> callbacks = {},
208  ffi::Map<ffi::String, ffi::ObjectRef> shared_state = {},
209  ffi::Map<ffi::String, ffi::Array<PrimExpr>> inter = {},
210  ffi::Map<ffi::String, ffi::Array<PrimExpr>> intra = {},
211  ffi::String scope_kind = "");
212 
214 };
215 
221 TVM_DLL const Op& cast();
222 
228 TVM_DLL const Op& permute_dims();
229 
235 TVM_DLL const Op& copy();
236 
242 TVM_DLL const Op& copy_async();
243 
249 TVM_DLL const Op& fill();
250 
256 TVM_DLL const Op& gemm();
257 
264 TVM_DLL const Op& gemm_async();
265 
266 TVM_DLL const Op& zero();
267 
268 TVM_DLL const Op& sqrt();
269 
270 TVM_DLL const Op& exp();
271 
272 TVM_DLL const Op& add();
273 
274 TVM_DLL const Op& sub();
275 
276 TVM_DLL const Op& mul();
277 
278 TVM_DLL const Op& fdiv();
279 
280 TVM_DLL const Op& minimum();
281 
282 TVM_DLL const Op& maximum();
283 
284 TVM_DLL const Op& reciprocal();
285 
286 TVM_DLL const Op& sum();
287 
288 TVM_DLL const Op& max();
289 
290 TVM_DLL const Op& min();
291 
292 TVM_DLL const Op& memset();
293 
294 TVM_DLL const Op& reduce_negate();
295 
296 TVM_DLL const Op& binary_reduce();
297 
298 TVM_DLL const Op& unary_reduce();
299 
300 TVM_DLL const Op& binary_chain();
301 
302 TVM_DLL const Op& select();
303 
309 TVM_DLL const Op& tvm_kernel_replace_point();
310 
311 } // namespace tirx
312 } // namespace tvm
313 
314 #endif // TVM_TIRX_TIRX_OP_H_
Managed reference class to OpNode.
Definition: op.h:131
Managed reference class to TargetNode.
Definition: target.h:135
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:172
The context information of the kernel required by op dispatch.
Definition: tirx_op.h:135
bool alloc_only
Whether the dispatch context is only used for buffer allocation.
Definition: tirx_op.h:146
ffi::Map< ffi::String, ffi::ObjectRef > shared_state
Shared state that persists across dispatch calls within a single lowering pass.
Definition: tirx_op.h:150
ffi::Map< ffi::String, ffi::ObjectRef > callbacks
Callback to be handled when the operator is scheduled.
Definition: tirx_op.h:148
ExecScope exec_scope
The exec scope of the operator.
Definition: tirx_op.h:140
static void RegisterReflection()
Definition: tirx_op.h:166
ffi::Map< ffi::String, IterVar > launch_params
The kernel launch parameters.
Definition: tirx_op.h:142
ffi::String scope_kind
Scope kind string ("kernel"/"cta"/"warpgroup"/"warp"/"thread"/"cluster").
Definition: tirx_op.h:164
Target target
The target of the kernel.
Definition: tirx_op.h:138
void SharedStateSet(ffi::String key, ffi::ObjectRef value)
Set a value in the shared state cache.
ffi::Map< ffi::String, ffi::Array< PrimExpr > > inter
ExecContext inter-team view at this op site.
Definition: tirx_op.h:160
ffi::Map< Var, Range > var_range_map
A map from loop variables to their ranges.
Definition: tirx_op.h:144
void AddAllocBuffer(Buffer buffer)
Add a buffer to be allocated in the kernel.
ffi::Map< ffi::String, ffi::Array< PrimExpr > > intra
ExecContext intra-team view. Same encoding as inter.
Definition: tirx_op.h:162
void AddPostBufferDefStmt(Buffer buffer, Stmt stmt)
Add a statement to be inserted after a buffer's definition.
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.DispatchContext", DispatchContextNode, ffi::Object)
ffi::Optional< ffi::ObjectRef > SharedStateGet(ffi::String key)
Get a value from the shared state cache.
void AddInitStmt(Stmt stmt, bool host=false)
Add an initialization statement to be inserted.
Managed reference to DispatchContextNode.
Definition: tirx_op.h:202
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DispatchContext, ffi::ObjectRef, DispatchContextNode)
DispatchContext(Target target, ExecScope exec_scope, ffi::Map< ffi::String, IterVar > launch_params={}, ffi::Map< Var, Range > var_range_map={}, bool alloc_only=false, ffi::Map< ffi::String, ffi::ObjectRef > callbacks={}, ffi::Map< ffi::String, ffi::ObjectRef > shared_state={}, ffi::Map< ffi::String, ffi::Array< PrimExpr >> inter={}, ffi::Map< ffi::String, ffi::Array< PrimExpr >> intra={}, ffi::String scope_kind="")
Definition: exec_scope.h:234
The context information of the kernel required by op schedule.
Definition: tirx_op.h:62
ffi::Map< ffi::String, ffi::ObjectRef > callbacks
Callback to be handled when the operator is scheduled.
Definition: tirx_op.h:75
ExecScope exec_scope
The exec scope of the operator.
Definition: tirx_op.h:67
void AddAllocBuffer(Buffer buffer)
Add a buffer to be allocated in the kernel.
ffi::Map< Var, Range > var_range_map
A map from loop variables to their ranges.
Definition: tirx_op.h:71
Target target
The target of the kernel.
Definition: tirx_op.h:65
bool alloc_only
Whether the schedule context is only used for buffer allocation.
Definition: tirx_op.h:73
ffi::Map< ffi::String, IterVar > launch_params
The kernel launch parameters.
Definition: tirx_op.h:69
static void RegisterReflection()
Definition: tirx_op.h:77
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.ScheduleContext", ScheduleContextNode, ffi::Object)
void AddInitStmt(Stmt stmt, bool host=false)
Add an initialization statement to be inserted.
Managed reference to ScheduleContextNode.
Definition: tirx_op.h:105
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ScheduleContext, ffi::ObjectRef, ScheduleContextNode)
ScheduleContext(Target target, ExecScope exec_scope, ffi::Map< ffi::String, IterVar > launch_params={}, ffi::Map< Var, Range > var_range_map={}, bool alloc_only=false, ffi::Map< ffi::String, ffi::ObjectRef > callbacks={})
Constructor.
Container of all statements.
Definition: stmt.h:67
Primitive operators(builtin intrinsics) and registry for them.
constexpr const char * kPrivateAlloc
The buffers allocated by the operator.
Definition: tirx_op.h:44
constexpr const char * kPostBufferDefStmt
Statements to be inserted after a specific buffer's definition (DeclBuffer/AllocBuffer)....
Definition: tirx_op.h:56
constexpr const char * kDeviceInitStmt
The initialization statement of the operator. which will be inserted at the beginning of the kernel.
Definition: tirx_op.h:48
constexpr const char * kHostInitStmt
The initialization statement of the operator. which will be inserted at the beginning of the kernel.
Definition: tirx_op.h:52
const Op & binary_chain()
const Op & copy_async()
See pesudo code below:
const Op & unary_reduce()
const Op & sqrt()
const Op & mul()
const Op & reciprocal()
const Op & permute_dims()
See pesudo code below:
const Op & reduce_negate()
const Op & cast()
See pesudo code below:
const Op & gemm_async()
See pesudo code below:
const Op & select()
const Op & minimum()
const Op & binary_reduce()
const Op & fill()
See pesudo code below:
ffi::TypedFunction< Stmt(tvm::Op, ffi::Array< ffi::ObjectRef >, ScheduleContext)> FOpScheduler
The type of the function that schedules a TIRX operator.
Definition: tirx_op.h:130
const Op & add()
const Op & gemm()
See pesudo code below:
const Op & sum()
const Op & copy()
See pesudo code below:
const Op & fdiv()
const Op & zero()
const Op & exp()
const Op & tvm_kernel_replace_point()
See pesudo code below:
ffi::TypedFunction< void(tvm::Op, ffi::Array< ffi::ObjectRef >)> FArgSanitizer
The type of the function that sanitizes the arguments of a TIRX operator.
Definition: tirx_op.h:40
const Op & memset()
const Op & max()
const Op & maximum()
const Op & sub()
const Op & min()
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
Compilation target object.
TIR statements.