tvm
Namespaces | Classes | Typedefs | Enumerations | Functions
tvm::relax::distributed Namespace Reference

Namespaces

 transform
 

Classes

struct  Axis
 tensor axis More...
 
class  AxisHash
 
class  AxisGroupHash
 
class  AxisShardingSpecEqual
 
class  AxisShardingSpecHash
 
class  AxisGroupGraph
 A graph whose nodes are tensor axes, and the edge means some information can be propagated through the two axes. Although it only does sharding propagation, this data structure can be extended to perform all kinds of propagation that happens on tensor axes. More...
 
class  DeviceMeshNode
 
class  DeviceMesh
 Managed reference to a DeviceMesh. More...
 
class  PlacementSpecNode
 Describes how data is distributed in one dimension of the device mesh. More...
 
class  PlacementSpec
 Managed reference to PlacementSpecNode. More...
 
class  ShardingNode
 
class  PlacementNode
 Describes how data is distributed in each dimension of the device mesh. More...
 
class  Placement
 Managed reference to a Placement. More...
 
class  DTensorStructInfoNode
 StructInfo of DTensor (Distributed Tensor). More...
 
class  DTensorStructInfo
 Managed reference to DTensorStructInfoNode. More...
 

Typedefs

using AxisGroup = std::unordered_set< Axis, AxisHash >
 
using ShardingSpec = std::pair< DeviceMesh, Placement >
 
using AxisShardingSpec = std::pair< DeviceMesh, int >
 
using FBuildAxisGraph = std::function< void(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)>
 

Enumerations

enum class  PlacementSpecKind : int { kSharding = 0 , kReplica = 1 }
 

Functions

void BuildAxisGraphUnary (const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
 
void BuildAxisGraphBinary (const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
 
void BuildAxisGraphReduce (const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
 
void BuildAxisGraphMatmul (const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
 
void BuildAxisGraphPermuteDims (const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
 
void BuildAxisGraphReshape (const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
 
void BuildAxisGraphCallTIR (const Var &output_var, const Call &call, const tir::PrimFunc &func, distributed::AxisGroupGraph *axis_group_graph)
 

Typedef Documentation

◆ AxisGroup

using tvm::relax::distributed::AxisGroup = typedef std::unordered_set<Axis, AxisHash>

◆ AxisShardingSpec

using tvm::relax::distributed::AxisShardingSpec = typedef std::pair<DeviceMesh, int>

◆ FBuildAxisGraph

using tvm::relax::distributed::FBuildAxisGraph = typedef std::function<void(const Var& output_var, const Call& call, distributed::AxisGroupGraph* axis_group_graph)>

◆ ShardingSpec

Enumeration Type Documentation

◆ PlacementSpecKind

Enumerator
kSharding 
kReplica 

Function Documentation

◆ BuildAxisGraphBinary()

void tvm::relax::distributed::BuildAxisGraphBinary ( const Var output_var,
const Call call,
distributed::AxisGroupGraph axis_group_graph 
)

◆ BuildAxisGraphCallTIR()

void tvm::relax::distributed::BuildAxisGraphCallTIR ( const Var output_var,
const Call call,
const tir::PrimFunc func,
distributed::AxisGroupGraph axis_group_graph 
)

◆ BuildAxisGraphMatmul()

void tvm::relax::distributed::BuildAxisGraphMatmul ( const Var output_var,
const Call call,
distributed::AxisGroupGraph axis_group_graph 
)

◆ BuildAxisGraphPermuteDims()

void tvm::relax::distributed::BuildAxisGraphPermuteDims ( const Var output_var,
const Call call,
distributed::AxisGroupGraph axis_group_graph 
)

◆ BuildAxisGraphReduce()

void tvm::relax::distributed::BuildAxisGraphReduce ( const Var output_var,
const Call call,
distributed::AxisGroupGraph axis_group_graph 
)

◆ BuildAxisGraphReshape()

void tvm::relax::distributed::BuildAxisGraphReshape ( const Var output_var,
const Call call,
distributed::AxisGroupGraph axis_group_graph 
)

◆ BuildAxisGraphUnary()

void tvm::relax::distributed::BuildAxisGraphUnary ( const Var output_var,
const Call call,
distributed::AxisGroupGraph axis_group_graph 
)