tvm
dataflow_pattern.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_DATAFLOW_PATTERN_H_
25 #define TVM_RELAX_DATAFLOW_PATTERN_H_
26 
27 #include <tvm/ir/expr.h>
28 #include <tvm/relax/expr.h>
29 #include <tvm/relax/type.h>
32 #include <tvm/support/with.h>
33 
34 #include <cstdint>
35 #include <functional>
36 #include <map>
37 #include <memory>
38 #include <string>
39 #include <tuple>
40 #include <utility>
41 #include <vector>
42 
43 namespace tvm {
44 
45 namespace arith {
46 class Analyzer;
47 }
48 
49 namespace relax {
50 
51 class PatternSeq;
52 class CallPattern;
53 class OrPattern;
54 class AndPattern;
55 class NotPattern;
56 class ShapePattern;
57 class StructInfoPattern;
58 class TypePattern;
59 class DataTypePattern;
60 class AttrPattern;
61 class SameShapeConstraint;
62 
71 TVM_DLL PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index = -1);
73 TVM_DLL PatternSeq operator^(const PatternSeq& lhs, const PatternSeq& rhs);
74 
83 TVM_DLL PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index = -1);
85 TVM_DLL PatternSeq operator>>(const PatternSeq& lhs, const PatternSeq& rhs);
86 
91 class DFPatternNode : public Object {
92  public:
93  static constexpr const char* _type_key = "DFPatternNode";
95 };
96 
101 class DFPattern : public ObjectRef {
102  public:
104  template <typename... Args>
105  CallPattern operator()(Args&&... args) const;
107  TVM_DLL CallPattern operator()(const std::vector<DFPattern>& args) const;
109  TVM_DLL OrPattern operator|(const DFPattern& other) const;
111  TVM_DLL AndPattern operator&(const DFPattern& other) const;
113  TVM_DLL NotPattern operator~() const;
115  TVM_DLL AttrPattern HasAttr(const Map<String, ObjectRef>& attrs) const;
117  TVM_DLL StructInfoPattern HasStructInfo(const StructInfo& struct_info) const;
119  TVM_DLL TypePattern HasType(const Type& type) const;
121  TVM_DLL DataTypePattern HasDtype(const DataType& dtype) const;
123  TVM_DLL DataTypePattern HasDtype(const std::string& dtype) const;
127  TVM_DLL SameShapeConstraint HasSameShapeAs(const DFPattern& other) const;
129  TVM_DLL DFPattern dup() const;
130 
132  TVM_DLL operator PatternSeq() const;
133 
135 };
136 
138 struct PairCons {
140  enum Type {
144  int index = -1;
152  TVM_DLL explicit PairCons(Type t, int index = -1) : type(t), index(index) {}
153 
154  bool operator==(const PairCons& other) const {
155  return type == other.type && index == other.index;
156  }
157 };
158 
166 class DFConstraintNode : public Object {
167  public:
170 
198  virtual std::tuple<PrimExpr, bool> AsPrimExpr(
199  std::function<Optional<Var>(const DFPatternNode*)> match_state) const = 0;
200 
201  static constexpr const char* _type_key = "DFConstraintNode";
202  static constexpr const uint32_t _type_child_slots = 1;
204 };
205 
206 class DFConstraint : public ObjectRef {
207  public:
209 };
210 
215 class PatternSeqNode final : public Object {
216  public:
218  std::vector<PairCons> pair_constraints;
220  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("patterns", &patterns); }
221  static constexpr const char* _type_key = "relax.dpl.PatternSeq";
223 };
224 
229 class PatternSeq final : public ObjectRef {
230  public:
231  TVM_DLL explicit PatternSeq(DFPattern init_pattern);
232  TVM_DLL explicit PatternSeq(tvm::Array<DFPattern> patterns, bool only_used_by = false);
233 
234  PatternSeq UsedBy(PatternSeq other, int index = -1) const;
235  PatternSeq OnlyUsedBy(PatternSeq other, int index = -1) const;
236 
238  PatternSeq dup() const;
239 
240  // friend functions
241  friend PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index);
242  friend PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index);
243 
245 };
246 
251 class PatternContextNode : public Object {
252  public:
254  enum ExternUse {
258 
259  // src node -> <dst node, constraint type> constraints.
260  // Dst nodes are kept in a vector to keep them ordered.
261  std::map<DFPattern, std::vector<std::pair<DFPattern, std::vector<PairCons>>>> edge_constraints;
262 
263  // Underlying DFPattern nodes which the edge constraints may reference
264  // Kept as a separate vector of patterns to process constraints in a fixed order.
265  std::vector<DFPattern> src_ordered;
266 
267  // Non-edge constraints
268  std::vector<DFConstraint> validation_constraints;
269 
270  static constexpr const char* _type_key = "relax.dpl.PatternContext";
272 };
273 
278 class PatternContext : public ObjectRef {
279  public:
280  TVM_DLL explicit PatternContext(ObjectPtr<Object> n) : ObjectRef(n) {}
281  TVM_DLL explicit PatternContext(bool incremental = false);
282 
284  ICHECK(get() != nullptr);
285  return static_cast<const PatternContextNode*>(get());
286  }
287 
289  ICHECK(get() != nullptr);
290  return static_cast<PatternContextNode*>(get_mutable());
291  }
292 
300  void add_constraint(DFPattern producer, DFPattern consumer, PairCons cons) {
301  auto& pairs = (*this)->edge_constraints[producer];
302  auto it = std::find_if(pairs.begin(), pairs.end(),
303  [consumer](auto p) { return p.first == consumer; });
304  if (it == pairs.end()) {
305  pairs.emplace_back(consumer, std::vector{cons});
306  } else {
307  auto& vec = it->second;
308  ICHECK(std::find(vec.cbegin(), vec.cend(), cons) == vec.cend())
309  << "Constraint already exists";
310  vec.push_back(cons);
311  }
312 
313  auto& patterns = (*this)->src_ordered;
314  if (std::find(patterns.begin(), patterns.end(), producer) == patterns.end()) {
315  patterns.push_back(producer);
316  }
317  }
318 
324  void add_constraint(DFConstraint constraint) {
325  (*this)->validation_constraints.push_back(constraint);
326  }
327 
330 
332  TVM_DLL void EnterWithScope() const;
334  TVM_DLL void ExitWithScope() const;
335 
336  private:
337  friend class With<PatternContext>;
338 };
339 
345  public:
348  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("expr", &expr); }
349 
350  static constexpr const char* _type_key = "relax.dpl.ExprPattern";
352 };
353 
358 class ExprPattern : public DFPattern {
359  public:
360  TVM_DLL explicit ExprPattern(Expr expr);
362 };
363 
370  public:
372  const String& name_hint() const { return name; }
373  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); }
374 
375  static constexpr const char* _type_key = "relax.dpl.VarPattern";
377 };
378 
383 class VarPattern : public DFPattern {
384  public:
390  TVM_DLL VarPattern(String name_hint);
392 };
393 
399  public:
400  static constexpr const char* _type_key = "relax.dpl.DataflowVarPattern";
402 };
403 
409  public:
411  TVM_DLL DataflowVarPattern(String name_hint);
413 };
414 
420  public:
421  static constexpr const char* _type_key = "relax.dpl.GlobalVarPattern";
423 };
424 
429 class GlobalVarPattern : public DFPattern {
430  public:
431  TVM_DLL GlobalVarPattern(String name_hint);
433 };
434 
440  public:
442 
443  static constexpr const char* _type_key = "relax.dpl.ConstantPattern";
445 };
446 
451 class ConstantPattern : public DFPattern {
452  public:
454 };
455 
461  public:
477  // Todo(relax-team): Dataflow pattern for StructInfo, and match sinfo_args
478 
480  v->Visit("op", &op);
481  v->Visit("args", &args);
482  }
483 
484  static constexpr const char* _type_key = "relax.dpl.CallPattern";
486 };
487 
488 class CallPattern : public DFPattern {
489  public:
490  TVM_DLL CallPattern(DFPattern op, Array<DFPattern> args, bool varg_default_wildcard = false);
492 };
493 
500  public:
502  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); }
503  static constexpr const char* _type_key = "relax.dpl.PrimArrPattern";
505 };
506 
511 class PrimArrPattern : public DFPattern {
512  public:
515 };
516 
523  public:
534  v->Visit("params", &params);
535  v->Visit("body", &body);
536  }
537 
538  static constexpr const char* _type_key = "relax.dpl.FunctionPattern";
540 };
541 
546 class FunctionPattern : public DFPattern {
547  public:
554 
556 };
557 
563  public:
566  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); }
567 
568  static constexpr const char* _type_key = "relax.dpl.TuplePattern";
570 };
571 
576 class TuplePattern : public DFPattern {
577  public:
578  TVM_DLL explicit TuplePattern(tvm::Array<DFPattern> fields);
580 };
581 
587  public:
590  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); }
591 
592  static constexpr const char* _type_key = "relax.dpl.UnorderedTuplePattern";
594 };
595 
601  public:
604 };
605 
612  public:
614  int index;
617  v->Visit("tuple", &tuple);
618  v->Visit("index", &index);
619  }
620 
621  static constexpr const char* _type_key = "relax.dpl.TupleGetItemPattern";
623 };
624 
630  public:
631  TVM_DLL TupleGetItemPattern(DFPattern tuple, int index);
633 };
634 
640  public:
645  v->Visit("left", &left);
646  v->Visit("right", &right);
647  }
648 
649  static constexpr const char* _type_key = "relax.dpl.AndPattern";
651 };
652 
657 class AndPattern : public DFPattern {
658  public:
659  TVM_DLL AndPattern(DFPattern lhs, DFPattern rhs);
661 };
662 
667 class OrPatternNode : public DFPatternNode {
668  public:
673  v->Visit("left", &left);
674  v->Visit("right", &right);
675  }
676 
677  static constexpr const char* _type_key = "relax.dpl.OrPattern";
679 };
680 
685 class OrPattern : public DFPattern {
686  public:
687  TVM_DLL OrPattern(DFPattern left, DFPattern right);
689 };
690 
696  public:
699  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("reject", &reject); }
700 
701  static constexpr const char* _type_key = "relax.dpl.NotPattern";
703 };
704 
709 class NotPattern : public DFPattern {
710  public:
711  TVM_DLL NotPattern(DFPattern reject);
713 };
714 
720  public:
722 
723  static constexpr const char* _type_key = "relax.dpl.WildcardPattern";
725 };
726 
731 class WildcardPattern : public DFPattern {
732  public:
734 
735  // Declaring WildcardPattern declared as non-nullable avoids the
736  // default zero-parameter constructor for ObjectRef with `data_ =
737  // nullptr`. This allows a zero-parameter constructor to be
738  // declared here, to create a valid wildcard instance.
739 
741 };
742 
748  public:
753  v->Visit("pattern", &pattern);
754  v->Visit("type", &type);
755  }
756 
757  static constexpr const char* _type_key = "relax.dpl.TypePattern";
759 };
760 
765 class TypePattern : public DFPattern {
766  public:
767  TVM_DLL TypePattern(DFPattern pattern, Type type);
769 };
770 
776  public:
781  v->Visit("pattern", &pattern);
782  v->Visit("struct_info", &struct_info);
783  }
784 
785  static constexpr const char* _type_key = "relax.dpl.StructInfoPattern";
787 };
788 
789 class StructInfoPattern : public DFPattern {
790  public:
791  TVM_DLL StructInfoPattern(DFPattern pattern, StructInfo struct_info);
793 };
794 
800  public:
805  v->Visit("pattern", &pattern);
806  v->Visit("shape", &shape);
807  }
808 
809  static constexpr const char* _type_key = "relax.dpl.ShapePattern";
811 };
812 
817 class ShapePattern : public DFPattern {
818  public:
819  TVM_DLL ShapePattern(DFPattern pattern, Array<PrimExpr> type);
821 };
822 
828  public:
831  Array<DFPattern> GetDependentPatterns() const override { return args; }
832 
833  std::tuple<PrimExpr, bool> AsPrimExpr(
834  std::function<Optional<Var>(const DFPatternNode*)> match_state) const override;
835 
836  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("args", &args); }
837 
838  static constexpr const char* _type_key = "relax.dpl.SameShapeConstraint";
840 };
841 
847  public:
850 };
851 
857  public:
862  v->Visit("pattern", &pattern);
863  v->Visit("dtype", &dtype);
864  }
865 
866  static constexpr const char* _type_key = "relax.dpl.DataTypePattern";
868 };
869 
874 class DataTypePattern : public DFPattern {
875  public:
876  TVM_DLL DataTypePattern(DFPattern pattern, DataType dtype);
878 };
879 
885  public:
890  v->Visit("pattern", &pattern);
891  v->Visit("attrs", &attrs);
892  }
893 
894  static constexpr const char* _type_key = "relax.dpl.AttrPattern";
896 };
897 
902 class AttrPattern : public DFPattern {
903  public:
904  TVM_DLL AttrPattern(DFPattern pattern, DictAttrs attrs);
906 };
907 
914  public:
918  const String& global_symbol() const { return global_symbol_; }
919  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("global_symbol", &global_symbol_); }
920 
921  static constexpr const char* _type_key = "relax.dpl.ExternFuncPattern";
923 };
924 
929 class ExternFuncPattern : public DFPattern {
930  public:
931  TVM_DLL ExternFuncPattern(String global_symbol);
933 };
934 
936 VarPattern IsVar(const String& name);
942 ExprPattern IsExpr(const Expr& expr);
944 ExprPattern IsOp(const String& op_name);
946 // Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo
949 CallPattern IsCallTIR(const String& name, TuplePattern var_args);
955 DFPattern IsTuple(const Array<DFPattern>& fields, bool unordered = false);
957 TupleGetItemPattern IsTupleGetItem(const DFPattern tuple, int index = -1);
958 
960 template <typename... Args>
961 CallPattern DFPattern::operator()(Args&&... args) const {
962  return CallPattern(GetRef<DFPattern>(this->get()),
963  Array<DFPattern>({std::forward<Args>(args)...}));
964 }
965 
966 } // namespace relax
967 } // namespace tvm
968 #endif // TVM_RELAX_DATAFLOW_PATTERN_H_
Runtime Array container types.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Managed reference to DictAttrsNode.
Definition: attrs.h:227
Managed reference to RelayExprNode.
Definition: expr.h:442
Managed reference to TypeNode.
Definition: type.h:93
RAII wrapper function to enter and exit a context object similar to python's with syntax.
Definition: with.h:58
Match a conjunction of other patterns.
Definition: dataflow_pattern.h:639
DFPattern left
Definition: dataflow_pattern.h:641
DFPattern right
Definition: dataflow_pattern.h:642
static constexpr const char * _type_key
Definition: dataflow_pattern.h:649
TVM_DECLARE_FINAL_OBJECT_INFO(AndPatternNode, DFPatternNode)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: dataflow_pattern.h:644
Managed reference to AndPatternNode.
Definition: dataflow_pattern.h:657
AndPattern(DFPattern lhs, DFPattern rhs)
TVM_DEFINE_OBJECT_REF_METHODS(AndPattern, DFPattern, AndPatternNode)
A pattern that asserting a root pattern has certain attributes.
Definition: dataflow_pattern.h:884
DictAttrs attrs
Definition: dataflow_pattern.h:887
DFPattern pattern
Definition: dataflow_pattern.h:886
void VisitAttrs(tvm::AttrVisitor *v)
Definition: dataflow_pattern.h:889
TVM_DECLARE_FINAL_OBJECT_INFO(AttrPatternNode, DFPatternNode)
static constexpr const char * _type_key
Definition: dataflow_pattern.h:894
Managed reference to AttrPatternNode.
Definition: dataflow_pattern.h:902
TVM_DEFINE_OBJECT_REF_METHODS(AttrPattern, DFPattern, AttrPatternNode)
AttrPattern(DFPattern pattern, DictAttrs attrs)
A pattern to match a callable node in Relax.
Definition: dataflow_pattern.h:460
TVM_DECLARE_FINAL_OBJECT_INFO(CallPatternNode, DFPatternNode)
bool varg_default_wildcard
Definition: dataflow_pattern.h:475
static constexpr const char * _type_key
Definition: dataflow_pattern.h:484
void VisitAttrs(tvm::AttrVisitor *v)
Definition: dataflow_pattern.h:479
tvm::Array< DFPattern > args
Definition: dataflow_pattern.h:468
DFPattern op
Definition: dataflow_pattern.h:467
Definition: dataflow_pattern.h:488
CallPattern(DFPattern op, Array< DFPattern > args, bool varg_default_wildcard=false)
TVM_DEFINE_OBJECT_REF_METHODS(CallPattern, DFPattern, CallPatternNode)
A Pattern to Match a Relax Constant.
Definition: dataflow_pattern.h:439
void VisitAttrs(tvm::AttrVisitor *v)
Definition: dataflow_pattern.h:441
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantPatternNode, DFPatternNode)
static constexpr const char * _type_key
Definition: dataflow_pattern.h:443
Managed reference to a ConstantPattern.
Definition: dataflow_pattern.h:451
TVM_DEFINE_OBJECT_REF_METHODS(ConstantPattern, DFPattern, ConstantPatternNode)
Additional constraints on the graph.
Definition: dataflow_pattern.h:166
TVM_DECLARE_BASE_OBJECT_INFO(DFConstraintNode, Object)
virtual Array< DFPattern > GetDependentPatterns() const =0
Return the patterns on which the constraint depends.
static constexpr const uint32_t _type_child_slots
Definition: dataflow_pattern.h:202
virtual std::tuple< PrimExpr, bool > AsPrimExpr(std::function< Optional< Var >(const DFPatternNode *)> match_state) const =0
Convert the constraint to a PrimExpr.
static constexpr const char * _type_key
Definition: dataflow_pattern.h:201
Definition: dataflow_pattern.h:206
TVM_DEFINE_OBJECT_REF_METHODS(DFConstraint, ObjectRef, DFConstraintNode)
Base type of all dataflow patterns.
Definition: dataflow_pattern.h:91
TVM_DECLARE_BASE_OBJECT_INFO(DFPatternNode, Object)
static constexpr const char * _type_key
Definition: dataflow_pattern.h:93
Managed reference to dataflow patterns.
Definition: dataflow_pattern.h:101
CallPattern operator()(const std::vector< DFPattern > &args) const
Syntatic Sugar for creating a CallPattern.
ShapePattern HasShape(const Array< PrimExpr > &shape) const
Syntatic Sugar for creating a ShapePattern.
DataTypePattern HasDtype(const DataType &dtype) const
Syntatic Sugar for creating a DataTypePattern with a DataType.
SameShapeConstraint HasSameShapeAs(const DFPattern &other) const
Syntatic Sugar for creating a ShapePattern.
OrPattern operator|(const DFPattern &other) const
Syntatic Sugar for creating an OrPattern.
CallPattern operator()(Args &&... args) const
Syntatic Sugar for creating a CallPattern.
Definition: dataflow_pattern.h:961
TVM_DEFINE_OBJECT_REF_METHODS(DFPattern, ObjectRef, DFPatternNode)
TypePattern HasType(const Type &type) const
Syntatic Sugar for creating a TypePattern.
StructInfoPattern HasStructInfo(const StructInfo &struct_info) const
Syntatic Sugar for creating a StructInfoPattern.
AndPattern operator&(const DFPattern &other) const
Syntatic Sugar for creating an AndPattern.
DFPattern dup() const
Syntatic Sugar for duplicating the current pattern.
NotPattern operator~() const
Syntatic Sugar for creating a NotPattern.
AttrPattern HasAttr(const Map< String, ObjectRef > &attrs) const
Syntatic Sugar for creating an AttrPattern.
DataTypePattern HasDtype(const std::string &dtype) const
Syntatic Sugar for creating a DataTypePattern with a data type's name.
A pattern that asserting a root pattern has a certain data type.
Definition: dataflow_pattern.h:856
DFPattern pattern
Definition: dataflow_pattern.h:858
TVM_DECLARE_FINAL_OBJECT_INFO(DataTypePatternNode, DFPatternNode)
DataType dtype
Definition: dataflow_pattern.h:859
void VisitAttrs(tvm::AttrVisitor *v)
Definition: dataflow_pattern.h:861
static constexpr const char * _type_key
Definition: dataflow_pattern.h:866
Managed reference to DataTypePatternNode.
Definition: dataflow_pattern.h:874
DataTypePattern(DFPattern pattern, DataType dtype)
TVM_DEFINE_OBJECT_REF_METHODS(DataTypePattern, DFPattern, DataTypePatternNode)
A Pattern to Match a Relax Dataflow Variable.
Definition: dataflow_pattern.h:398
TVM_DECLARE_FINAL_OBJECT_INFO(DataflowVarPatternNode, DFPatternNode)
static constexpr const char * _type_key
Definition: dataflow_pattern.h:400
Managed reference to a DataflowVarPattern.
Definition: dataflow_pattern.h:408
DataflowVarPattern(String name_hint)
TVM_DEFINE_OBJECT_REF_METHODS(DataflowVarPattern, DFPattern, DataflowVarPatternNode)
Pattern for Relax Expression.
Definition: dataflow_pattern.h:344
Expr expr
Definition: dataflow_pattern.h:346
void VisitAttrs(tvm::AttrVisitor *v)
Definition: dataflow_pattern.h:348
TVM_DECLARE_FINAL_OBJECT_INFO(ExprPatternNode, DFPatternNode)
static constexpr const char * _type_key
Definition: dataflow_pattern.h:350
Managed reference to an ExprPattern.
Definition: dataflow_pattern.h:358
TVM_DEFINE_OBJECT_REF_METHODS(ExprPattern, DFPattern, ExprPatternNode)
A pattern of external function.
Definition: dataflow_pattern.h:913
const String & global_symbol() const
The external function name.
Definition: dataflow_pattern.h:918
String global_symbol_
Definition: dataflow_pattern.h:915
TVM_DECLARE_FINAL_OBJECT_INFO(ExternFuncPatternNode, DFPatternNode)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: dataflow_pattern.h:919
static constexpr const char * _type_key
Definition: dataflow_pattern.h:921
Managed reference to ExternFuncPatternNode.
Definition: dataflow_pattern.h:929
ExternFuncPattern(String global_symbol)
TVM_DEFINE_OBJECT_REF_METHODS(ExternFuncPattern, DFPattern, ExternFuncPatternNode)
A pattern to match a Relax Function.
Definition: dataflow_pattern.h:522
DFPattern body
Definition: dataflow_pattern.h:531
void VisitAttrs(tvm::AttrVisitor *v)
Definition: dataflow_pattern.h:533
static constexpr const char * _type_key
Definition: dataflow_pattern.h:538
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPatternNode, DFPatternNode)
tvm::Array< DFPattern > params
Definition: dataflow_pattern.h:524
Managed reference to FunctionPatternNode.
Definition: dataflow_pattern.h:546
FunctionPattern(tvm::Array< DFPattern > params, DFPattern body)
Constructor.
TVM_DEFINE_OBJECT_REF_METHODS(FunctionPattern, DFPattern, FunctionPatternNode)
A Pattern to Match a Relax Global Variable.
Definition: dataflow_pattern.h:419
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarPatternNode, DFPatternNode)
static constexpr const char * _type_key
Definition: dataflow_pattern.h:421
Managed reference to a GlobalVarPattern.
Definition: dataflow_pattern.h:429
GlobalVarPattern(String name_hint)
TVM_DEFINE_OBJECT_REF_METHODS(GlobalVarPattern, DFPattern, GlobalVarPatternNode)
Pattern for rejecting a certain pattern.
Definition: dataflow_pattern.h:695
DFPattern reject
Definition: dataflow_pattern.h:697
void VisitAttrs(tvm::AttrVisitor *v)
Definition: dataflow_pattern.h:699
TVM_DECLARE_FINAL_OBJECT_INFO(NotPatternNode, DFPatternNode)
static constexpr const char * _type_key
Definition: dataflow_pattern.h:701
Managed reference to NotPatternNode.
Definition: dataflow_pattern.h:709
NotPattern(DFPattern reject)
TVM_DEFINE_OBJECT_REF_METHODS(NotPattern, DFPattern, NotPatternNode)
Match a disjunction of other patterns.
Definition: dataflow_pattern.h:667
DFPattern left
Definition: dataflow_pattern.h:669
DFPattern right
Definition: dataflow_pattern.h:670
TVM_DECLARE_FINAL_OBJECT_INFO(OrPatternNode, DFPatternNode)
static constexpr const char * _type_key
Definition: dataflow_pattern.h:677
void VisitAttrs(tvm::AttrVisitor *v)
Definition: dataflow_pattern.h:672
Managed reference to OrPatternNode.
Definition: dataflow_pattern.h:685
TVM_DEFINE_OBJECT_REF_METHODS(OrPattern, DFPattern, OrPatternNode)
OrPattern(DFPattern left, DFPattern right)
A context to manage the graph-level pattern matching.
Definition: dataflow_pattern.h:251
ExternUse
Constrainting matched graph with assertion to external uses.
Definition: dataflow_pattern.h:254
@ kMustNot
Definition: dataflow_pattern.h:256
@ kMay
Definition: dataflow_pattern.h:255
TVM_DECLARE_FINAL_OBJECT_INFO(PatternContextNode, Object)
std::vector< DFConstraint > validation_constraints
Definition: dataflow_pattern.h:268
static constexpr const char * _type_key
Definition: dataflow_pattern.h:270
enum tvm::relax::PatternContextNode::ExternUse allow_extern_use
std::map< DFPattern, std::vector< std::pair< DFPattern, std::vector< PairCons > > > > edge_constraints
Definition: dataflow_pattern.h:261
std::vector< DFPattern > src_ordered
Definition: dataflow_pattern.h:265
Managed reference to a pattern context.
Definition: dataflow_pattern.h:278
void add_constraint(DFConstraint constraint)
Add a validation constraint.
Definition: dataflow_pattern.h:324
void ExitWithScope() const
The RAII-like exit of a constraint context scope.
PatternContextNode * operator->()
Definition: dataflow_pattern.h:288
static Optional< PatternContext > Current()
Get the constraint context object on the top of the stack.
void add_constraint(DFPattern producer, DFPattern consumer, PairCons cons)
Build an edge constraint between two patterns (producer and consumer).
Definition: dataflow_pattern.h:300
const PatternContextNode * operator->() const
Definition: dataflow_pattern.h:283
PatternContext(bool incremental=false)
PatternContext(ObjectPtr< Object > n)
Definition: dataflow_pattern.h:280
void EnterWithScope() const
The RAII-like entry of a constraint context scope.
A sequence of DFPatterns that the previous DFPattern is connected to the next one.
Definition: dataflow_pattern.h:215
TVM_DECLARE_BASE_OBJECT_INFO(PatternSeqNode, Object)
tvm::Array< DFPattern > patterns
Definition: dataflow_pattern.h:217
static constexpr const char * _type_key
Definition: dataflow_pattern.h:221
void VisitAttrs(tvm::AttrVisitor *v)
Definition: dataflow_pattern.h:220
std::vector< PairCons > pair_constraints
Definition: dataflow_pattern.h:218
Managed reference to pattern sequences.
Definition: dataflow_pattern.h:229
PatternSeq(tvm::Array< DFPattern > patterns, bool only_used_by=false)
PatternSeq OnlyUsedBy(PatternSeq other, int index=-1) const
TVM_DEFINE_OBJECT_REF_METHODS(PatternSeq, ObjectRef, PatternSeqNode)
PatternSeq UsedBy(PatternSeq other, int index=-1) const
friend PatternSeq UsedBy(const PatternSeq &lhs, const PatternSeq &rhs, int index)
Create used-by relationship between lhs[-1] and rhs[0], with [*lhs, *rhs] returned.
PatternSeq(DFPattern init_pattern)
friend PatternSeq OnlyUsedBy(const PatternSeq &lhs, const PatternSeq &rhs, int index)
Create only-used-by relationship between lhs[-1] and rhs[0], with [*lhs, *rhs] returned.
PatternSeq dup() const
Syntatic Sugar for duplicating the current pattern sequence.
A pattern to match an array of PrimExpr.
Definition: dataflow_pattern.h:499
TVM_DECLARE_FINAL_OBJECT_INFO(PrimArrPatternNode, DFPatternNode)
Array< PrimExpr > fields
Definition: dataflow_pattern.h:501
static constexpr const char * _type_key
Definition: dataflow_pattern.h:503
void VisitAttrs(tvm::AttrVisitor *v)
Definition: dataflow_pattern.h:502
Managed reference to a PrimArrPattern.
Definition: dataflow_pattern.h:511
TVM_DEFINE_OBJECT_REF_METHODS(PrimArrPattern, DFPattern, PrimArrPatternNode)
PrimArrPattern(Array< PrimExpr > arr)
A pattern that asserting multiple root patterns have the same shape.
Definition: dataflow_pattern.h:827
void VisitAttrs(tvm::AttrVisitor *v)
Definition: dataflow_pattern.h:836
Array< DFPattern > args
Definition: dataflow_pattern.h:829
TVM_DECLARE_FINAL_OBJECT_INFO(SameShapeConstraintNode, DFConstraintNode)
std::tuple< PrimExpr, bool > AsPrimExpr(std::function< Optional< Var >(const DFPatternNode *)> match_state) const override
Convert the constraint to a PrimExpr.
static constexpr const char * _type_key
Definition: dataflow_pattern.h:838
Array< DFPattern > GetDependentPatterns() const override
Return the patterns on which the constraint depends.
Definition: dataflow_pattern.h:831
Managed reference to SameShapePatternNode.
Definition: dataflow_pattern.h:846
TVM_DEFINE_OBJECT_REF_METHODS(SameShapeConstraint, DFConstraint, SameShapeConstraintNode)
SameShapeConstraint(Array< DFPattern > args)
A pattern that asserting a root pattern has a certain shape.
Definition: dataflow_pattern.h:799
static constexpr const char * _type_key
Definition: dataflow_pattern.h:809
void VisitAttrs(tvm::AttrVisitor *v)
Definition: dataflow_pattern.h:804
TVM_DECLARE_FINAL_OBJECT_INFO(ShapePatternNode, DFPatternNode)
DFPattern pattern
Definition: dataflow_pattern.h:801
Array< PrimExpr > shape
Definition: dataflow_pattern.h:802
Managed reference to ShapePatternNode.
Definition: dataflow_pattern.h:817
TVM_DEFINE_OBJECT_REF_METHODS(ShapePattern, DFPattern, ShapePatternNode)
ShapePattern(DFPattern pattern, Array< PrimExpr > type)
Pattern for matching a certain struct info.
Definition: dataflow_pattern.h:775
void VisitAttrs(tvm::AttrVisitor *v)
Definition: dataflow_pattern.h:780
StructInfo struct_info
Definition: dataflow_pattern.h:778
static constexpr const char * _type_key
Definition: dataflow_pattern.h:785
TVM_DECLARE_FINAL_OBJECT_INFO(StructInfoPatternNode, DFPatternNode)
DFPattern pattern
Definition: dataflow_pattern.h:777
Definition: dataflow_pattern.h:789
TVM_DEFINE_OBJECT_REF_METHODS(StructInfoPattern, DFPattern, StructInfoPatternNode)
StructInfoPattern(DFPattern pattern, StructInfo struct_info)
Managed reference to StructInfoNode.
Definition: expr.h:129
A pattern to match n'th indexing to a tuple.
Definition: dataflow_pattern.h:611
static constexpr const char * _type_key
Definition: dataflow_pattern.h:621
TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemPatternNode, DFPatternNode)
DFPattern tuple
Definition: dataflow_pattern.h:613
void VisitAttrs(tvm::AttrVisitor *v)
Definition: dataflow_pattern.h:616
int index
Definition: dataflow_pattern.h:614
Managed reference to TupleGetItemPatternNode.
Definition: dataflow_pattern.h:629
TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItemPattern, DFPattern, TupleGetItemPatternNode)
TupleGetItemPattern(DFPattern tuple, int index)
Pattern to match a tuple of ordered expressions.
Definition: dataflow_pattern.h:562
tvm::Array< DFPattern > fields
Definition: dataflow_pattern.h:564
TVM_DECLARE_FINAL_OBJECT_INFO(TuplePatternNode, DFPatternNode)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: dataflow_pattern.h:566
static constexpr const char * _type_key
Definition: dataflow_pattern.h:568
Managed reference to TuplePatternNode.
Definition: dataflow_pattern.h:576
TVM_DEFINE_OBJECT_REF_METHODS(TuplePattern, DFPattern, TuplePatternNode)
TuplePattern(tvm::Array< DFPattern > fields)
Pattern for matching a certain type.
Definition: dataflow_pattern.h:747
void VisitAttrs(tvm::AttrVisitor *v)
Definition: dataflow_pattern.h:752
TVM_DECLARE_FINAL_OBJECT_INFO(TypePatternNode, DFPatternNode)
DFPattern pattern
Definition: dataflow_pattern.h:749
Type type
Definition: dataflow_pattern.h:750
static constexpr const char * _type_key
Definition: dataflow_pattern.h:757
Managed reference to TypePatternNode.
Definition: dataflow_pattern.h:765
TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode)
TypePattern(DFPattern pattern, Type type)
A pattern to match multiple expressions unorderedly.
Definition: dataflow_pattern.h:586
tvm::Array< DFPattern > fields
Definition: dataflow_pattern.h:588
static constexpr const char * _type_key
Definition: dataflow_pattern.h:592
void VisitAttrs(tvm::AttrVisitor *v)
Definition: dataflow_pattern.h:590
TVM_DECLARE_FINAL_OBJECT_INFO(UnorderedTuplePatternNode, DFPatternNode)
Managed reference to UnorderedTuplePatternNode.
Definition: dataflow_pattern.h:600
UnorderedTuplePattern(tvm::Array< DFPattern > fields)
TVM_DEFINE_OBJECT_REF_METHODS(UnorderedTuplePattern, DFPattern, UnorderedTuplePatternNode)
A Pattern to Match a Relax Variable.
Definition: dataflow_pattern.h:369
void VisitAttrs(tvm::AttrVisitor *v)
Definition: dataflow_pattern.h:373
String name
Definition: dataflow_pattern.h:371
const String & name_hint() const
Definition: dataflow_pattern.h:372
static constexpr const char * _type_key
Definition: dataflow_pattern.h:375
TVM_DECLARE_BASE_OBJECT_INFO(VarPatternNode, DFPatternNode)
Managed reference to a VarPattern.
Definition: dataflow_pattern.h:383
TVM_DEFINE_OBJECT_REF_METHODS(VarPattern, DFPattern, VarPatternNode)
VarPattern(String name_hint)
Create a pattern matching by variable name.
Wildcard Pattern is a pattern that can match anything.
Definition: dataflow_pattern.h:719
static constexpr const char * _type_key
Definition: dataflow_pattern.h:723
void VisitAttrs(tvm::AttrVisitor *v)
Definition: dataflow_pattern.h:721
TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode)
Managed reference to WildcardPatternNode.
Definition: dataflow_pattern.h:731
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode)
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Runtime primitive data type.
Definition: data_type.h:43
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
A custom smart pointer for Object.
Definition: object.h:362
Base class of all object reference.
Definition: object.h:519
const Object * get() const
Definition: object.h:554
Object * get_mutable() const
Definition: object.h:607
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
Base expr nodes in TVM.
CallPattern IsCallTIR(const String &name, Optional< TuplePattern > args=NullOpt)
Syntatic Sugar for call_tir (return a tensor)
PatternSeq operator^(const PatternSeq &lhs, const PatternSeq &rhs)
Syntax sugar of UsedBy(lhs, rhs, -1).
CallPattern IsCallDPSPacked(const String &name, Optional< TuplePattern > args=NullOpt)
Syntatic Sugar for call_dps_packed (return a tensor)
ExprPattern IsOp(const String &op_name)
Syntatic Sugar for creating a ExprPattern base on an Op.
DFPattern IsTuple(const Array< DFPattern > &fields, bool unordered=false)
Syntatic Sugar for creating TuplePattern or UnorderedTuplePattern (unordered=true)
WildcardPattern Wildcard()
Syntatic Sugar for creating a WildcardPattern.
ConstantPattern IsConst()
Syntatic Sugar for creating a ConstantPattern.
PatternSeq UsedBy(const PatternSeq &lhs, const PatternSeq &rhs, int index=-1)
Create used-by relationship between lhs[-1] and rhs[0], with [*lhs, *rhs] returned.
PatternSeq OnlyUsedBy(const PatternSeq &lhs, const PatternSeq &rhs, int index=-1)
Create only-used-by relationship between lhs[-1] and rhs[0], with [*lhs, *rhs] returned.
ExprPattern IsExpr(const Expr &expr)
Syntatic Sugar for creating a ExprPattern.
PatternSeq operator>>(const PatternSeq &lhs, const PatternSeq &rhs)
Syntax sugar of OnlyUsedBy(lhs, rhs, -1).
TupleGetItemPattern IsTupleGetItem(const DFPattern tuple, int index=-1)
Syntatic Sugar for creating a TupleGetItemPattern.
VarPattern IsVar(const String &name)
Syntatic Sugar for creating a VarPattern with a name.
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1913
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169
Runtime Optional container types.
Relax Types.
Constraint of a DFPattern edge (producer -> consumer) in graph-level matching.
Definition: dataflow_pattern.h:138
int index
Definition: dataflow_pattern.h:144
bool operator==(const PairCons &other) const
Definition: dataflow_pattern.h:154
enum tvm::relax::PairCons::Type type
Type
Constraint types of the edge.
Definition: dataflow_pattern.h:140
@ kOnlyUsedBy
Definition: dataflow_pattern.h:142
@ kUsedBy
Definition: dataflow_pattern.h:141
PairCons(Type t, int index=-1)
Construct a new PairCons object.
Definition: dataflow_pattern.h:152
RAII wrapper function to enter and exit a context object similar to python's with syntax.