tvm
Classes | Public Types | Public Member Functions | List of all members
tvm::relax::distributed::AxisGroupGraph Class Reference

A graph whose nodes are tensor axes, and the edge means some information can be propagated through the two axes. Although it only does sharding propagation, this data structure can be extended to perform all kinds of propagation that happens on tensor axes. More...

#include <axis_group_graph.h>

Collaboration diagram for tvm::relax::distributed::AxisGroupGraph:

Public Types

enum class  EdgeType { kAscend , kDescend , kSimbling }
 

Public Member Functions

 AxisGroupGraph ()=default
 
void JoinAxis (Axis axis1, Axis axis2, EdgeType type)
 add edge between two axes More...
 
void AddSrcShardingPoint (Axis axis, AxisShardingSpec spec)
 add a source shardingspec to propagate More...
 
void PropagateShardingSpec ()
 propagate sharding specs from source axes More...
 
void AddPropagationCutPoint (Axis axis, AxisShardingSpec spec)
 add a cut point that stops the propagation of a certain sharding spec More...
 
std::tuple< AxisShardingSpec, bool > GetAxisShardingSpec (Axis axis)
 Get the Sharding Spec of an axis after propagation. More...
 

Detailed Description

A graph whose nodes are tensor axes, and the edge means some information can be propagated through the two axes. Although it only does sharding propagation, this data structure can be extended to perform all kinds of propagation that happens on tensor axes.

Member Enumeration Documentation

◆ EdgeType

Enumerator
kAscend 
kDescend 
kSimbling 

Constructor & Destructor Documentation

◆ AxisGroupGraph()

tvm::relax::distributed::AxisGroupGraph::AxisGroupGraph ( )
default

Member Function Documentation

◆ AddPropagationCutPoint()

void tvm::relax::distributed::AxisGroupGraph::AddPropagationCutPoint ( Axis  axis,
AxisShardingSpec  spec 
)
inline

add a cut point that stops the propagation of a certain sharding spec

Parameters
axisThe cut point
specThe spec to stop propagation

◆ AddSrcShardingPoint()

void tvm::relax::distributed::AxisGroupGraph::AddSrcShardingPoint ( Axis  axis,
AxisShardingSpec  spec 
)
inline

add a source shardingspec to propagate

Parameters
axisThe source axis
specThe axis's sharding spec

◆ GetAxisShardingSpec()

std::tuple<AxisShardingSpec, bool> tvm::relax::distributed::AxisGroupGraph::GetAxisShardingSpec ( Axis  axis)
inline

Get the Sharding Spec of an axis after propagation.

Parameters
axisthe specified axis
Returns
if a sharding spec is found, return (axis_sharding_spec, true) otherwise, return (null axis_sharding_spec, false)

◆ JoinAxis()

void tvm::relax::distributed::AxisGroupGraph::JoinAxis ( Axis  axis1,
Axis  axis2,
EdgeType  type 
)
inline

add edge between two axes

Parameters
axis1The src axis
axis2The dst axis
typeThe producer-consumer relationship between src tensor and dst tensor kAscend means consumer->producer kDescend means producer->consumer kSimbling means other cases

◆ PropagateShardingSpec()

void tvm::relax::distributed::AxisGroupGraph::PropagateShardingSpec ( )
inline

propagate sharding specs from source axes


The documentation for this class was generated from the following file: