tvm
struct_info.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_RELAX_DISTRIBUTED_STRUCT_INFO_H_
26 #define TVM_RELAX_DISTRIBUTED_STRUCT_INFO_H_
27 
29 #include <tvm/relax/struct_info.h>
30 namespace tvm {
31 namespace relax {
32 namespace distributed {
33 
34 enum class PlacementSpecKind : int { kSharding = 0, kReplica = 1 };
35 
37 class PlacementSpecNode : public Object {
38  public:
42  int axis;
43 
46 
47  static void RegisterReflection() {
48  namespace refl = tvm::ffi::reflection;
49  refl::ObjectDef<PlacementSpecNode>()
50  .def_ro("axis", &PlacementSpecNode::axis)
51  .def_ro("kind", &PlacementSpecNode::kind);
52  }
53 
54  static constexpr const char* _type_key = "relax.distributed.PlacementSpec";
55  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode;
57 };
58 
63 class PlacementSpec : public ObjectRef {
64  public:
65  TVM_DLL static PlacementSpec Sharding(int axis);
66 
67  TVM_DLL static PlacementSpec Replica();
68 
70 };
71 
73  public:
76 
77  static void RegisterReflection() {
78  namespace refl = tvm::ffi::reflection;
79  refl::ObjectDef<ShardingNode>().def_ro("sharding_dim", &ShardingNode::sharding_dim);
80  }
81 
83 };
84 
86 class PlacementNode : public Object {
87  public:
89  Array<PlacementSpec> dim_specs;
90 
91  String ToString() const;
92 
93  static void RegisterReflection() {
94  namespace refl = tvm::ffi::reflection;
95  refl::ObjectDef<PlacementNode>().def_ro("dim_specs", &PlacementNode::dim_specs);
96  }
97 
98  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode;
99  static constexpr const char* _type_key = "relax.distributed.Placement";
101 };
102 
107 class Placement : public ObjectRef {
108  public:
109  TVM_DLL explicit Placement(Array<PlacementSpec> dim_specs);
111  static Placement FromText(String text_repr);
113 };
114 
119  public:
132 
133  static void RegisterReflection() {
134  namespace refl = tvm::ffi::reflection;
135  refl::ObjectDef<DTensorStructInfoNode>()
136  .def_ro("device_mesh", &DTensorStructInfoNode::device_mesh)
137  .def_ro("placement", &DTensorStructInfoNode::placement)
138  .def_ro("tensor_sinfo", &DTensorStructInfoNode::tensor_sinfo);
139  }
140 
141  static constexpr const char* _type_key = "relax.DTensorStructInfo";
143 };
144 
150  public:
158  TVM_DLL DTensorStructInfo(TensorStructInfo tensor_sinfo, DeviceMesh device_mesh,
159  Placement placement, Span span = Span());
160 
162 };
163 
164 } // namespace distributed
165 } // namespace relax
166 } // namespace tvm
167 
168 #endif // TVM_RELAX_DISTRIBUTED_STRUCT_INFO_H_
Container of constant int that adds more constructors.
Definition: expr.h:612
Definition: source_map.h:113
Base type of all structure information.
Definition: expr.h:110
Managed reference to StructInfoNode.
Definition: expr.h:135
Managed reference to TensorStructInfoNode.
Definition: struct_info.h:196
StructInfo of DTensor (Distributed Tensor).
Definition: struct_info.h:118
TensorStructInfo tensor_sinfo
The struct info inherited from TensorStructInfo.
Definition: struct_info.h:123
Placement placement
The placement of the tensor among the device mesh.
Definition: struct_info.h:131
TVM_DECLARE_FINAL_OBJECT_INFO(DTensorStructInfoNode, StructInfoNode)
DeviceMesh device_mesh
The device mesh of the tensor.
Definition: struct_info.h:127
static void RegisterReflection()
Definition: struct_info.h:133
static constexpr const char * _type_key
Definition: struct_info.h:141
Managed reference to DTensorStructInfoNode.
Definition: struct_info.h:149
DTensorStructInfo(TensorStructInfo tensor_sinfo, DeviceMesh device_mesh, Placement placement, Span span=Span())
Construction with device mesh and placement.
TVM_DEFINE_OBJECT_REF_METHODS(DTensorStructInfo, StructInfo, DTensorStructInfoNode)
Managed reference to a DeviceMesh.
Definition: global_info.h:64
Describes how data is distributed in each dimension of the device mesh.
Definition: struct_info.h:86
TVM_DECLARE_FINAL_OBJECT_INFO(PlacementNode, Object)
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: struct_info.h:98
static constexpr const char * _type_key
Definition: struct_info.h:99
Array< PlacementSpec > dim_specs
specs for each dim of device mesh.
Definition: struct_info.h:89
static void RegisterReflection()
Definition: struct_info.h:93
Describes how data is distributed in one dimension of the device mesh.
Definition: struct_info.h:37
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: struct_info.h:55
PlacementSpecKind kind
The kind of placement spec. Possible values: kSharding and kReplica.
Definition: struct_info.h:45
TVM_DECLARE_BASE_OBJECT_INFO(PlacementSpecNode, Object)
static void RegisterReflection()
Definition: struct_info.h:47
static constexpr const char * _type_key
Definition: struct_info.h:54
int axis
If the kind is sharding, this value represents the tensor dimension to shard. otherwise,...
Definition: struct_info.h:42
Managed reference to PlacementSpecNode.
Definition: struct_info.h:63
TVM_DEFINE_OBJECT_REF_METHODS(PlacementSpec, ObjectRef, PlacementSpecNode)
static PlacementSpec Sharding(int axis)
Managed reference to a Placement.
Definition: struct_info.h:107
static Placement FromText(String text_repr)
replica dim is printed as "R" and sharding dim is printed as "S[i]".]
Placement(Array< PlacementSpec > dim_specs)
TVM_DEFINE_OBJECT_REF_METHODS(Placement, ObjectRef, PlacementNode)
Definition: struct_info.h:72
TVM_DECLARE_FINAL_OBJECT_INFO(ShardingNode, PlacementSpecNode)
static void RegisterReflection()
Definition: struct_info.h:77
Integer sharding_dim
The dimension of tensor we shard.
Definition: struct_info.h:75
Definition: repr_printer.h:91
PlacementSpecKind
Definition: struct_info.h:34
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
Data structure for distributed inference.