tvm
struct_info.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_RELAX_DISTRIBUTED_STRUCT_INFO_H_
26 #define TVM_RELAX_DISTRIBUTED_STRUCT_INFO_H_
27 
29 #include <tvm/relax/struct_info.h>
30 namespace tvm {
31 namespace relax {
32 namespace distributed {
33 
34 enum class PlacementSpecKind : int { kSharding = 0, kReplica = 1 };
35 
37 class PlacementSpecNode : public Object {
38  public:
42  int axis;
43 
46 
48  v->Visit("axis", &axis);
49  v->Visit("kind", &kind);
50  }
51  bool SEqualReduce(const PlacementSpecNode* other, SEqualReducer equal) const {
52  return equal(axis, other->axis) && equal(kind, other->kind);
53  }
54 
55  void SHashReduce(SHashReducer hash_reduce) const {
56  hash_reduce(axis);
57  hash_reduce(static_cast<int>(kind));
58  }
59 
60  static constexpr const char* _type_key = "relax.distributed.PlacementSpec";
61  static constexpr const bool _type_has_method_sequal_reduce = true;
62  static constexpr const bool _type_has_method_shash_reduce = true;
64 };
65 
70 class PlacementSpec : public ObjectRef {
71  public:
72  TVM_DLL static PlacementSpec Sharding(int axis);
73 
74  TVM_DLL static PlacementSpec Replica();
75 
77 };
78 
80  public:
83 
84  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("sharding_dim", &sharding_dim); }
85 
86  bool SEqualReduce(const ShardingNode* other, SEqualReducer equal) const {
87  return equal(sharding_dim, other->sharding_dim);
88  }
89 
90  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(sharding_dim); }
91  static constexpr const char* _type_key = "relax.distributed.Sharding";
93 };
94 
96 class PlacementNode : public Object {
97  public:
100 
101  String ToString() const;
102 
103  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("dim_specs", &dim_specs); }
104 
105  bool SEqualReduce(const PlacementNode* other, SEqualReducer equal) const {
106  return equal(dim_specs, other->dim_specs);
107  }
108 
109  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dim_specs); }
110 
111  static constexpr const bool _type_has_method_sequal_reduce = true;
112  static constexpr const bool _type_has_method_shash_reduce = true;
113  static constexpr const char* _type_key = "relax.distributed.Placement";
115 };
116 
121 class Placement : public ObjectRef {
122  public:
123  TVM_DLL explicit Placement(Array<PlacementSpec> dim_specs);
125  static Placement FromText(String text_repr);
127 };
128 
133  public:
146 
148  v->Visit("device_mesh", &device_mesh);
149  v->Visit("placement", &placement);
150  v->Visit("tensor_sinfo", &tensor_sinfo);
151  v->Visit("span", &span);
152  }
153 
155  return equal(tensor_sinfo, other->tensor_sinfo) && equal(device_mesh, other->device_mesh) &&
156  equal(placement, other->placement);
157  }
158 
159  void SHashReduce(SHashReducer hash_reduce) const {
160  hash_reduce(tensor_sinfo);
161  hash_reduce(device_mesh);
162  hash_reduce(placement);
163  }
164 
165  static constexpr const char* _type_key = "relax.DTensorStructInfo";
167 };
168 
174  public:
182  TVM_DLL DTensorStructInfo(TensorStructInfo tensor_sinfo, DeviceMesh device_mesh,
183  Placement placement, Span span = Span());
184 
186 };
187 
188 } // namespace distributed
189 } // namespace relax
190 } // namespace tvm
191 
192 #endif // TVM_RELAX_DISTRIBUTED_STRUCT_INFO_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Container of constant int that adds more constructors.
Definition: expr.h:632
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:137
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
Definition: source_map.h:120
Base type of all structure information.
Definition: expr.h:110
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:116
Managed reference to StructInfoNode.
Definition: expr.h:129
Managed reference to TensorStructInfoNode.
Definition: struct_info.h:223
StructInfo of DTensor (Distributed Tensor).
Definition: struct_info.h:132
void VisitAttrs(AttrVisitor *v)
Definition: struct_info.h:147
TensorStructInfo tensor_sinfo
The struct info inherited from TensorStructInfo.
Definition: struct_info.h:137
Placement placement
The placement of the tensor among the device mesh.
Definition: struct_info.h:145
bool SEqualReduce(const DTensorStructInfoNode *other, SEqualReducer equal) const
Definition: struct_info.h:154
TVM_DECLARE_FINAL_OBJECT_INFO(DTensorStructInfoNode, StructInfoNode)
DeviceMesh device_mesh
The device mesh of the tensor.
Definition: struct_info.h:141
void SHashReduce(SHashReducer hash_reduce) const
Definition: struct_info.h:159
static constexpr const char * _type_key
Definition: struct_info.h:165
Managed reference to DTensorStructInfoNode.
Definition: struct_info.h:173
DTensorStructInfo(TensorStructInfo tensor_sinfo, DeviceMesh device_mesh, Placement placement, Span span=Span())
Construction with device mesh and placement.
TVM_DEFINE_OBJECT_REF_METHODS(DTensorStructInfo, StructInfo, DTensorStructInfoNode)
Managed reference to a DeviceMesh.
Definition: global_info.h:81
Describes how data is distributed in each dimension of the device mesh.
Definition: struct_info.h:96
TVM_DECLARE_FINAL_OBJECT_INFO(PlacementNode, Object)
static constexpr const bool _type_has_method_sequal_reduce
Definition: struct_info.h:111
static constexpr const char * _type_key
Definition: struct_info.h:113
void SHashReduce(SHashReducer hash_reduce) const
Definition: struct_info.h:109
Array< PlacementSpec > dim_specs
specs for each dim of device mesh.
Definition: struct_info.h:99
static constexpr const bool _type_has_method_shash_reduce
Definition: struct_info.h:112
void VisitAttrs(tvm::AttrVisitor *v)
Definition: struct_info.h:103
bool SEqualReduce(const PlacementNode *other, SEqualReducer equal) const
Definition: struct_info.h:105
Describes how data is distributed in one dimension of the device mesh.
Definition: struct_info.h:37
static constexpr const bool _type_has_method_shash_reduce
Definition: struct_info.h:62
PlacementSpecKind kind
The kind of placement spec. Possible values: kSharding and kReplica.
Definition: struct_info.h:45
TVM_DECLARE_BASE_OBJECT_INFO(PlacementSpecNode, Object)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: struct_info.h:47
void SHashReduce(SHashReducer hash_reduce) const
Definition: struct_info.h:55
static constexpr const bool _type_has_method_sequal_reduce
Definition: struct_info.h:61
static constexpr const char * _type_key
Definition: struct_info.h:60
bool SEqualReduce(const PlacementSpecNode *other, SEqualReducer equal) const
Definition: struct_info.h:51
int axis
If the kind is sharding, this value represents the tensor dimension to shard. otherwise,...
Definition: struct_info.h:42
Managed reference to PlacementSpecNode.
Definition: struct_info.h:70
TVM_DEFINE_OBJECT_REF_METHODS(PlacementSpec, ObjectRef, PlacementSpecNode)
static PlacementSpec Sharding(int axis)
Managed reference to a Placement.
Definition: struct_info.h:121
static Placement FromText(String text_repr)
replica dim is printed as "R" and sharding dim is printed as "S[i]".]
Placement(Array< PlacementSpec > dim_specs)
TVM_DEFINE_OBJECT_REF_METHODS(Placement, ObjectRef, PlacementNode)
Definition: struct_info.h:79
TVM_DECLARE_FINAL_OBJECT_INFO(ShardingNode, PlacementSpecNode)
bool SEqualReduce(const ShardingNode *other, SEqualReducer equal) const
Definition: struct_info.h:86
void VisitAttrs(tvm::AttrVisitor *v)
Definition: struct_info.h:84
void SHashReduce(SHashReducer hash_reduce) const
Definition: struct_info.h:90
static constexpr const char * _type_key
Definition: struct_info.h:91
Integer sharding_dim
The dimension of tensor we shard.
Definition: struct_info.h:82
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Reference to string objects.
Definition: string.h:98
PlacementSpecKind
Definition: struct_info.h:34
tvm::Span Span
Definition: base.h:65
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
Data structure for distributed inference.