20 #ifndef TVM_RELAX_DISTRIBUTED_AXIS_GROUP_GRAPH_H_
21 #define TVM_RELAX_DISTRIBUTED_AXIS_GROUP_GRAPH_H_
33 #include <unordered_map>
34 #include <unordered_set>
48 size_t const h2(std::hash<int>()(buffer_axis.second));
49 return h1 ^ (h2 << 1);
71 extractor(prim_func->body);
73 for (
const auto& pr : prim_func->buffer_map) {
74 inverse_buffer_map.
Set(pr.second, pr.first);
76 std::vector<std::vector<TIRVarAxis>> tir_var_axis_group_list;
77 std::unordered_set<BufferAxis, BufferAxisHash> visited;
78 for (
const auto& pr : prim_func->buffer_map) {
81 for (
int i = 0; i < static_cast<int>(buffer->shape.size()); i++) {
82 if (extractor.buffer_axis_graph_.count({buffer, i})) {
83 std::vector<BufferAxis> buffer_axis_group;
84 extractor.
DFSGraph({buffer, i}, &visited, &buffer_axis_group);
85 if (buffer_axis_group.size() <= 1) {
88 std::vector<TIRVarAxis> tir_var_axis_group;
89 for (
const auto& buffer_axis : buffer_axis_group) {
90 if (!inverse_buffer_map.
count(buffer_axis.first)) {
93 tir_var_axis_group.push_back(
94 {inverse_buffer_map[buffer_axis.first], buffer_axis.second});
96 tir_var_axis_group_list.push_back(tir_var_axis_group);
100 return tir_var_axis_group_list;
104 std::vector<BufferAxis>* buffer_axis_group) {
105 if (visited->count(cur)) {
108 visited->insert(cur);
109 buffer_axis_group->push_back(cur);
110 for (
const auto& next : buffer_axis_graph_[cur]) {
111 DFSGraph(next, visited, buffer_axis_group);
118 buffer_access_indices_.push_back({op->buffer, op->indices});
121 void VisitExpr_(
const BufferLoadNode* op)
final {
123 buffer_access_indices_.push_back({op->buffer, op->indices});
126 bool Match(PrimExpr a, PrimExpr buffer_shape_a, PrimExpr b, PrimExpr buffer_shape_b,
127 arith::Analyzer* analyzer) {
128 if (b.as<VarNode>()) {
130 std::swap(buffer_shape_a, buffer_shape_b);
132 if (!a.as<VarNode>()) {
135 Var
var = Downcast<Var>(a);
136 analyzer->Bind(iter_var_range_);
137 b = analyzer->Simplify(b);
140 if (!analyzer->CanProveEqual(buffer_shape_a, iter_var_range_[
var]->extent) ||
145 if (!matched_var.same_as(
var)) {
151 void VisitStmt_(
const BlockNode* op)
final {
152 if (op->name_hint ==
"root") {
156 buffer_access_indices_.clear();
158 iter_var_range_.clear();
159 for (
const auto& iter_var : op->iter_vars) {
160 iter_var_range_.Set(iter_var->var, iter_var->dom);
162 arith::Analyzer analyzer;
163 for (
const auto& access_pr : buffer_access_indices_) {
164 Buffer buffer = access_pr.first;
165 Array<PrimExpr> indices = access_pr.second;
166 for (
int i = 0; i < static_cast<int>(indices.size()); i++) {
167 for (
const auto& another_access_pr : buffer_access_indices_) {
168 if (another_access_pr.first.same_as(buffer)) {
171 Buffer another_buffer = another_access_pr.first;
172 Array<PrimExpr> another_indices = another_access_pr.second;
173 for (
int j = 0; j < static_cast<int>(another_indices.size()); j++) {
174 if (Match(indices[i], buffer->shape[i], another_indices[j], another_buffer->shape[j],
176 JoinBufferAxis({buffer, i}, {another_buffer, j});
185 if (!buffer_axis_graph_.count(axis1)) {
186 buffer_axis_graph_[axis1] = {};
188 if (!buffer_axis_graph_.count(axis2)) {
189 buffer_axis_graph_[axis2] = {};
191 buffer_axis_graph_[axis1].push_back(axis2);
192 buffer_axis_graph_[axis2].push_back(axis1);
195 std::vector<std::pair<Buffer, Array<PrimExpr>>> buffer_access_indices_;
196 std::unordered_map<BufferAxis, std::vector<BufferAxis>, BufferAxisHash> buffer_axis_graph_;
197 Map<Var, Range> iter_var_range_;
198 std::string func_name;
205 namespace distributed {
226 size_t const h1(std::hash<const ExprNode*>()(axis.
tensor));
227 size_t const h2(std::hash<int>()(axis.
dim));
228 size_t const h3(std::hash<int>()(axis.
tuple_index));
229 return h1 ^ (h2 << 1) ^ (h3 << 2);
239 for (
auto axis : axis_group) {
240 seed ^=
AxisHash()(axis) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
253 return StructuralEqual()(lhs.first, rhs.first) && lhs.second == rhs.second;
262 seed ^= std::hash<int>()(sharding_spec.second) << 1;
286 LOG(FATAL) <<
"Unreachable code";
290 static int GetEdgePriority(
EdgeType type) {
299 LOG(FATAL) <<
"Unreachable code";
303 struct AxisGraphEdge {
313 bool operator==(
const AxisGraphEdge& other)
const {
314 return src == other.src && dst == other.dst && type == other.type;
321 Path AddEdge(
EdgeType type) {
return {direction |= (1 << GetEdgePriority(type))}; }
323 int GetPriority()
const {
350 AddEdge(axis1, axis2, type);
351 AddEdge(axis2, axis1, ReverseEdgeType(type));
360 src_axis_sharding_spec_[axis] = spec;
367 axis_sharding_specs_priority_.clear();
368 for (
const auto& pr : src_axis_sharding_spec_) {
369 std::unordered_set<Axis, AxisHash> visited;
372 ChooseAxisShardingSpec();
382 cutpoint_axis_sharding_spec_[axis] = spec;
393 if (axis_sharding_specs_priority_.count(axis)) {
394 return {axis_sharding_specs_priority_[axis].begin()->first,
true};
402 if (!graph_.count(src)) {
405 graph_[src].push_back({src, dst, type});
409 std::unordered_set<Axis, AxisHash>* visited) {
410 if (cutpoint_axis_sharding_spec_.count(axis) ||
411 (src_axis_sharding_spec_.count(axis) &&
412 !AxisShardingSpecEqual()(src_axis_sharding_spec_[axis], spec)) ||
413 visited->count(axis)) {
416 visited->insert(axis);
417 if (!axis_sharding_specs_priority_.count(axis)) {
418 axis_sharding_specs_priority_[axis] = {};
420 axis_sharding_specs_priority_[axis][spec] = path.GetPriority();
421 for (
auto edge : graph_[axis]) {
426 void ChooseAxisShardingSpec() {
427 for (
auto& pr : axis_sharding_specs_priority_) {
428 auto& axis = pr.first;
429 auto& specs = pr.second;
431 for (
auto& pr2 : specs) {
432 max_priority =
std::max(max_priority, pr2.second);
434 for (
auto it = specs.begin(); it != specs.end();) {
435 if (it->second != max_priority) {
436 it = specs.erase(it);
441 ICHECK(specs.size() == 1) <<
"multiple possible sharding for axis: ("
442 << GetRef<Expr>(axis.tensor) <<
", " << axis.dim <<
")";
447 std::unordered_map<Axis, std::vector<AxisGraphEdge>, AxisHash> graph_;
448 std::unordered_map<Axis, AxisShardingSpec, AxisHash> src_axis_sharding_spec_;
449 std::unordered_map<Axis, AxisShardingSpec, AxisHash> cutpoint_axis_sharding_spec_;
451 Axis, std::unordered_map<AxisShardingSpec, int, AxisShardingSpecHash, AxisShardingSpecEqual>,
453 axis_sharding_specs_priority_;
Reference to PrimExprNode.
Definition: expr.h:115
static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span=Span())
construct a new range with min and extent The corresponding constructor is removed,...
Base node of all non-primitive expressions.
Definition: expr.h:362
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:114
Content-aware structural hashing.
Definition: structural_hash.h:94
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:629
Constant tensor.
Definition: expr.h:480
The variable class for all Relax bindings.
Definition: expr.h:389
A graph whose nodes are tensor axes, and the edge means some information can be propagated through th...
Definition: axis_group_graph.h:272
void JoinAxis(Axis axis1, Axis axis2, EdgeType type)
add edge between two axes
Definition: axis_group_graph.h:349
void AddSrcShardingPoint(Axis axis, AxisShardingSpec spec)
add a source shardingspec to propagate
Definition: axis_group_graph.h:359
std::tuple< AxisShardingSpec, bool > GetAxisShardingSpec(Axis axis)
Get the Sharding Spec of an axis after propagation.
Definition: axis_group_graph.h:392
void PropagateShardingSpec()
propagate sharding specs from source axes
Definition: axis_group_graph.h:366
EdgeType
Definition: axis_group_graph.h:274
void AddPropagationCutPoint(Axis axis, AxisShardingSpec spec)
add a cut point that stops the propagation of a certain sharding spec
Definition: axis_group_graph.h:381
Definition: axis_group_graph.h:235
size_t operator()(const AxisGroup &axis_group) const
Definition: axis_group_graph.h:237
Definition: axis_group_graph.h:223
size_t operator()(const Axis &axis) const
Definition: axis_group_graph.h:225
Definition: axis_group_graph.h:250
bool operator()(const AxisShardingSpec &lhs, const AxisShardingSpec &rhs) const
Definition: axis_group_graph.h:252
Definition: axis_group_graph.h:257
size_t operator()(const AxisShardingSpec &sharding_spec) const
Definition: axis_group_graph.h:259
Managed reference to a DeviceMesh.
Definition: global_info.h:81
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
size_t count(const K &key) const
Definition: map.h:1356
void Set(const K &key, const V &value)
set the Map.
Definition: map.h:1374
bool IsInstance() const
Definition: object.h:874
Definition: axis_group_graph.h:44
size_t operator()(const BufferAxis &buffer_axis) const
Definition: axis_group_graph.h:46
Store value to the high dimension buffer.
Definition: stmt.h:226
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:174
void VisitExpr_(const VarNode *op) override
Managed reference to PrimFuncNode.
Definition: function.h:145
Visitor that recursively visit stmts and exprs on them.
Definition: stmt_functor.h:295
void VisitStmt_(const AttrStmtNode *op) override
a named variable in TIR
Definition: var.h:89
Struct info for DTensor (Distributed Tensor)
Iterator quasi-affine mapping patterns.
IntSet EvalSet(PrimExpr e, const Map< IterVar, IntSet > &dom_map)
Find an symbolic integer set that contains all possible values of e given the domain of each iteratio...
Map< Var, arith::IntSet > AsIntSet(const Map< Var, Range > &var_dom)
Converts the Ranges to IntSets.
void BuildAxisGraphBinary(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
std::unordered_set< Axis, AxisHash > AxisGroup
Definition: axis_group_graph.h:233
void BuildAxisGraphUnary(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
void BuildAxisGraphCallTIR(const Var &output_var, const Call &call, const tir::PrimFunc &func, distributed::AxisGroupGraph *axis_group_graph)
std::function< void(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)> FBuildAxisGraph
Definition: axis_group_graph.h:457
std::pair< DeviceMesh, int > AxisShardingSpec
Definition: axis_group_graph.h:249
void BuildAxisGraphMatmul(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
void BuildAxisGraphReshape(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
void BuildAxisGraphReduce(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
void BuildAxisGraphPermuteDims(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
std::pair< DeviceMesh, Placement > ShardingSpec
Definition: axis_group_graph.h:246
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
std::pair< Buffer, int > BufferAxis
Definition: axis_group_graph.h:43
Var GetShardingVarFromIndex(PrimExpr index, Map< Var, Range > var_range, arith::Analyzer *analyzer)
Suppose we want to shard a buffer along a specific dimension, we need to know how to rewrite the acce...
std::pair< Var, int > TIRVarAxis
Definition: axis_group_graph.h:41
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr operator==(PrimExpr a, PrimExpr b)
equal
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
Functors for tir stmts utility functions to call common functors.
tensor axis
Definition: axis_group_graph.h:208
Axis(const ExprNode *tensor, int dim, int tuple_index=0)
Definition: axis_group_graph.h:213
int dim
Definition: axis_group_graph.h:210
bool operator==(const Axis &other) const
Definition: axis_group_graph.h:218
int tuple_index
Definition: axis_group_graph.h:211
const ExprNode * tensor
Definition: axis_group_graph.h:209
ObjectRef hash functor.
Definition: object.h:655