A graph whose nodes are tensor axes, and the edge means some information can be propagated through the two axes. Although it only does sharding propagation, this data structure can be extended to perform all kinds of propagation that happens on tensor axes.
More...
#include <axis_group_graph.h>
A graph whose nodes are tensor axes, and the edge means some information can be propagated through the two axes. Although it only does sharding propagation, this data structure can be extended to perform all kinds of propagation that happens on tensor axes.
◆ EdgeType
Enumerator |
---|
kAscend | |
kDescend | |
kSimbling | |
◆ AxisGroupGraph()
tvm::relax::distributed::AxisGroupGraph::AxisGroupGraph |
( |
| ) |
|
|
default |
◆ AddPropagationCutPoint()
void tvm::relax::distributed::AxisGroupGraph::AddPropagationCutPoint |
( |
Axis |
axis, |
|
|
AxisShardingSpec |
spec |
|
) |
| |
|
inline |
add a cut point that stops the propagation of a certain sharding spec
- Parameters
-
axis | The cut point |
spec | The spec to stop propagation |
◆ AddSrcShardingPoint()
void tvm::relax::distributed::AxisGroupGraph::AddSrcShardingPoint |
( |
Axis |
axis, |
|
|
AxisShardingSpec |
spec |
|
) |
| |
|
inline |
add a source shardingspec to propagate
- Parameters
-
axis | The source axis |
spec | The axis's sharding spec |
◆ GetAxisShardingSpec()
std::tuple<AxisShardingSpec, bool> tvm::relax::distributed::AxisGroupGraph::GetAxisShardingSpec |
( |
Axis |
axis | ) |
|
|
inline |
Get the Sharding Spec of an axis after propagation.
- Parameters
-
- Returns
- if a sharding spec is found, return (axis_sharding_spec, true) otherwise, return (null axis_sharding_spec, false)
◆ JoinAxis()
void tvm::relax::distributed::AxisGroupGraph::JoinAxis |
( |
Axis |
axis1, |
|
|
Axis |
axis2, |
|
|
EdgeType |
type |
|
) |
| |
|
inline |
add edge between two axes
- Parameters
-
axis1 | The src axis |
axis2 | The dst axis |
type | The producer-consumer relationship between src tensor and dst tensor kAscend means consumer->producer kDescend means producer->consumer kSimbling means other cases |
◆ PropagateShardingSpec()
void tvm::relax::distributed::AxisGroupGraph::PropagateShardingSpec |
( |
| ) |
|
|
inline |
propagate sharding specs from source axes
The documentation for this class was generated from the following file: