tvm
struct_info.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_RELAX_DISTRIBUTED_STRUCT_INFO_H_
26 #define TVM_RELAX_DISTRIBUTED_STRUCT_INFO_H_
27 
29 #include <tvm/relax/struct_info.h>
30 namespace tvm {
31 namespace relax {
32 namespace distributed {
33 
34 enum class PlacementSpecKind : int { kSharding = 0, kReplica = 1 };
35 
37 class PlacementSpecNode : public Object {
38  public:
42  int axis;
43 
46 
47  static void RegisterReflection() {
48  namespace refl = tvm::ffi::reflection;
49  refl::ObjectDef<PlacementSpecNode>()
50  .def_ro("axis", &PlacementSpecNode::axis)
51  .def_ro("kind", &PlacementSpecNode::kind);
52  }
53 
54  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode;
55  TVM_FFI_DECLARE_OBJECT_INFO("relax.distributed.PlacementSpec", PlacementSpecNode, Object);
56 };
57 
62 class PlacementSpec : public ObjectRef {
63  public:
64  TVM_DLL static PlacementSpec Sharding(int axis);
65 
66  TVM_DLL static PlacementSpec Replica();
67 
69 };
70 
72  public:
75 
76  static void RegisterReflection() {
77  namespace refl = tvm::ffi::reflection;
78  refl::ObjectDef<ShardingNode>().def_ro("sharding_dim", &ShardingNode::sharding_dim);
79  }
80 
82 };
83 
85 class PlacementNode : public Object {
86  public:
88  ffi::Array<PlacementSpec> dim_specs;
89 
90  ffi::String ToString() const;
91 
92  static void RegisterReflection() {
93  namespace refl = tvm::ffi::reflection;
94  refl::ObjectDef<PlacementNode>().def_ro("dim_specs", &PlacementNode::dim_specs);
95  }
96 
97  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode;
98  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.distributed.Placement", PlacementNode, Object);
99 };
100 
105 class Placement : public ObjectRef {
106  public:
107  TVM_DLL explicit Placement(ffi::Array<PlacementSpec> dim_specs);
109  static Placement FromText(ffi::String text_repr);
111 };
112 
117  public:
130 
131  static void RegisterReflection() {
132  namespace refl = tvm::ffi::reflection;
133  refl::ObjectDef<DTensorStructInfoNode>()
134  .def_ro("device_mesh", &DTensorStructInfoNode::device_mesh)
135  .def_ro("placement", &DTensorStructInfoNode::placement)
136  .def_ro("tensor_sinfo", &DTensorStructInfoNode::tensor_sinfo);
137  }
140 };
141 
147  public:
155  TVM_DLL DTensorStructInfo(TensorStructInfo tensor_sinfo, DeviceMesh device_mesh,
156  Placement placement, Span span = Span());
157 
159 };
160 
161 } // namespace distributed
162 } // namespace relax
163 } // namespace tvm
164 
165 #endif // TVM_RELAX_DISTRIBUTED_STRUCT_INFO_H_
Container of constant int that adds more constructors.
Definition: expr.h:600
Definition: source_map.h:111
Base type of all structure information.
Definition: expr.h:108
Managed reference to StructInfoNode.
Definition: expr.h:132
Managed reference to TensorStructInfoNode.
Definition: struct_info.h:190
StructInfo of DTensor (Distributed Tensor).
Definition: struct_info.h:116
TensorStructInfo tensor_sinfo
The struct info inherited from TensorStructInfo.
Definition: struct_info.h:121
Placement placement
The placement of the tensor among the device mesh.
Definition: struct_info.h:129
DeviceMesh device_mesh
The device mesh of the tensor.
Definition: struct_info.h:125
static void RegisterReflection()
Definition: struct_info.h:131
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.DTensorStructInfo", DTensorStructInfoNode, StructInfoNode)
Managed reference to DTensorStructInfoNode.
Definition: struct_info.h:146
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DTensorStructInfo, StructInfo, DTensorStructInfoNode)
DTensorStructInfo(TensorStructInfo tensor_sinfo, DeviceMesh device_mesh, Placement placement, Span span=Span())
Construction with device mesh and placement.
Managed reference to a DeviceMesh.
Definition: global_info.h:62
Describes how data is distributed in each dimension of the device mesh.
Definition: struct_info.h:85
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: struct_info.h:97
ffi::Array< PlacementSpec > dim_specs
specs for each dim of device mesh.
Definition: struct_info.h:88
static void RegisterReflection()
Definition: struct_info.h:92
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.distributed.Placement", PlacementNode, Object)
Describes how data is distributed in one dimension of the device mesh.
Definition: struct_info.h:37
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: struct_info.h:54
PlacementSpecKind kind
The kind of placement spec. Possible values: kSharding and kReplica.
Definition: struct_info.h:45
TVM_FFI_DECLARE_OBJECT_INFO("relax.distributed.PlacementSpec", PlacementSpecNode, Object)
static void RegisterReflection()
Definition: struct_info.h:47
int axis
If the kind is sharding, this value represents the tensor dimension to shard. otherwise,...
Definition: struct_info.h:42
Managed reference to PlacementSpecNode.
Definition: struct_info.h:62
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PlacementSpec, ObjectRef, PlacementSpecNode)
static PlacementSpec Sharding(int axis)
Managed reference to a Placement.
Definition: struct_info.h:105
static Placement FromText(ffi::String text_repr)
replica dim is printed as "R" and sharding dim is printed as "S[i]".]
Placement(ffi::Array< PlacementSpec > dim_specs)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Placement, ObjectRef, PlacementNode)
Definition: struct_info.h:71
static void RegisterReflection()
Definition: struct_info.h:76
Integer sharding_dim
The dimension of tensor we shard.
Definition: struct_info.h:74
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.distributed.Sharding", ShardingNode, PlacementSpecNode)
Definition: repr_printer.h:91
PlacementSpecKind
Definition: struct_info.h:34
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
Data structure for distributed inference.