20 #ifndef TVM_RELAX_DISTRIBUTED_AXIS_GROUP_GRAPH_H_
21 #define TVM_RELAX_DISTRIBUTED_AXIS_GROUP_GRAPH_H_
33 #include <unordered_map>
34 #include <unordered_set>
47 size_t const h1(ObjectPtrHash()(buffer_axis.first));
48 size_t const h2(std::hash<int>()(buffer_axis.second));
49 return h1 ^ (h2 << 1);
72 extractor(prim_func->body);
73 ffi::Map<Buffer, Var> inverse_buffer_map;
74 for (
const auto& pr : prim_func->buffer_map) {
75 inverse_buffer_map.Set(pr.second, pr.first);
77 std::vector<std::vector<TIRVarAxis>> tir_var_axis_group_list;
78 std::unordered_set<BufferAxis, BufferAxisHash> visited;
79 for (
const auto& pr : prim_func->buffer_map) {
82 for (
int i = 0; i < static_cast<int>(buffer->shape.size()); i++) {
83 if (extractor.buffer_axis_graph_.count({buffer, i})) {
84 std::vector<BufferAxis> buffer_axis_group;
85 extractor.
DFSGraph({buffer, i}, &visited, &buffer_axis_group);
86 if (buffer_axis_group.size() <= 1) {
89 std::vector<TIRVarAxis> tir_var_axis_group;
90 for (
const auto& buffer_axis : buffer_axis_group) {
91 if (!inverse_buffer_map.count(buffer_axis.first)) {
94 tir_var_axis_group.push_back(
95 {inverse_buffer_map[buffer_axis.first], buffer_axis.second});
97 tir_var_axis_group_list.push_back(tir_var_axis_group);
101 return tir_var_axis_group_list;
105 std::vector<BufferAxis>* buffer_axis_group) {
106 if (visited->count(cur)) {
109 visited->insert(cur);
110 buffer_axis_group->push_back(cur);
111 for (
const auto& next : buffer_axis_graph_[cur]) {
112 DFSGraph(next, visited, buffer_axis_group);
119 buffer_access_indices_.push_back({op->buffer, op->indices});
122 void VisitExpr_(
const BufferLoadNode* op)
final {
124 buffer_access_indices_.push_back({op->buffer, op->indices});
127 bool Match(PrimExpr a, PrimExpr buffer_shape_a, PrimExpr b, PrimExpr buffer_shape_b,
128 arith::Analyzer* analyzer) {
129 if (b.as<VarNode>()) {
131 std::swap(buffer_shape_a, buffer_shape_b);
133 if (!a.as<VarNode>()) {
136 Var
var = Downcast<Var>(a);
137 analyzer->Bind(iter_var_range_);
138 b = analyzer->Simplify(b);
141 if (!analyzer->CanProveEqual(buffer_shape_a, iter_var_range_[
var]->extent) ||
146 if (!matched_var.same_as(
var)) {
152 void VisitStmt_(
const BlockNode* op)
final {
153 if (op->name_hint ==
"root") {
157 buffer_access_indices_.clear();
159 iter_var_range_.clear();
160 for (
const auto& iter_var : op->iter_vars) {
161 iter_var_range_.Set(iter_var->var, iter_var->dom);
163 arith::Analyzer analyzer;
164 for (
const auto& access_pr : buffer_access_indices_) {
165 Buffer buffer = access_pr.first;
166 ffi::Array<PrimExpr> indices = access_pr.second;
167 for (
int i = 0; i < static_cast<int>(indices.size()); i++) {
168 for (
const auto& another_access_pr : buffer_access_indices_) {
169 if (another_access_pr.first.same_as(buffer)) {
172 Buffer another_buffer = another_access_pr.first;
173 ffi::Array<PrimExpr> another_indices = another_access_pr.second;
174 for (
int j = 0; j < static_cast<int>(another_indices.size()); j++) {
175 if (Match(indices[i], buffer->shape[i], another_indices[j], another_buffer->shape[j],
177 JoinBufferAxis({buffer, i}, {another_buffer, j});
186 if (!buffer_axis_graph_.count(axis1)) {
187 buffer_axis_graph_[axis1] = {};
189 if (!buffer_axis_graph_.count(axis2)) {
190 buffer_axis_graph_[axis2] = {};
192 buffer_axis_graph_[axis1].push_back(axis2);
193 buffer_axis_graph_[axis2].push_back(axis1);
196 std::vector<std::pair<Buffer, ffi::Array<PrimExpr>>> buffer_access_indices_;
197 std::unordered_map<BufferAxis, std::vector<BufferAxis>, BufferAxisHash> buffer_axis_graph_;
198 ffi::Map<Var, Range> iter_var_range_;
199 std::string func_name;
206 namespace distributed {
227 size_t const h1(std::hash<const ExprNode*>()(axis.
tensor));
228 size_t const h2(std::hash<int>()(axis.
dim));
229 size_t const h3(std::hash<int>()(axis.
tuple_index));
230 return h1 ^ (h2 << 1) ^ (h3 << 2);
240 for (
auto axis : axis_group) {
241 seed ^=
AxisHash()(axis) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
254 return StructuralEqual()(lhs.first, rhs.first) && lhs.second == rhs.second;
263 seed ^= std::hash<int>()(sharding_spec.second) << 1;
287 LOG(FATAL) <<
"Unreachable code";
291 static int GetEdgePriority(
EdgeType type) {
300 LOG(FATAL) <<
"Unreachable code";
304 struct AxisGraphEdge {
314 bool operator==(
const AxisGraphEdge& other)
const {
315 return src == other.src && dst == other.dst && type == other.type;
322 Path AddEdge(
EdgeType type) {
return {direction |= (1 << GetEdgePriority(type))}; }
324 int GetPriority()
const {
351 AddEdge(axis1, axis2, type);
352 AddEdge(axis2, axis1, ReverseEdgeType(type));
361 src_axis_sharding_spec_[axis] = spec;
368 axis_sharding_specs_priority_.clear();
369 for (
const auto& pr : src_axis_sharding_spec_) {
370 std::unordered_set<Axis, AxisHash> visited;
373 ChooseAxisShardingSpec();
383 cutpoint_axis_sharding_spec_[axis] = spec;
394 if (axis_sharding_specs_priority_.count(axis)) {
395 return {axis_sharding_specs_priority_[axis].begin()->first,
true};
403 if (!graph_.count(src)) {
406 graph_[src].push_back({src, dst, type});
410 std::unordered_set<Axis, AxisHash>* visited) {
411 if (cutpoint_axis_sharding_spec_.count(axis) ||
412 (src_axis_sharding_spec_.count(axis) &&
413 !AxisShardingSpecEqual()(src_axis_sharding_spec_[axis], spec)) ||
414 visited->count(axis)) {
417 visited->insert(axis);
418 if (!axis_sharding_specs_priority_.count(axis)) {
419 axis_sharding_specs_priority_[axis] = {};
421 axis_sharding_specs_priority_[axis][spec] = path.GetPriority();
422 for (
auto edge : graph_[axis]) {
427 void ChooseAxisShardingSpec() {
428 for (
auto& pr : axis_sharding_specs_priority_) {
429 auto& axis = pr.first;
430 auto& specs = pr.second;
432 for (
auto& pr2 : specs) {
433 max_priority =
std::max(max_priority, pr2.second);
435 for (
auto it = specs.begin(); it != specs.end();) {
436 if (it->second != max_priority) {
437 it = specs.erase(it);
442 ICHECK(specs.size() == 1) <<
"multiple possible sharding for axis: ("
443 << ffi::GetRef<Expr>(axis.tensor) <<
", " << axis.dim <<
")";
448 std::unordered_map<Axis, std::vector<AxisGraphEdge>, AxisHash> graph_;
449 std::unordered_map<Axis, AxisShardingSpec, AxisHash> src_axis_sharding_spec_;
450 std::unordered_map<Axis, AxisShardingSpec, AxisHash> cutpoint_axis_sharding_spec_;
452 Axis, std::unordered_map<AxisShardingSpec, int, AxisShardingSpecHash, AxisShardingSpecEqual>,
454 axis_sharding_specs_priority_;
Reference to PrimExprNode.
Definition: expr.h:124
static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span=Span())
construct a new range with min and extent The corresponding constructor is removed,...
Base node of all non-primitive expressions.
Definition: expr.h:416
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:97
Content-aware structural hashing.
Definition: structural_hash.h:116
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:634
Constant tensor.
Definition: expr.h:422
The variable class for all Relax bindings.
Definition: expr.h:340
A graph whose nodes are tensor axes, and the edge means some information can be propagated through th...
Definition: axis_group_graph.h:273
void JoinAxis(Axis axis1, Axis axis2, EdgeType type)
add edge between two axes
Definition: axis_group_graph.h:350
void AddSrcShardingPoint(Axis axis, AxisShardingSpec spec)
add a source shardingspec to propagate
Definition: axis_group_graph.h:360
std::tuple< AxisShardingSpec, bool > GetAxisShardingSpec(Axis axis)
Get the Sharding Spec of an axis after propagation.
Definition: axis_group_graph.h:393
void PropagateShardingSpec()
propagate sharding specs from source axes
Definition: axis_group_graph.h:367
EdgeType
Definition: axis_group_graph.h:275
void AddPropagationCutPoint(Axis axis, AxisShardingSpec spec)
add a cut point that stops the propagation of a certain sharding spec
Definition: axis_group_graph.h:382
Definition: axis_group_graph.h:236
size_t operator()(const AxisGroup &axis_group) const
Definition: axis_group_graph.h:238
Definition: axis_group_graph.h:224
size_t operator()(const Axis &axis) const
Definition: axis_group_graph.h:226
Definition: axis_group_graph.h:251
bool operator()(const AxisShardingSpec &lhs, const AxisShardingSpec &rhs) const
Definition: axis_group_graph.h:253
Definition: axis_group_graph.h:258
size_t operator()(const AxisShardingSpec &sharding_spec) const
Definition: axis_group_graph.h:260
Managed reference to a DeviceMesh.
Definition: global_info.h:62
Definition: axis_group_graph.h:44
size_t operator()(const BufferAxis &buffer_axis) const
Definition: axis_group_graph.h:46
Store value to the high dimension buffer.
Definition: stmt.h:194
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:156
void VisitExpr_(const VarNode *op) override
Managed reference to PrimFuncNode.
Definition: function.h:129
Visitor that recursively visit stmts and exprs on them.
Definition: stmt_functor.h:285
void VisitStmt_(const AttrStmtNode *op) override
a named variable in TIR
Definition: var.h:77
Struct info for DTensor (Distributed Tensor)
Iterator quasi-affine mapping patterns.
ffi::Map< Var, arith::IntSet > AsIntSet(const ffi::Map< Var, Range > &var_dom)
Converts the Ranges to IntSets.
IntSet EvalSet(PrimExpr e, const ffi::Map< IterVar, IntSet > &dom_map)
Find an symbolic integer set that contains all possible values of e given the domain of each iteratio...
void BuildAxisGraphBinary(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
std::unordered_set< Axis, AxisHash > AxisGroup
Definition: axis_group_graph.h:234
void BuildAxisGraphUnary(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
void BuildAxisGraphCallTIR(const Var &output_var, const Call &call, const tir::PrimFunc &func, distributed::AxisGroupGraph *axis_group_graph)
std::function< void(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)> FBuildAxisGraph
Definition: axis_group_graph.h:458
std::pair< DeviceMesh, int > AxisShardingSpec
Definition: axis_group_graph.h:250
void BuildAxisGraphMatmul(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
void BuildAxisGraphReshape(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
void BuildAxisGraphReduce(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
void BuildAxisGraphPermuteDims(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
std::pair< DeviceMesh, Placement > ShardingSpec
Definition: axis_group_graph.h:247
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
std::pair< Buffer, int > BufferAxis
Definition: axis_group_graph.h:43
Var GetShardingVarFromIndex(PrimExpr index, ffi::Map< Var, Range > var_range, arith::Analyzer *analyzer)
Suppose we want to shard a buffer along a specific dimension, we need to know how to rewrite the acce...
std::pair< Var, int > TIRVarAxis
Definition: axis_group_graph.h:41
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr operator==(PrimExpr a, PrimExpr b)
equal
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
Functors for tir stmts utility functions to call common functors.
tensor axis
Definition: axis_group_graph.h:209
Axis(const ExprNode *tensor, int dim, int tuple_index=0)
Definition: axis_group_graph.h:214
int dim
Definition: axis_group_graph.h:211
bool operator==(const Axis &other) const
Definition: axis_group_graph.h:219
int tuple_index
Definition: axis_group_graph.h:212
const ExprNode * tensor
Definition: axis_group_graph.h:210