tvm
exec_context.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
28 #ifndef TVM_TIRX_EXEC_CONTEXT_H_
29 #define TVM_TIRX_EXEC_CONTEXT_H_
30 
31 #include <tvm/tirx/exec_scope.h>
32 #include <tvm/tirx/layout.h>
33 #include <tvm/tirx/var.h>
34 
35 #include <string>
36 #include <unordered_map>
37 #include <vector>
38 
39 namespace tvm {
40 namespace tirx {
41 
43 constexpr int kWgSize = 4;
44 
46 struct AxisRange {
50 
52  bool Intersect(int64_t lo, int64_t hi, AxisRange* out) const;
53 
55  bool Modulo(int64_t modulus, int64_t residue, AxisRange* out) const;
56 };
57 
64 struct ActiveSet {
66 
67  int64_t size() const;
68  bool GetAxis(const std::string& axis, AxisRange* out) const;
69  bool HasAxis(const std::string& axis) const;
70  ActiveSet WithAxis(const std::string& axis, const AxisRange& range) const;
71  std::vector<std::string> AxisNames() const;
72 };
73 
80 struct ExecSplit {
81  std::unordered_map<std::string, AxisRange> inter;
82  std::unordered_map<std::string, AxisRange> intra;
83 };
84 
86 TVM_DLL ActiveSet InitialActiveSet(int64_t lane_ext, int64_t warp_ext, int64_t cta_ext);
87 TVM_DLL ActiveSet InitialActiveSet(int64_t lane_ext, int64_t warp_ext, int64_t cta_ext,
88  const std::vector<std::pair<std::string, int64_t>>& cta_axes);
89 
101 TVM_DLL bool FilterNarrow(const ActiveSet& A, ScopeBinding binding, int64_t lo, int64_t hi,
102  ActiveSet* out, std::string* err);
103 
110 TVM_DLL bool ScopeSwitch(const ActiveSet& A, ScopeKind scope_kind, ExecSplit* out,
111  std::string* err);
112 
114 struct ExecContext {
117  ExecSplit split; // (inter, intra) of current A under current scope_kind
118 
120  static ExecContext AtKernelEntry(int64_t lane_ext, int64_t warp_ext, int64_t cta_ext);
121  static ExecContext AtKernelEntry(int64_t lane_ext, int64_t warp_ext, int64_t cta_ext,
122  const std::vector<std::pair<std::string, int64_t>>& cta_axes);
123 
125  bool WithFilter(ScopeBinding binding, int64_t lo, int64_t hi, ExecContext* out,
126  std::string* err) const;
127 
130  std::string* err) const;
131 
133  bool WithCtaAxisFilter(const std::string& axis, int64_t lo, int64_t hi, ExecContext* out,
134  std::string* err) const;
135 
137  bool WithCtaAxisModulo(const std::string& axis, int64_t modulus, int64_t residue,
138  ExecContext* out, std::string* err) const;
139 
141  bool WithScopeSwitch(ScopeKind new_scope_kind, ExecContext* out, std::string* err) const;
142 };
143 
149 TVM_DLL ffi::Map<ffi::String, ffi::Array<PrimExpr>> EncodeSplitSide(
150  const std::unordered_map<std::string, AxisRange>& side);
151 
152 } // namespace tirx
153 } // namespace tvm
154 
155 #endif // TVM_TIRX_EXEC_CONTEXT_H_
Reference to PrimExprNode.
Definition: expr.h:126
Definition: layout.h:396
Definition of layout.
const Op & selector()
Analysis-only active-thread selector.
ffi::Map< ffi::String, ffi::Array< PrimExpr > > EncodeSplitSide(const std::unordered_map< std::string, AxisRange > &side)
Encode one side of an ExecSplit (inter or intra) as the FFI map used by DispatchContextNode::{inter,...
ScopeKind
The target execution scope kind of an ExecScopeStmt.
Definition: exec_scope.h:45
constexpr int kWgSize
Warpgroup size in warps (hardware-fixed).
Definition: exec_context.h:43
ActiveSet InitialActiveSet(int64_t lane_ext, int64_t warp_ext, int64_t cta_ext)
Initial A at T.kernel() entry: all threads active, offsets zero.
bool ScopeSwitch(const ActiveSet &A, ScopeKind scope_kind, ExecSplit *out, std::string *err)
Factor A into (inter, intra) for target scope_kind.
ScopeBinding
The binding between a parent scope and a child scope as used by a ScopeIdDef. The closed enum of vali...
Definition: exec_scope.h:81
bool FilterNarrow(const ActiveSet &A, ScopeBinding binding, int64_t lo, int64_t hi, ActiveSet *out, std::string *err)
Narrow A on the lane bound to binding.
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
Active thread set A. The source of truth is layout: shard = active axes with extents offset = per-axi...
Definition: exec_context.h:64
ActiveSet WithAxis(const std::string &axis, const AxisRange &range) const
bool HasAxis(const std::string &axis) const
bool GetAxis(const std::string &axis, AxisRange *out) const
int64_t size() const
std::vector< std::string > AxisNames() const
TileLayout layout
Definition: exec_context.h:65
Active slice offset + stride * [0, extent) encoded on one TileLayout axis.
Definition: exec_context.h:46
PrimExpr stride
Definition: exec_context.h:49
PrimExpr extent
Definition: exec_context.h:47
bool Intersect(int64_t lo, int64_t hi, AxisRange *out) const
Intersect with [lo, hi). Returns false if the result is empty.
PrimExpr offset
Definition: exec_context.h:48
bool Modulo(int64_t modulus, int64_t residue, AxisRange *out) const
Intersect with values satisfying axis % modulus == residue.
Per-program-point ExecContext: active set + scope kind + split.
Definition: exec_context.h:114
bool WithSelector(ScopeBinding binding, PrimExpr selector, ExecContext *out, std::string *err) const
Apply a unique-value selector filter on one scope id Var.
ScopeKind scope_kind
Definition: exec_context.h:116
ExecSplit split
Definition: exec_context.h:117
static ExecContext AtKernelEntry(int64_t lane_ext, int64_t warp_ext, int64_t cta_ext)
Kernel-entry ctor.
bool WithScopeSwitch(ScopeKind new_scope_kind, ExecContext *out, std::string *err) const
Apply scope_switch; A preserved, split recomputed for new scope_kind.
bool WithCtaAxisFilter(const std::string &axis, int64_t lo, int64_t hi, ExecContext *out, std::string *err) const
Apply filter on a factorized CTA axis such as cbx/cby/cbz.
ActiveSet A
Definition: exec_context.h:115
static ExecContext AtKernelEntry(int64_t lane_ext, int64_t warp_ext, int64_t cta_ext, const std::vector< std::pair< std::string, int64_t >> &cta_axes)
bool WithCtaAxisModulo(const std::string &axis, int64_t modulus, int64_t residue, ExecContext *out, std::string *err) const
Apply modulo filter on a factorized CTA axis such as cbx/cby/cbz.
bool WithFilter(ScopeBinding binding, int64_t lo, int64_t hi, ExecContext *out, std::string *err) const
Apply filter; scope_kind preserved, split recomputed.
One scope_switch split. Fields are sparse dicts keyed by active-set axis name, e.g....
Definition: exec_context.h:80
std::unordered_map< std::string, AxisRange > inter
Definition: exec_context.h:81
std::unordered_map< std::string, AxisRange > intra
Definition: exec_context.h:82
Variables in the TIR.