tvm
axis_group_graph.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef TVM_RELAX_DISTRIBUTED_AXIS_GROUP_GRAPH_H_
21 #define TVM_RELAX_DISTRIBUTED_AXIS_GROUP_GRAPH_H_
22 
25 #include <tvm/relax/expr.h>
26 #include <tvm/tir/function.h>
27 #include <tvm/tir/stmt_functor.h>
28 
29 #include <algorithm>
30 #include <limits>
31 #include <string>
32 #include <tuple>
33 #include <unordered_map>
34 #include <unordered_set>
35 #include <utility>
36 #include <vector>
37 
38 namespace tvm {
39 namespace tir {
40 // (var, axis)
41 using TIRVarAxis = std::pair<Var, int>;
42 // (buffer, axis)
43 using BufferAxis = std::pair<Buffer, int>;
45  public:
46  size_t operator()(const BufferAxis& buffer_axis) const {
47  size_t const h1(ObjectPtrHash()(buffer_axis.first));
48  size_t const h2(std::hash<int>()(buffer_axis.second));
49  return h1 ^ (h2 << 1);
50  }
51 };
62 
68  public:
69  static std::vector<std::vector<TIRVarAxis>> GetTIRVarAxisGraph(const PrimFunc& prim_func) {
70  BufferAxisGraphExtractor extractor;
71  extractor(prim_func->body);
72  Map<Buffer, Var> inverse_buffer_map;
73  for (const auto& pr : prim_func->buffer_map) {
74  inverse_buffer_map.Set(pr.second, pr.first);
75  }
76  std::vector<std::vector<TIRVarAxis>> tir_var_axis_group_list;
77  std::unordered_set<BufferAxis, BufferAxisHash> visited;
78  for (const auto& pr : prim_func->buffer_map) {
79  Var param = pr.first;
80  Buffer buffer = pr.second;
81  for (int i = 0; i < static_cast<int>(buffer->shape.size()); i++) {
82  if (extractor.buffer_axis_graph_.count({buffer, i})) {
83  std::vector<BufferAxis> buffer_axis_group;
84  extractor.DFSGraph({buffer, i}, &visited, &buffer_axis_group);
85  if (buffer_axis_group.size() <= 1) {
86  continue;
87  }
88  std::vector<TIRVarAxis> tir_var_axis_group;
89  for (const auto& buffer_axis : buffer_axis_group) {
90  if (!inverse_buffer_map.count(buffer_axis.first)) {
91  continue;
92  }
93  tir_var_axis_group.push_back(
94  {inverse_buffer_map[buffer_axis.first], buffer_axis.second});
95  }
96  tir_var_axis_group_list.push_back(tir_var_axis_group);
97  }
98  }
99  }
100  return tir_var_axis_group_list;
101  }
102 
103  void DFSGraph(BufferAxis cur, std::unordered_set<BufferAxis, BufferAxisHash>* visited,
104  std::vector<BufferAxis>* buffer_axis_group) {
105  if (visited->count(cur)) {
106  return;
107  }
108  visited->insert(cur);
109  buffer_axis_group->push_back(cur);
110  for (const auto& next : buffer_axis_graph_[cur]) {
111  DFSGraph(next, visited, buffer_axis_group);
112  }
113  }
114 
115  private:
116  void VisitStmt_(const BufferStoreNode* op) final {
118  buffer_access_indices_.push_back({op->buffer, op->indices});
119  }
120 
121  void VisitExpr_(const BufferLoadNode* op) final {
123  buffer_access_indices_.push_back({op->buffer, op->indices});
124  }
125 
126  bool Match(PrimExpr a, PrimExpr buffer_shape_a, PrimExpr b, PrimExpr buffer_shape_b,
127  arith::Analyzer* analyzer) {
128  if (b.as<VarNode>()) {
129  std::swap(a, b);
130  std::swap(buffer_shape_a, buffer_shape_b);
131  }
132  if (!a.as<VarNode>()) {
133  return false;
134  }
135  Var var = Downcast<Var>(a);
136  analyzer->Bind(iter_var_range_);
137  b = analyzer->Simplify(b);
138  // index var `a` must access whole range of a specific buffer dimension
139  arith::IntSet intset_b = arith::EvalSet(b, arith::AsIntSet(iter_var_range_));
140  if (!analyzer->CanProveEqual(buffer_shape_a, iter_var_range_[var]->extent) ||
141  !intset_b.MatchRange(Range::FromMinExtent(0, buffer_shape_b))) {
142  return false;
143  }
144  Var matched_var = GetShardingVarFromIndex(b, iter_var_range_, analyzer);
145  if (!matched_var.same_as(var)) {
146  return false;
147  }
148  return true;
149  }
150 
151  void VisitStmt_(const BlockNode* op) final {
152  if (op->name_hint == "root") {
154  return;
155  }
156  buffer_access_indices_.clear();
158  iter_var_range_.clear();
159  for (const auto& iter_var : op->iter_vars) {
160  iter_var_range_.Set(iter_var->var, iter_var->dom);
161  }
162  arith::Analyzer analyzer;
163  for (const auto& access_pr : buffer_access_indices_) {
164  Buffer buffer = access_pr.first;
165  Array<PrimExpr> indices = access_pr.second;
166  for (int i = 0; i < static_cast<int>(indices.size()); i++) {
167  for (const auto& another_access_pr : buffer_access_indices_) {
168  if (another_access_pr.first.same_as(buffer)) {
169  continue;
170  }
171  Buffer another_buffer = another_access_pr.first;
172  Array<PrimExpr> another_indices = another_access_pr.second;
173  for (int j = 0; j < static_cast<int>(another_indices.size()); j++) {
174  if (Match(indices[i], buffer->shape[i], another_indices[j], another_buffer->shape[j],
175  &analyzer)) {
176  JoinBufferAxis({buffer, i}, {another_buffer, j});
177  }
178  }
179  }
180  }
181  }
182  }
183 
184  void JoinBufferAxis(BufferAxis axis1, BufferAxis axis2) {
185  if (!buffer_axis_graph_.count(axis1)) {
186  buffer_axis_graph_[axis1] = {};
187  }
188  if (!buffer_axis_graph_.count(axis2)) {
189  buffer_axis_graph_[axis2] = {};
190  }
191  buffer_axis_graph_[axis1].push_back(axis2);
192  buffer_axis_graph_[axis2].push_back(axis1);
193  }
194 
195  std::vector<std::pair<Buffer, Array<PrimExpr>>> buffer_access_indices_;
196  std::unordered_map<BufferAxis, std::vector<BufferAxis>, BufferAxisHash> buffer_axis_graph_;
197  Map<Var, Range> iter_var_range_;
198  std::string func_name;
199 };
200 } // namespace tir
201 } // namespace tvm
202 
203 namespace tvm {
204 namespace relax {
205 namespace distributed {
206 
208 struct Axis {
209  const ExprNode* tensor;
210  int dim = 0;
211  int tuple_index = 0;
212 
213  Axis(const ExprNode* tensor, int dim, int tuple_index = 0)
216  }
217 
218  bool operator==(const Axis& other) const {
219  return tensor == other.tensor && dim == other.dim && tuple_index == other.tuple_index;
220  }
221 };
222 
223 class AxisHash {
224  public:
225  size_t operator()(const Axis& axis) const {
226  size_t const h1(std::hash<const ExprNode*>()(axis.tensor));
227  size_t const h2(std::hash<int>()(axis.dim));
228  size_t const h3(std::hash<int>()(axis.tuple_index));
229  return h1 ^ (h2 << 1) ^ (h3 << 2);
230  }
231 };
232 
233 using AxisGroup = std::unordered_set<Axis, AxisHash>;
234 
236  public:
237  size_t operator()(const AxisGroup& axis_group) const {
238  size_t seed = 0;
239  for (auto axis : axis_group) {
240  seed ^= AxisHash()(axis) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
241  }
242  return seed;
243  }
244 };
245 
246 using ShardingSpec = std::pair<DeviceMesh, Placement>;
247 
248 // device mesh and the device mesh axis that the tensor axis maps to
249 using AxisShardingSpec = std::pair<DeviceMesh, int>;
251  public:
252  bool operator()(const AxisShardingSpec& lhs, const AxisShardingSpec& rhs) const {
253  return StructuralEqual()(lhs.first, rhs.first) && lhs.second == rhs.second;
254  }
255 };
256 
258  public:
259  size_t operator()(const AxisShardingSpec& sharding_spec) const {
260  size_t seed = 0;
261  seed ^= StructuralHash()(sharding_spec.first);
262  seed ^= std::hash<int>()(sharding_spec.second) << 1;
263  return seed;
264  }
265 };
266 
273  public:
274  enum class EdgeType { kAscend, kDescend, kSimbling };
275 
276  private:
277  static EdgeType ReverseEdgeType(EdgeType type) {
278  switch (type) {
279  case EdgeType::kAscend:
280  return EdgeType::kDescend;
281  case EdgeType::kDescend:
282  return EdgeType::kAscend;
283  case EdgeType::kSimbling:
284  return EdgeType::kSimbling;
285  }
286  LOG(FATAL) << "Unreachable code";
287  throw;
288  }
289 
290  static int GetEdgePriority(EdgeType type) {
291  switch (type) {
292  case EdgeType::kAscend:
293  return 0;
294  case EdgeType::kDescend:
295  return 2;
296  case EdgeType::kSimbling:
297  return 1;
298  }
299  LOG(FATAL) << "Unreachable code";
300  throw;
301  }
302 
303  struct AxisGraphEdge {
304  Axis src;
305  Axis dst;
306 
307  // the producer-consumer relationship between src tensor and dst tensor
308  // kAscend means consumer->producer
309  // kDescend means producer->consumer
310  // kSimbling means other cases
311  EdgeType type;
312 
313  bool operator==(const AxisGraphEdge& other) const {
314  return src == other.src && dst == other.dst && type == other.type;
315  }
316  };
317 
318  struct Path {
319  int direction = 0;
320 
321  Path AddEdge(EdgeType type) { return {direction |= (1 << GetEdgePriority(type))}; }
322 
323  int GetPriority() const {
324  switch (direction) {
325  case 1: // ascend only
326  return 0;
327  case 4: // descend only
328  return 2;
329  case 0: // empty path (source node)
330  return 3; // source node must have max priority
331  default: // mixed path
332  return 1;
333  }
334  }
335  };
336 
337  public:
338  AxisGroupGraph() = default;
339 
349  void JoinAxis(Axis axis1, Axis axis2, EdgeType type) {
350  AddEdge(axis1, axis2, type);
351  AddEdge(axis2, axis1, ReverseEdgeType(type));
352  }
353 
360  src_axis_sharding_spec_[axis] = spec;
361  }
362 
367  axis_sharding_specs_priority_.clear();
368  for (const auto& pr : src_axis_sharding_spec_) {
369  std::unordered_set<Axis, AxisHash> visited;
370  PropagateShardingSpec(pr.first, pr.second, Path(), &visited);
371  }
372  ChooseAxisShardingSpec();
373  }
374 
382  cutpoint_axis_sharding_spec_[axis] = spec;
383  }
384 
392  std::tuple<AxisShardingSpec, bool> GetAxisShardingSpec(Axis axis) {
393  if (axis_sharding_specs_priority_.count(axis)) {
394  return {axis_sharding_specs_priority_[axis].begin()->first, true};
395  } else {
396  return {{DeviceMesh(), -1}, false};
397  }
398  }
399 
400  private:
401  void AddEdge(Axis src, Axis dst, EdgeType type) {
402  if (!graph_.count(src)) {
403  graph_[src] = {};
404  }
405  graph_[src].push_back({src, dst, type});
406  }
407 
408  void PropagateShardingSpec(Axis axis, AxisShardingSpec spec, Path path,
409  std::unordered_set<Axis, AxisHash>* visited) {
410  if (cutpoint_axis_sharding_spec_.count(axis) ||
411  (src_axis_sharding_spec_.count(axis) &&
412  !AxisShardingSpecEqual()(src_axis_sharding_spec_[axis], spec)) ||
413  visited->count(axis)) {
414  return;
415  }
416  visited->insert(axis);
417  if (!axis_sharding_specs_priority_.count(axis)) {
418  axis_sharding_specs_priority_[axis] = {};
419  }
420  axis_sharding_specs_priority_[axis][spec] = path.GetPriority();
421  for (auto edge : graph_[axis]) {
422  PropagateShardingSpec(edge.dst, spec, path.AddEdge(edge.type), visited);
423  }
424  }
425 
426  void ChooseAxisShardingSpec() {
427  for (auto& pr : axis_sharding_specs_priority_) {
428  auto& axis = pr.first;
429  auto& specs = pr.second;
430  int max_priority = std::numeric_limits<int>::min();
431  for (auto& pr2 : specs) {
432  max_priority = std::max(max_priority, pr2.second);
433  }
434  for (auto it = specs.begin(); it != specs.end();) {
435  if (it->second != max_priority) {
436  it = specs.erase(it);
437  } else {
438  it++;
439  }
440  }
441  ICHECK(specs.size() == 1) << "multiple possible sharding for axis: ("
442  << GetRef<Expr>(axis.tensor) << ", " << axis.dim << ")";
443  }
444  }
445 
446  // union set
447  std::unordered_map<Axis, std::vector<AxisGraphEdge>, AxisHash> graph_;
448  std::unordered_map<Axis, AxisShardingSpec, AxisHash> src_axis_sharding_spec_;
449  std::unordered_map<Axis, AxisShardingSpec, AxisHash> cutpoint_axis_sharding_spec_;
450  std::unordered_map<
451  Axis, std::unordered_map<AxisShardingSpec, int, AxisShardingSpecHash, AxisShardingSpecEqual>,
452  AxisHash>
453  axis_sharding_specs_priority_;
454 };
455 
456 using FBuildAxisGraph = std::function<void(const Var& output_var, const Call& call,
457  distributed::AxisGroupGraph* axis_group_graph)>;
458 
459 void BuildAxisGraphUnary(const Var& output_var, const Call& call,
460  distributed::AxisGroupGraph* axis_group_graph);
461 void BuildAxisGraphBinary(const Var& output_var, const Call& call,
462  distributed::AxisGroupGraph* axis_group_graph);
463 void BuildAxisGraphReduce(const Var& output_var, const Call& call,
464  distributed::AxisGroupGraph* axis_group_graph);
465 void BuildAxisGraphMatmul(const Var& output_var, const Call& call,
466  distributed::AxisGroupGraph* axis_group_graph);
467 void BuildAxisGraphPermuteDims(const Var& output_var, const Call& call,
468  distributed::AxisGroupGraph* axis_group_graph);
469 void BuildAxisGraphReshape(const Var& output_var, const Call& call,
470  distributed::AxisGroupGraph* axis_group_graph);
471 void BuildAxisGraphCallTIR(const Var& output_var, const Call& call, const tir::PrimFunc& func,
472  distributed::AxisGroupGraph* axis_group_graph);
473 
474 } // namespace distributed
475 } // namespace relax
476 } // namespace tvm
477 
478 #endif // TVM_RELAX_DISTRIBUTED_AXIS_GROUP_GRAPH_H_
Reference to PrimExprNode.
Definition: expr.h:115
static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span=Span())
construct a new range with min and extent The corresponding constructor is removed,...
Base node of all non-primitive expressions.
Definition: expr.h:362
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:114
Content-aware structural hashing.
Definition: structural_hash.h:94
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:629
Definition: expr.h:190
Constant tensor.
Definition: expr.h:480
The variable class for all Relax bindings.
Definition: expr.h:389
Definition: expr.h:422
A graph whose nodes are tensor axes, and the edge means some information can be propagated through th...
Definition: axis_group_graph.h:272
void JoinAxis(Axis axis1, Axis axis2, EdgeType type)
add edge between two axes
Definition: axis_group_graph.h:349
void AddSrcShardingPoint(Axis axis, AxisShardingSpec spec)
add a source shardingspec to propagate
Definition: axis_group_graph.h:359
std::tuple< AxisShardingSpec, bool > GetAxisShardingSpec(Axis axis)
Get the Sharding Spec of an axis after propagation.
Definition: axis_group_graph.h:392
void PropagateShardingSpec()
propagate sharding specs from source axes
Definition: axis_group_graph.h:366
EdgeType
Definition: axis_group_graph.h:274
void AddPropagationCutPoint(Axis axis, AxisShardingSpec spec)
add a cut point that stops the propagation of a certain sharding spec
Definition: axis_group_graph.h:381
Definition: axis_group_graph.h:235
size_t operator()(const AxisGroup &axis_group) const
Definition: axis_group_graph.h:237
Definition: axis_group_graph.h:223
size_t operator()(const Axis &axis) const
Definition: axis_group_graph.h:225
Definition: axis_group_graph.h:250
bool operator()(const AxisShardingSpec &lhs, const AxisShardingSpec &rhs) const
Definition: axis_group_graph.h:252
Definition: axis_group_graph.h:257
size_t operator()(const AxisShardingSpec &sharding_spec) const
Definition: axis_group_graph.h:259
Managed reference to a DeviceMesh.
Definition: global_info.h:81
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
size_t count(const K &key) const
Definition: map.h:1356
void Set(const K &key, const V &value)
set the Map.
Definition: map.h:1374
bool IsInstance() const
Definition: object.h:874
Construct an axis group graph from a PrimFunc. Two buffer axis are connected if they are accessed by ...
Definition: axis_group_graph.h:67
static std::vector< std::vector< TIRVarAxis > > GetTIRVarAxisGraph(const PrimFunc &prim_func)
Definition: axis_group_graph.h:69
void DFSGraph(BufferAxis cur, std::unordered_set< BufferAxis, BufferAxisHash > *visited, std::vector< BufferAxis > *buffer_axis_group)
Definition: axis_group_graph.h:103
Definition: axis_group_graph.h:44
size_t operator()(const BufferAxis &buffer_axis) const
Definition: axis_group_graph.h:46
Store value to the high dimension buffer.
Definition: stmt.h:226
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:174
void VisitExpr_(const VarNode *op) override
Managed reference to PrimFuncNode.
Definition: function.h:145
Visitor that recursively visit stmts and exprs on them.
Definition: stmt_functor.h:295
void VisitStmt_(const AttrStmtNode *op) override
a named variable in TIR
Definition: var.h:89
Struct info for DTensor (Distributed Tensor)
Iterator quasi-affine mapping patterns.
IntSet EvalSet(PrimExpr e, const Map< IterVar, IntSet > &dom_map)
Find an symbolic integer set that contains all possible values of e given the domain of each iteratio...
Map< Var, arith::IntSet > AsIntSet(const Map< Var, Range > &var_dom)
Converts the Ranges to IntSets.
void BuildAxisGraphBinary(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
std::unordered_set< Axis, AxisHash > AxisGroup
Definition: axis_group_graph.h:233
void BuildAxisGraphUnary(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
void BuildAxisGraphCallTIR(const Var &output_var, const Call &call, const tir::PrimFunc &func, distributed::AxisGroupGraph *axis_group_graph)
std::function< void(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)> FBuildAxisGraph
Definition: axis_group_graph.h:457
std::pair< DeviceMesh, int > AxisShardingSpec
Definition: axis_group_graph.h:249
void BuildAxisGraphMatmul(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
void BuildAxisGraphReshape(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
void BuildAxisGraphReduce(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
void BuildAxisGraphPermuteDims(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
std::pair< DeviceMesh, Placement > ShardingSpec
Definition: axis_group_graph.h:246
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
std::pair< Buffer, int > BufferAxis
Definition: axis_group_graph.h:43
Var GetShardingVarFromIndex(PrimExpr index, Map< Var, Range > var_range, arith::Analyzer *analyzer)
Suppose we want to shard a buffer along a specific dimension, we need to know how to rewrite the acce...
std::pair< Var, int > TIRVarAxis
Definition: axis_group_graph.h:41
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr operator==(PrimExpr a, PrimExpr b)
equal
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
Functors for tir stmts utility functions to call common functors.
tensor axis
Definition: axis_group_graph.h:208
Axis(const ExprNode *tensor, int dim, int tuple_index=0)
Definition: axis_group_graph.h:213
int dim
Definition: axis_group_graph.h:210
bool operator==(const Axis &other) const
Definition: axis_group_graph.h:218
int tuple_index
Definition: axis_group_graph.h:211
const ExprNode * tensor
Definition: axis_group_graph.h:209
ObjectRef hash functor.
Definition: object.h:655
TIR Function.