tvm
axis_group_graph.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef TVM_RELAX_DISTRIBUTED_AXIS_GROUP_GRAPH_H_
21 #define TVM_RELAX_DISTRIBUTED_AXIS_GROUP_GRAPH_H_
22 
25 #include <tvm/relax/expr.h>
26 #include <tvm/tir/function.h>
27 #include <tvm/tir/stmt_functor.h>
28 
29 #include <algorithm>
30 #include <limits>
31 #include <string>
32 #include <tuple>
33 #include <unordered_map>
34 #include <unordered_set>
35 #include <utility>
36 #include <vector>
37 
38 namespace tvm {
39 namespace tir {
40 // (var, axis)
41 using TIRVarAxis = std::pair<Var, int>;
42 // (buffer, axis)
43 using BufferAxis = std::pair<Buffer, int>;
45  public:
46  size_t operator()(const BufferAxis& buffer_axis) const {
47  size_t const h1(ObjectPtrHash()(buffer_axis.first));
48  size_t const h2(std::hash<int>()(buffer_axis.second));
49  return h1 ^ (h2 << 1);
50  }
51 };
61 Var GetShardingVarFromIndex(PrimExpr index, ffi::Map<Var, Range> var_range,
62  arith::Analyzer* analyzer);
63 
69  public:
70  static std::vector<std::vector<TIRVarAxis>> GetTIRVarAxisGraph(const PrimFunc& prim_func) {
71  BufferAxisGraphExtractor extractor;
72  extractor(prim_func->body);
73  ffi::Map<Buffer, Var> inverse_buffer_map;
74  for (const auto& pr : prim_func->buffer_map) {
75  inverse_buffer_map.Set(pr.second, pr.first);
76  }
77  std::vector<std::vector<TIRVarAxis>> tir_var_axis_group_list;
78  std::unordered_set<BufferAxis, BufferAxisHash> visited;
79  for (const auto& pr : prim_func->buffer_map) {
80  Var param = pr.first;
81  Buffer buffer = pr.second;
82  for (int i = 0; i < static_cast<int>(buffer->shape.size()); i++) {
83  if (extractor.buffer_axis_graph_.count({buffer, i})) {
84  std::vector<BufferAxis> buffer_axis_group;
85  extractor.DFSGraph({buffer, i}, &visited, &buffer_axis_group);
86  if (buffer_axis_group.size() <= 1) {
87  continue;
88  }
89  std::vector<TIRVarAxis> tir_var_axis_group;
90  for (const auto& buffer_axis : buffer_axis_group) {
91  if (!inverse_buffer_map.count(buffer_axis.first)) {
92  continue;
93  }
94  tir_var_axis_group.push_back(
95  {inverse_buffer_map[buffer_axis.first], buffer_axis.second});
96  }
97  tir_var_axis_group_list.push_back(tir_var_axis_group);
98  }
99  }
100  }
101  return tir_var_axis_group_list;
102  }
103 
104  void DFSGraph(BufferAxis cur, std::unordered_set<BufferAxis, BufferAxisHash>* visited,
105  std::vector<BufferAxis>* buffer_axis_group) {
106  if (visited->count(cur)) {
107  return;
108  }
109  visited->insert(cur);
110  buffer_axis_group->push_back(cur);
111  for (const auto& next : buffer_axis_graph_[cur]) {
112  DFSGraph(next, visited, buffer_axis_group);
113  }
114  }
115 
116  private:
117  void VisitStmt_(const BufferStoreNode* op) final {
119  buffer_access_indices_.push_back({op->buffer, op->indices});
120  }
121 
122  void VisitExpr_(const BufferLoadNode* op) final {
124  buffer_access_indices_.push_back({op->buffer, op->indices});
125  }
126 
127  bool Match(PrimExpr a, PrimExpr buffer_shape_a, PrimExpr b, PrimExpr buffer_shape_b,
128  arith::Analyzer* analyzer) {
129  if (b.as<VarNode>()) {
130  std::swap(a, b);
131  std::swap(buffer_shape_a, buffer_shape_b);
132  }
133  if (!a.as<VarNode>()) {
134  return false;
135  }
136  Var var = Downcast<Var>(a);
137  analyzer->Bind(iter_var_range_);
138  b = analyzer->Simplify(b);
139  // index var `a` must access whole range of a specific buffer dimension
140  arith::IntSet intset_b = arith::EvalSet(b, arith::AsIntSet(iter_var_range_));
141  if (!analyzer->CanProveEqual(buffer_shape_a, iter_var_range_[var]->extent) ||
142  !intset_b.MatchRange(Range::FromMinExtent(0, buffer_shape_b))) {
143  return false;
144  }
145  Var matched_var = GetShardingVarFromIndex(b, iter_var_range_, analyzer);
146  if (!matched_var.same_as(var)) {
147  return false;
148  }
149  return true;
150  }
151 
152  void VisitStmt_(const BlockNode* op) final {
153  if (op->name_hint == "root") {
155  return;
156  }
157  buffer_access_indices_.clear();
159  iter_var_range_.clear();
160  for (const auto& iter_var : op->iter_vars) {
161  iter_var_range_.Set(iter_var->var, iter_var->dom);
162  }
163  arith::Analyzer analyzer;
164  for (const auto& access_pr : buffer_access_indices_) {
165  Buffer buffer = access_pr.first;
166  ffi::Array<PrimExpr> indices = access_pr.second;
167  for (int i = 0; i < static_cast<int>(indices.size()); i++) {
168  for (const auto& another_access_pr : buffer_access_indices_) {
169  if (another_access_pr.first.same_as(buffer)) {
170  continue;
171  }
172  Buffer another_buffer = another_access_pr.first;
173  ffi::Array<PrimExpr> another_indices = another_access_pr.second;
174  for (int j = 0; j < static_cast<int>(another_indices.size()); j++) {
175  if (Match(indices[i], buffer->shape[i], another_indices[j], another_buffer->shape[j],
176  &analyzer)) {
177  JoinBufferAxis({buffer, i}, {another_buffer, j});
178  }
179  }
180  }
181  }
182  }
183  }
184 
185  void JoinBufferAxis(BufferAxis axis1, BufferAxis axis2) {
186  if (!buffer_axis_graph_.count(axis1)) {
187  buffer_axis_graph_[axis1] = {};
188  }
189  if (!buffer_axis_graph_.count(axis2)) {
190  buffer_axis_graph_[axis2] = {};
191  }
192  buffer_axis_graph_[axis1].push_back(axis2);
193  buffer_axis_graph_[axis2].push_back(axis1);
194  }
195 
196  std::vector<std::pair<Buffer, ffi::Array<PrimExpr>>> buffer_access_indices_;
197  std::unordered_map<BufferAxis, std::vector<BufferAxis>, BufferAxisHash> buffer_axis_graph_;
198  ffi::Map<Var, Range> iter_var_range_;
199  std::string func_name;
200 };
201 } // namespace tir
202 } // namespace tvm
203 
204 namespace tvm {
205 namespace relax {
206 namespace distributed {
207 
209 struct Axis {
210  const ExprNode* tensor;
211  int dim = 0;
212  int tuple_index = 0;
213 
214  Axis(const ExprNode* tensor, int dim, int tuple_index = 0)
216  ICHECK(tensor->IsInstance<ConstantNode>() || tensor->IsInstance<VarNode>());
217  }
218 
219  bool operator==(const Axis& other) const {
220  return tensor == other.tensor && dim == other.dim && tuple_index == other.tuple_index;
221  }
222 };
223 
224 class AxisHash {
225  public:
226  size_t operator()(const Axis& axis) const {
227  size_t const h1(std::hash<const ExprNode*>()(axis.tensor));
228  size_t const h2(std::hash<int>()(axis.dim));
229  size_t const h3(std::hash<int>()(axis.tuple_index));
230  return h1 ^ (h2 << 1) ^ (h3 << 2);
231  }
232 };
233 
234 using AxisGroup = std::unordered_set<Axis, AxisHash>;
235 
237  public:
238  size_t operator()(const AxisGroup& axis_group) const {
239  size_t seed = 0;
240  for (auto axis : axis_group) {
241  seed ^= AxisHash()(axis) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
242  }
243  return seed;
244  }
245 };
246 
247 using ShardingSpec = std::pair<DeviceMesh, Placement>;
248 
249 // device mesh and the device mesh axis that the tensor axis maps to
250 using AxisShardingSpec = std::pair<DeviceMesh, int>;
252  public:
253  bool operator()(const AxisShardingSpec& lhs, const AxisShardingSpec& rhs) const {
254  return StructuralEqual()(lhs.first, rhs.first) && lhs.second == rhs.second;
255  }
256 };
257 
259  public:
260  size_t operator()(const AxisShardingSpec& sharding_spec) const {
261  size_t seed = 0;
262  seed ^= StructuralHash()(sharding_spec.first);
263  seed ^= std::hash<int>()(sharding_spec.second) << 1;
264  return seed;
265  }
266 };
267 
274  public:
275  enum class EdgeType { kAscend, kDescend, kSimbling };
276 
277  private:
278  static EdgeType ReverseEdgeType(EdgeType type) {
279  switch (type) {
280  case EdgeType::kAscend:
281  return EdgeType::kDescend;
282  case EdgeType::kDescend:
283  return EdgeType::kAscend;
284  case EdgeType::kSimbling:
285  return EdgeType::kSimbling;
286  }
287  LOG(FATAL) << "Unreachable code";
288  throw;
289  }
290 
291  static int GetEdgePriority(EdgeType type) {
292  switch (type) {
293  case EdgeType::kAscend:
294  return 0;
295  case EdgeType::kDescend:
296  return 2;
297  case EdgeType::kSimbling:
298  return 1;
299  }
300  LOG(FATAL) << "Unreachable code";
301  throw;
302  }
303 
304  struct AxisGraphEdge {
305  Axis src;
306  Axis dst;
307 
308  // the producer-consumer relationship between src tensor and dst tensor
309  // kAscend means consumer->producer
310  // kDescend means producer->consumer
311  // kSimbling means other cases
312  EdgeType type;
313 
314  bool operator==(const AxisGraphEdge& other) const {
315  return src == other.src && dst == other.dst && type == other.type;
316  }
317  };
318 
319  struct Path {
320  int direction = 0;
321 
322  Path AddEdge(EdgeType type) { return {direction |= (1 << GetEdgePriority(type))}; }
323 
324  int GetPriority() const {
325  switch (direction) {
326  case 1: // ascend only
327  return 0;
328  case 4: // descend only
329  return 2;
330  case 0: // empty path (source node)
331  return 3; // source node must have max priority
332  default: // mixed path
333  return 1;
334  }
335  }
336  };
337 
338  public:
339  AxisGroupGraph() = default;
340 
350  void JoinAxis(Axis axis1, Axis axis2, EdgeType type) {
351  AddEdge(axis1, axis2, type);
352  AddEdge(axis2, axis1, ReverseEdgeType(type));
353  }
354 
361  src_axis_sharding_spec_[axis] = spec;
362  }
363 
368  axis_sharding_specs_priority_.clear();
369  for (const auto& pr : src_axis_sharding_spec_) {
370  std::unordered_set<Axis, AxisHash> visited;
371  PropagateShardingSpec(pr.first, pr.second, Path(), &visited);
372  }
373  ChooseAxisShardingSpec();
374  }
375 
383  cutpoint_axis_sharding_spec_[axis] = spec;
384  }
385 
393  std::tuple<AxisShardingSpec, bool> GetAxisShardingSpec(Axis axis) {
394  if (axis_sharding_specs_priority_.count(axis)) {
395  return {axis_sharding_specs_priority_[axis].begin()->first, true};
396  } else {
397  return {{DeviceMesh(), -1}, false};
398  }
399  }
400 
401  private:
402  void AddEdge(Axis src, Axis dst, EdgeType type) {
403  if (!graph_.count(src)) {
404  graph_[src] = {};
405  }
406  graph_[src].push_back({src, dst, type});
407  }
408 
409  void PropagateShardingSpec(Axis axis, AxisShardingSpec spec, Path path,
410  std::unordered_set<Axis, AxisHash>* visited) {
411  if (cutpoint_axis_sharding_spec_.count(axis) ||
412  (src_axis_sharding_spec_.count(axis) &&
413  !AxisShardingSpecEqual()(src_axis_sharding_spec_[axis], spec)) ||
414  visited->count(axis)) {
415  return;
416  }
417  visited->insert(axis);
418  if (!axis_sharding_specs_priority_.count(axis)) {
419  axis_sharding_specs_priority_[axis] = {};
420  }
421  axis_sharding_specs_priority_[axis][spec] = path.GetPriority();
422  for (auto edge : graph_[axis]) {
423  PropagateShardingSpec(edge.dst, spec, path.AddEdge(edge.type), visited);
424  }
425  }
426 
427  void ChooseAxisShardingSpec() {
428  for (auto& pr : axis_sharding_specs_priority_) {
429  auto& axis = pr.first;
430  auto& specs = pr.second;
431  int max_priority = std::numeric_limits<int>::min();
432  for (auto& pr2 : specs) {
433  max_priority = std::max(max_priority, pr2.second);
434  }
435  for (auto it = specs.begin(); it != specs.end();) {
436  if (it->second != max_priority) {
437  it = specs.erase(it);
438  } else {
439  it++;
440  }
441  }
442  ICHECK(specs.size() == 1) << "multiple possible sharding for axis: ("
443  << ffi::GetRef<Expr>(axis.tensor) << ", " << axis.dim << ")";
444  }
445  }
446 
447  // union set
448  std::unordered_map<Axis, std::vector<AxisGraphEdge>, AxisHash> graph_;
449  std::unordered_map<Axis, AxisShardingSpec, AxisHash> src_axis_sharding_spec_;
450  std::unordered_map<Axis, AxisShardingSpec, AxisHash> cutpoint_axis_sharding_spec_;
451  std::unordered_map<
452  Axis, std::unordered_map<AxisShardingSpec, int, AxisShardingSpecHash, AxisShardingSpecEqual>,
453  AxisHash>
454  axis_sharding_specs_priority_;
455 };
456 
457 using FBuildAxisGraph = std::function<void(const Var& output_var, const Call& call,
458  distributed::AxisGroupGraph* axis_group_graph)>;
459 
460 void BuildAxisGraphUnary(const Var& output_var, const Call& call,
461  distributed::AxisGroupGraph* axis_group_graph);
462 void BuildAxisGraphBinary(const Var& output_var, const Call& call,
463  distributed::AxisGroupGraph* axis_group_graph);
464 void BuildAxisGraphReduce(const Var& output_var, const Call& call,
465  distributed::AxisGroupGraph* axis_group_graph);
466 void BuildAxisGraphMatmul(const Var& output_var, const Call& call,
467  distributed::AxisGroupGraph* axis_group_graph);
468 void BuildAxisGraphPermuteDims(const Var& output_var, const Call& call,
469  distributed::AxisGroupGraph* axis_group_graph);
470 void BuildAxisGraphReshape(const Var& output_var, const Call& call,
471  distributed::AxisGroupGraph* axis_group_graph);
472 void BuildAxisGraphCallTIR(const Var& output_var, const Call& call, const tir::PrimFunc& func,
473  distributed::AxisGroupGraph* axis_group_graph);
474 
475 } // namespace distributed
476 } // namespace relax
477 } // namespace tvm
478 
479 #endif // TVM_RELAX_DISTRIBUTED_AXIS_GROUP_GRAPH_H_
Reference to PrimExprNode.
Definition: expr.h:124
static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span=Span())
construct a new range with min and extent The corresponding constructor is removed,...
Base node of all non-primitive expressions.
Definition: expr.h:416
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:97
Content-aware structural hashing.
Definition: structural_hash.h:116
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:634
Definition: expr.h:176
Constant tensor.
Definition: expr.h:422
The variable class for all Relax bindings.
Definition: expr.h:340
Definition: expr.h:377
A graph whose nodes are tensor axes, and the edge means some information can be propagated through th...
Definition: axis_group_graph.h:273
void JoinAxis(Axis axis1, Axis axis2, EdgeType type)
add edge between two axes
Definition: axis_group_graph.h:350
void AddSrcShardingPoint(Axis axis, AxisShardingSpec spec)
add a source shardingspec to propagate
Definition: axis_group_graph.h:360
std::tuple< AxisShardingSpec, bool > GetAxisShardingSpec(Axis axis)
Get the Sharding Spec of an axis after propagation.
Definition: axis_group_graph.h:393
void PropagateShardingSpec()
propagate sharding specs from source axes
Definition: axis_group_graph.h:367
EdgeType
Definition: axis_group_graph.h:275
void AddPropagationCutPoint(Axis axis, AxisShardingSpec spec)
add a cut point that stops the propagation of a certain sharding spec
Definition: axis_group_graph.h:382
Definition: axis_group_graph.h:236
size_t operator()(const AxisGroup &axis_group) const
Definition: axis_group_graph.h:238
Definition: axis_group_graph.h:224
size_t operator()(const Axis &axis) const
Definition: axis_group_graph.h:226
Definition: axis_group_graph.h:251
bool operator()(const AxisShardingSpec &lhs, const AxisShardingSpec &rhs) const
Definition: axis_group_graph.h:253
Definition: axis_group_graph.h:258
size_t operator()(const AxisShardingSpec &sharding_spec) const
Definition: axis_group_graph.h:260
Managed reference to a DeviceMesh.
Definition: global_info.h:62
Construct an axis group graph from a PrimFunc. Two buffer axis are connected if they are accessed by ...
Definition: axis_group_graph.h:68
static std::vector< std::vector< TIRVarAxis > > GetTIRVarAxisGraph(const PrimFunc &prim_func)
Definition: axis_group_graph.h:70
void DFSGraph(BufferAxis cur, std::unordered_set< BufferAxis, BufferAxisHash > *visited, std::vector< BufferAxis > *buffer_axis_group)
Definition: axis_group_graph.h:104
Definition: axis_group_graph.h:44
size_t operator()(const BufferAxis &buffer_axis) const
Definition: axis_group_graph.h:46
Store value to the high dimension buffer.
Definition: stmt.h:194
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:156
void VisitExpr_(const VarNode *op) override
Managed reference to PrimFuncNode.
Definition: function.h:129
Visitor that recursively visit stmts and exprs on them.
Definition: stmt_functor.h:285
void VisitStmt_(const AttrStmtNode *op) override
a named variable in TIR
Definition: var.h:77
Struct info for DTensor (Distributed Tensor)
Iterator quasi-affine mapping patterns.
ffi::Map< Var, arith::IntSet > AsIntSet(const ffi::Map< Var, Range > &var_dom)
Converts the Ranges to IntSets.
IntSet EvalSet(PrimExpr e, const ffi::Map< IterVar, IntSet > &dom_map)
Find an symbolic integer set that contains all possible values of e given the domain of each iteratio...
void BuildAxisGraphBinary(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
std::unordered_set< Axis, AxisHash > AxisGroup
Definition: axis_group_graph.h:234
void BuildAxisGraphUnary(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
void BuildAxisGraphCallTIR(const Var &output_var, const Call &call, const tir::PrimFunc &func, distributed::AxisGroupGraph *axis_group_graph)
std::function< void(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)> FBuildAxisGraph
Definition: axis_group_graph.h:458
std::pair< DeviceMesh, int > AxisShardingSpec
Definition: axis_group_graph.h:250
void BuildAxisGraphMatmul(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
void BuildAxisGraphReshape(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
void BuildAxisGraphReduce(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
void BuildAxisGraphPermuteDims(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
std::pair< DeviceMesh, Placement > ShardingSpec
Definition: axis_group_graph.h:247
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
std::pair< Buffer, int > BufferAxis
Definition: axis_group_graph.h:43
Var GetShardingVarFromIndex(PrimExpr index, ffi::Map< Var, Range > var_range, arith::Analyzer *analyzer)
Suppose we want to shard a buffer along a specific dimension, we need to know how to rewrite the acce...
std::pair< Var, int > TIRVarAxis
Definition: axis_group_graph.h:41
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr operator==(PrimExpr a, PrimExpr b)
equal
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
Functors for tir stmts utility functions to call common functors.
tensor axis
Definition: axis_group_graph.h:209
Axis(const ExprNode *tensor, int dim, int tuple_index=0)
Definition: axis_group_graph.h:214
int dim
Definition: axis_group_graph.h:211
bool operator==(const Axis &other) const
Definition: axis_group_graph.h:219
int tuple_index
Definition: axis_group_graph.h:212
const ExprNode * tensor
Definition: axis_group_graph.h:210
TIR Function.