tvm
Classes | Namespaces | Typedefs | Functions
axis_group_graph.h File Reference
#include <tvm/arith/iter_affine_map.h>
#include <tvm/relax/distributed/struct_info.h>
#include <tvm/relax/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
#include <algorithm>
#include <limits>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
Include dependency graph for axis_group_graph.h:

Go to the source code of this file.

Classes

class  tvm::tir::BufferAxisHash
 
class  tvm::tir::BufferAxisGraphExtractor
 Construct an axis group graph from a PrimFunc. Two buffer axis are connected if they are accessed by the same index. More...
 
struct  tvm::relax::distributed::Axis
 tensor axis More...
 
class  tvm::relax::distributed::AxisHash
 
class  tvm::relax::distributed::AxisGroupHash
 
class  tvm::relax::distributed::AxisShardingSpecEqual
 
class  tvm::relax::distributed::AxisShardingSpecHash
 
class  tvm::relax::distributed::AxisGroupGraph
 A graph whose nodes are tensor axes, and the edge means some information can be propagated through the two axes. Although it only does sharding propagation, this data structure can be extended to perform all kinds of propagation that happens on tensor axes. More...
 

Namespaces

 tvm
 runtime implementation for LibTorch/TorchScript.
 
 tvm::tir
 
 tvm::relax
 
 tvm::relax::distributed
 

Typedefs

using tvm::tir::TIRVarAxis = std::pair< Var, int >
 
using tvm::tir::BufferAxis = std::pair< Buffer, int >
 
using tvm::relax::distributed::AxisGroup = std::unordered_set< Axis, AxisHash >
 
using tvm::relax::distributed::ShardingSpec = std::pair< DeviceMesh, Placement >
 
using tvm::relax::distributed::AxisShardingSpec = std::pair< DeviceMesh, int >
 
using tvm::relax::distributed::FBuildAxisGraph = std::function< void(const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)>
 

Functions

Var tvm::tir::GetShardingVarFromIndex (PrimExpr index, Map< Var, Range > var_range, arith::Analyzer *analyzer)
 Suppose we want to shard a buffer along a specific dimension, we need to know how to rewrite the access index of the buffer. To make it simple, we only support the case that the access can be rewritten by changing the extent of an iter var. More...
 
void tvm::relax::distributed::BuildAxisGraphUnary (const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
 
void tvm::relax::distributed::BuildAxisGraphBinary (const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
 
void tvm::relax::distributed::BuildAxisGraphReduce (const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
 
void tvm::relax::distributed::BuildAxisGraphMatmul (const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
 
void tvm::relax::distributed::BuildAxisGraphPermuteDims (const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
 
void tvm::relax::distributed::BuildAxisGraphReshape (const Var &output_var, const Call &call, distributed::AxisGroupGraph *axis_group_graph)
 
void tvm::relax::distributed::BuildAxisGraphCallTIR (const Var &output_var, const Call &call, const tir::PrimFunc &func, distributed::AxisGroupGraph *axis_group_graph)