template<typename T>
class tvm::relax::NestedMsg< T >
Container that stores possibly nested message with leaf message type T.
NestedMsg is a helper structure to store intermediate message state in pass analysis so we can robustly handle message passing with the presence of nested tuple types.
Under the hood, NestedMsg[T] = Union[T, NullOpt, Array[NestedMsg[T]]]. Each nested message corresponds to the same nesting structure as the nested tuple types when we encounter them in analysis.
Relax support nested tuple structures in the IR. Nested tuple structure is important to support advanced groupings in cases such as gradient calculation and other scenarios.
The possible presence of nested tuple does mean that we need to to robustly handle analysis that contains nested tuple structures in a dataflow graph.
t = ((v0, v1), (v2,), v0)
t1 = t[0]
v3 = concat(t1)
v4 = t[2]
tvm::te::Tensor relu(const tvm::te::Tensor &t, T threshold=static_cast< T >(0), std::string name="T_relu", std::string tag=kElementWise)
Creates an operation that performs a rectified linear unit.
Definition: nn.h:55
PrimExpr exp(PrimExpr x, Span span=Span())
Definition: op.h:706
PrimExpr add(PrimExpr a, PrimExpr b, Span span=Span())
add operator
Consider the above code sequence that contains a mixture of tuple nesting and normal operations. A common message-passing-based analysis will track messages attached to each intermediate variable.
Because the intermediate value can contain nested-tuples, we need to have abilities to nest messages according to tuple structure and propagate them along the way. In python, this simply corresponds to using a tuple to hold nested messages. This class provides a helper wrapper in C++ to present such possibly nested message for a given leaf message.
This design pattern is necessary to handle tuple values regardless of the normal form design of the IR to enable different messages for each tuple component without enforcing all tuple elements to have the same message.
Please consider the following patterns in our pass:
On a forward propagation message passing analysis:
- Create map [leafnode=>NestedMsg<T>], scan forward
- input_msg = [MapToNestedMsg<T>(x, lookup_map) for x in call->args]
- output_msg = ForwardProp[call->op](input_msg, call)
- map[binding->var] = output_msg
- Use MapToNestedMsg to remap the remaining body.
On a backward propagation message passing analysis:
- Create map [leafnode=>NestedMsg<T>], scan backward
- output_msg = lookup map(binding->var)
- handle case when output_msg is null
- input_msg = BackProp[call->op](out_msg, call)
- for arg, msg in zip(call->args, input_msg), DecomposeNestedMessage(arg, msg, lambda node, m: update_map(node, m))
- update_map(node, m) => CombineNestedMessage(map[node], m)
Here leafnode is a node that you would like to propagate messages to such as constant, var and should not include tuple.
We also recommend writing unit-test cases that involve nested tuple composition and decomposition.
- See also
- MapToNestedMsg, DecomposeNestedMsg, CombineNestedMsg, ForEachLeaf, Equal
- Note
- If you want to write robust message passing-based analysis for programs that can contain nested tuples, you likely need to use this class or logic of a similar kind.