tvm
dataflow_pattern.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_DATAFLOW_PATTERN_H_
25 #define TVM_RELAX_DATAFLOW_PATTERN_H_
26 
27 #include <tvm/ffi/container/array.h>
28 #include <tvm/ffi/optional.h>
29 #include <tvm/ffi/reflection/registry.h>
30 #include <tvm/ir/expr.h>
31 #include <tvm/relax/expr.h>
32 #include <tvm/relax/type.h>
33 #include <tvm/support/with.h>
34 
35 #include <cstdint>
36 #include <functional>
37 #include <map>
38 #include <memory>
39 #include <string>
40 #include <tuple>
41 #include <utility>
42 #include <vector>
43 
44 namespace tvm {
45 
46 namespace arith {
47 class Analyzer;
48 }
49 
50 namespace relax {
51 
52 class PatternSeq;
53 class CallPattern;
54 class OrPattern;
55 class AndPattern;
56 class NotPattern;
57 class ShapePattern;
58 class StructInfoPattern;
59 class DataTypePattern;
60 class AttrPattern;
61 class SameShapeConstraint;
62 
71 TVM_DLL PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index = -1);
73 TVM_DLL PatternSeq operator^(const PatternSeq& lhs, const PatternSeq& rhs);
74 
83 TVM_DLL PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index = -1);
85 TVM_DLL PatternSeq operator>>(const PatternSeq& lhs, const PatternSeq& rhs);
86 
91 class DFPatternNode : public Object {
92  public:
93  static constexpr const char* _type_key = "DFPatternNode";
94  static constexpr const uint32_t _type_child_slots = 21;
96 };
97 
102 class DFPattern : public ObjectRef {
103  public:
105  template <typename... Args>
106  CallPattern operator()(Args&&... args) const;
108  TVM_DLL CallPattern operator()(const std::vector<DFPattern>& args) const;
110  TVM_DLL OrPattern operator|(const DFPattern& other) const;
112  TVM_DLL AndPattern operator&(const DFPattern& other) const;
114  TVM_DLL NotPattern operator~() const;
116  TVM_DLL AttrPattern HasAttr(const Map<String, Any>& attrs) const;
118  TVM_DLL StructInfoPattern HasStructInfo(const StructInfo& struct_info) const;
120  TVM_DLL DataTypePattern HasDtype(const DataType& dtype) const;
122  TVM_DLL DataTypePattern HasDtype(const std::string& dtype) const;
124  TVM_DLL ShapePattern HasShape(const Array<PrimExpr>& shape) const;
126  TVM_DLL SameShapeConstraint HasSameShapeAs(const DFPattern& other) const;
128  TVM_DLL DFPattern dup() const;
129 
131  TVM_DLL operator PatternSeq() const;
132 
134 };
135 
137 struct PairCons {
139  enum Type {
143  int index = -1;
151  TVM_DLL explicit PairCons(Type t, int index = -1) : type(t), index(index) {}
152 
153  bool operator==(const PairCons& other) const {
154  return type == other.type && index == other.index;
155  }
156 };
157 
165 class DFConstraintNode : public Object {
166  public:
168  virtual Array<DFPattern> GetDependentPatterns() const = 0;
169 
197  virtual std::tuple<PrimExpr, bool> AsPrimExpr(
198  std::function<Optional<Var>(const DFPatternNode*)> match_state) const = 0;
199 
200  static constexpr const char* _type_key = "DFConstraintNode";
201  static constexpr const uint32_t _type_child_slots = 1;
203 };
204 
205 class DFConstraint : public ObjectRef {
206  public:
208 };
209 
214 class PatternSeqNode final : public Object {
215  public:
216  tvm::Array<DFPattern> patterns;
217  std::vector<PairCons> pair_constraints;
219  static void RegisterReflection() {
220  namespace refl = tvm::ffi::reflection;
221  refl::ObjectDef<PatternSeqNode>().def_ro("patterns", &PatternSeqNode::patterns);
222  }
223 
224  static constexpr const char* _type_key = "relax.dpl.PatternSeq";
226 };
227 
232 class PatternSeq final : public ObjectRef {
233  public:
234  TVM_DLL explicit PatternSeq(DFPattern init_pattern);
235  TVM_DLL explicit PatternSeq(tvm::Array<DFPattern> patterns, bool only_used_by = false);
236 
237  PatternSeq UsedBy(PatternSeq other, int index = -1) const;
238  PatternSeq OnlyUsedBy(PatternSeq other, int index = -1) const;
239 
241  PatternSeq dup() const;
242 
243  // friend functions
244  friend PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index);
245  friend PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index);
246 
248 };
249 
254 class PatternContextNode : public Object {
255  public:
257  enum ExternUse {
261 
262  // src node -> <dst node, constraint type> constraints.
263  // Dst nodes are kept in a vector to keep them ordered.
264  std::map<DFPattern, std::vector<std::pair<DFPattern, std::vector<PairCons>>>> edge_constraints;
265 
266  // Underlying DFPattern nodes which the edge constraints may reference
267  // Kept as a separate vector of patterns to process constraints in a fixed order.
268  std::vector<DFPattern> src_ordered;
269 
270  // Non-edge constraints
271  std::vector<DFConstraint> validation_constraints;
272 
273  static constexpr const char* _type_key = "relax.dpl.PatternContext";
275 };
276 
281 class PatternContext : public ObjectRef {
282  public:
283  TVM_DLL explicit PatternContext(ObjectPtr<Object> n) : ObjectRef(n) {}
284  TVM_DLL explicit PatternContext(bool incremental = false);
285 
287  ICHECK(get() != nullptr);
288  return static_cast<const PatternContextNode*>(get());
289  }
290 
292  ICHECK(get() != nullptr);
293  return static_cast<PatternContextNode*>(get_mutable());
294  }
295 
303  void add_constraint(DFPattern producer, DFPattern consumer, PairCons cons) {
304  auto& pairs = (*this)->edge_constraints[producer];
305  auto it = std::find_if(pairs.begin(), pairs.end(),
306  [consumer](auto p) { return p.first == consumer; });
307  if (it == pairs.end()) {
308  pairs.emplace_back(consumer, std::vector{cons});
309  } else {
310  auto& vec = it->second;
311  ICHECK(std::find(vec.cbegin(), vec.cend(), cons) == vec.cend())
312  << "Constraint already exists";
313  vec.push_back(cons);
314  }
315 
316  auto& patterns = (*this)->src_ordered;
317  if (std::find(patterns.begin(), patterns.end(), producer) == patterns.end()) {
318  patterns.push_back(producer);
319  }
320  }
321 
327  void add_constraint(DFConstraint constraint) {
328  (*this)->validation_constraints.push_back(constraint);
329  }
330 
332  TVM_DLL static Optional<PatternContext> Current();
333 
335  TVM_DLL void EnterWithScope() const;
337  TVM_DLL void ExitWithScope() const;
338 
339  private:
340  friend class With<PatternContext>;
341 };
342 
348  public:
351  static void RegisterReflection() {
352  namespace refl = tvm::ffi::reflection;
353  refl::ObjectDef<ExprPatternNode>().def_ro("expr", &ExprPatternNode::expr);
354  }
355 
356  static constexpr const char* _type_key = "relax.dpl.ExprPattern";
358 };
359 
364 class ExprPattern : public DFPattern {
365  public:
366  TVM_DLL explicit ExprPattern(Expr expr);
368 };
369 
376  public:
377  String name;
378  const String& name_hint() const { return name; }
379 
380  static void RegisterReflection() {
381  namespace refl = tvm::ffi::reflection;
382  refl::ObjectDef<VarPatternNode>().def_ro("name", &VarPatternNode::name);
383  }
384 
385  static constexpr const char* _type_key = "relax.dpl.VarPattern";
386  static constexpr const uint32_t _type_child_slots = 1;
388 };
389 
394 class VarPattern : public DFPattern {
395  public:
401  TVM_DLL VarPattern(String name_hint);
403 };
404 
410  public:
411  static void RegisterReflection() {
412  namespace refl = tvm::ffi::reflection;
413  refl::ObjectDef<DataflowVarPatternNode>();
414  }
415 
416  static constexpr const char* _type_key = "relax.dpl.DataflowVarPattern";
418 };
419 
425  public:
427  TVM_DLL DataflowVarPattern(String name_hint);
429 };
430 
436  public:
437  static constexpr const char* _type_key = "relax.dpl.GlobalVarPattern";
439 };
440 
445 class GlobalVarPattern : public DFPattern {
446  public:
447  TVM_DLL GlobalVarPattern(String name_hint);
449 };
450 
456  public:
457  static void RegisterReflection() {
458  namespace refl = tvm::ffi::reflection;
459  refl::ObjectDef<ConstantPatternNode>();
460  }
461 
462  static constexpr const char* _type_key = "relax.dpl.ConstantPattern";
464 };
465 
470 class ConstantPattern : public DFPattern {
471  public:
473 };
474 
480  public:
487  tvm::Array<DFPattern> args;
496  // Todo(relax-team): Dataflow pattern for StructInfo, and match sinfo_args
497 
498  static void RegisterReflection() {
499  namespace refl = tvm::ffi::reflection;
500  refl::ObjectDef<CallPatternNode>()
501  .def_ro("op", &CallPatternNode::op)
502  .def_ro("args", &CallPatternNode::args);
503  }
504 
505  static constexpr const char* _type_key = "relax.dpl.CallPattern";
507 };
508 
509 class CallPattern : public DFPattern {
510  public:
511  TVM_DLL CallPattern(DFPattern op, Array<DFPattern> args, bool varg_default_wildcard = false);
513 };
514 
521  public:
522  Array<PrimExpr> fields;
524  static void RegisterReflection() {
525  namespace refl = tvm::ffi::reflection;
526  refl::ObjectDef<PrimArrPatternNode>().def_ro("fields", &PrimArrPatternNode::fields);
527  }
528 
529  static constexpr const char* _type_key = "relax.dpl.PrimArrPattern";
531 };
532 
537 class PrimArrPattern : public DFPattern {
538  public:
539  TVM_DLL PrimArrPattern(Array<PrimExpr> arr);
541 };
542 
549  public:
550  tvm::Array<DFPattern> params;
559  static void RegisterReflection() {
560  namespace refl = tvm::ffi::reflection;
561  refl::ObjectDef<FunctionPatternNode>()
562  .def_ro("params", &FunctionPatternNode::params)
563  .def_ro("body", &FunctionPatternNode::body);
564  }
565 
566  static constexpr const char* _type_key = "relax.dpl.FunctionPattern";
568 };
569 
574 class FunctionPattern : public DFPattern {
575  public:
581  TVM_DLL FunctionPattern(tvm::Array<DFPattern> params, DFPattern body);
582 
584 };
585 
591  public:
592  tvm::Array<DFPattern> fields;
594  static void RegisterReflection() {
595  namespace refl = tvm::ffi::reflection;
596  refl::ObjectDef<TuplePatternNode>().def_ro("fields", &TuplePatternNode::fields);
597  }
598 
599  static constexpr const char* _type_key = "relax.dpl.TuplePattern";
601 };
602 
607 class TuplePattern : public DFPattern {
608  public:
609  TVM_DLL explicit TuplePattern(tvm::Array<DFPattern> fields);
611 };
612 
618  public:
619  tvm::Array<DFPattern> fields;
621  static void RegisterReflection() {
622  namespace refl = tvm::ffi::reflection;
623  refl::ObjectDef<UnorderedTuplePatternNode>().def_ro("fields",
625  }
626 
627  static constexpr const char* _type_key = "relax.dpl.UnorderedTuplePattern";
629 };
630 
636  public:
637  TVM_DLL explicit UnorderedTuplePattern(tvm::Array<DFPattern> fields);
639 };
640 
647  public:
649  int index;
651  static void RegisterReflection() {
652  namespace refl = tvm::ffi::reflection;
653  refl::ObjectDef<TupleGetItemPatternNode>()
654  .def_ro("tuple", &TupleGetItemPatternNode::tuple)
655  .def_ro("index", &TupleGetItemPatternNode::index);
656  }
657 
658  static constexpr const char* _type_key = "relax.dpl.TupleGetItemPattern";
660 };
661 
667  public:
668  TVM_DLL TupleGetItemPattern(DFPattern tuple, int index);
670 };
671 
677  public:
681  static void RegisterReflection() {
682  namespace refl = tvm::ffi::reflection;
683  refl::ObjectDef<AndPatternNode>()
684  .def_ro("left", &AndPatternNode::left)
685  .def_ro("right", &AndPatternNode::right);
686  }
687 
688  static constexpr const char* _type_key = "relax.dpl.AndPattern";
690 };
691 
696 class AndPattern : public DFPattern {
697  public:
698  TVM_DLL AndPattern(DFPattern lhs, DFPattern rhs);
700 };
701 
706 class OrPatternNode : public DFPatternNode {
707  public:
711  static void RegisterReflection() {
712  namespace refl = tvm::ffi::reflection;
713  refl::ObjectDef<OrPatternNode>()
714  .def_ro("left", &OrPatternNode::left)
715  .def_ro("right", &OrPatternNode::right);
716  }
717 
718  static constexpr const char* _type_key = "relax.dpl.OrPattern";
720 };
721 
726 class OrPattern : public DFPattern {
727  public:
728  TVM_DLL OrPattern(DFPattern left, DFPattern right);
730 };
731 
737  public:
740  static void RegisterReflection() {
741  namespace refl = tvm::ffi::reflection;
742  refl::ObjectDef<NotPatternNode>().def_ro("reject", &NotPatternNode::reject);
743  }
744 
745  static constexpr const char* _type_key = "relax.dpl.NotPattern";
747 };
748 
753 class NotPattern : public DFPattern {
754  public:
755  TVM_DLL NotPattern(DFPattern reject);
757 };
758 
764  public:
765  static void RegisterReflection() {
766  namespace refl = tvm::ffi::reflection;
767  refl::ObjectDef<WildcardPatternNode>();
768  }
769 
770  static constexpr const char* _type_key = "relax.dpl.WildcardPattern";
772 };
773 
778 class WildcardPattern : public DFPattern {
779  public:
781 
782  // Declaring WildcardPattern declared as non-nullable avoids the
783  // default zero-parameter constructor for ObjectRef with `data_ =
784  // nullptr`. This allows a zero-parameter constructor to be
785  // declared here, to create a valid wildcard instance.
786 
788 };
789 
795  public:
799  static void RegisterReflection() {
800  namespace refl = tvm::ffi::reflection;
801  refl::ObjectDef<StructInfoPatternNode>()
802  .def_ro("pattern", &StructInfoPatternNode::pattern)
803  .def_ro("struct_info", &StructInfoPatternNode::struct_info);
804  }
805 
806  static constexpr const char* _type_key = "relax.dpl.StructInfoPattern";
808 };
809 
810 class StructInfoPattern : public DFPattern {
811  public:
812  TVM_DLL StructInfoPattern(DFPattern pattern, StructInfo struct_info);
814 };
815 
821  public:
823  Array<PrimExpr> shape;
825  static void RegisterReflection() {
826  namespace refl = tvm::ffi::reflection;
827  refl::ObjectDef<ShapePatternNode>()
828  .def_ro("pattern", &ShapePatternNode::pattern)
829  .def_ro("shape", &ShapePatternNode::shape);
830  }
831 
832  static constexpr const char* _type_key = "relax.dpl.ShapePattern";
834 };
835 
840 class ShapePattern : public DFPattern {
841  public:
842  TVM_DLL ShapePattern(DFPattern pattern, Array<PrimExpr> type);
844 };
845 
851  public:
852  Array<DFPattern> args;
854  Array<DFPattern> GetDependentPatterns() const override { return args; }
855 
856  std::tuple<PrimExpr, bool> AsPrimExpr(
857  std::function<Optional<Var>(const DFPatternNode*)> match_state) const override;
858 
859  static void RegisterReflection() {
860  namespace refl = tvm::ffi::reflection;
861  refl::ObjectDef<SameShapeConstraintNode>().def_ro("args", &SameShapeConstraintNode::args);
862  }
863 
864  static constexpr const char* _type_key = "relax.dpl.SameShapeConstraint";
866 };
867 
873  public:
874  TVM_DLL SameShapeConstraint(Array<DFPattern> args);
876 };
877 
883  public:
887  static void RegisterReflection() {
888  namespace refl = tvm::ffi::reflection;
889  refl::ObjectDef<DataTypePatternNode>()
890  .def_ro("pattern", &DataTypePatternNode::pattern)
891  .def_ro("dtype", &DataTypePatternNode::dtype);
892  }
893 
894  static constexpr const char* _type_key = "relax.dpl.DataTypePattern";
896 };
897 
902 class DataTypePattern : public DFPattern {
903  public:
904  TVM_DLL DataTypePattern(DFPattern pattern, DataType dtype);
906 };
907 
913  public:
917  static void RegisterReflection() {
918  namespace refl = tvm::ffi::reflection;
919  refl::ObjectDef<AttrPatternNode>()
920  .def_ro("pattern", &AttrPatternNode::pattern)
921  .def_ro("attrs", &AttrPatternNode::attrs);
922  }
923 
924  static constexpr const char* _type_key = "relax.dpl.AttrPattern";
926 };
927 
932 class AttrPattern : public DFPattern {
933  public:
934  TVM_DLL AttrPattern(DFPattern pattern, DictAttrs attrs);
936 };
937 
944  public:
945  String global_symbol_;
948  const String& global_symbol() const { return global_symbol_; }
949 
950  static void RegisterReflection() {
951  namespace refl = tvm::ffi::reflection;
952  refl::ObjectDef<ExternFuncPatternNode>().def_ro("global_symbol",
954  }
955 
956  static constexpr const char* _type_key = "relax.dpl.ExternFuncPattern";
958 };
959 
964 class ExternFuncPattern : public DFPattern {
965  public:
966  TVM_DLL ExternFuncPattern(String global_symbol);
968 };
969 
971 VarPattern IsVar(const String& name);
977 ExprPattern IsExpr(const Expr& expr);
979 ExprPattern IsOp(const String& op_name);
981 // Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo
982 CallPattern IsCallTIR(const String& name, Optional<TuplePattern> args = std::nullopt);
984 CallPattern IsCallTIR(const String& name, TuplePattern var_args);
986 CallPattern IsCallDPSPacked(const String& name, Optional<TuplePattern> args = std::nullopt);
988 CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args);
990 DFPattern IsTuple(const Array<DFPattern>& fields, bool unordered = false);
992 TupleGetItemPattern IsTupleGetItem(const DFPattern tuple, int index = -1);
993 
995 template <typename... Args>
996 CallPattern DFPattern::operator()(Args&&... args) const {
997  return CallPattern(GetRef<DFPattern>(this->get()),
998  Array<DFPattern>({std::forward<Args>(args)...}));
999 }
1000 
1001 } // namespace relax
1002 } // namespace tvm
1003 #endif // TVM_RELAX_DATAFLOW_PATTERN_H_
Managed reference to DictAttrsNode.
Definition: attrs.h:166
Managed reference to RelaxExprNode.
Definition: expr.h:446
RAII wrapper function to enter and exit a context object similar to python's with syntax.
Definition: with.h:58
Match a conjunction of other patterns.
Definition: dataflow_pattern.h:676
DFPattern left
Definition: dataflow_pattern.h:678
DFPattern right
Definition: dataflow_pattern.h:679
static constexpr const char * _type_key
Definition: dataflow_pattern.h:688
TVM_DECLARE_FINAL_OBJECT_INFO(AndPatternNode, DFPatternNode)
static void RegisterReflection()
Definition: dataflow_pattern.h:681
Managed reference to AndPatternNode.
Definition: dataflow_pattern.h:696
AndPattern(DFPattern lhs, DFPattern rhs)
TVM_DEFINE_OBJECT_REF_METHODS(AndPattern, DFPattern, AndPatternNode)
A pattern that asserting a root pattern has certain attributes.
Definition: dataflow_pattern.h:912
DictAttrs attrs
Definition: dataflow_pattern.h:915
DFPattern pattern
Definition: dataflow_pattern.h:914
static void RegisterReflection()
Definition: dataflow_pattern.h:917
TVM_DECLARE_FINAL_OBJECT_INFO(AttrPatternNode, DFPatternNode)
static constexpr const char * _type_key
Definition: dataflow_pattern.h:924
Managed reference to AttrPatternNode.
Definition: dataflow_pattern.h:932
TVM_DEFINE_OBJECT_REF_METHODS(AttrPattern, DFPattern, AttrPatternNode)
AttrPattern(DFPattern pattern, DictAttrs attrs)
A pattern to match a callable node in Relax.
Definition: dataflow_pattern.h:479
TVM_DECLARE_FINAL_OBJECT_INFO(CallPatternNode, DFPatternNode)
bool varg_default_wildcard
Definition: dataflow_pattern.h:494
static constexpr const char * _type_key
Definition: dataflow_pattern.h:505
static void RegisterReflection()
Definition: dataflow_pattern.h:498
tvm::Array< DFPattern > args
Definition: dataflow_pattern.h:487
DFPattern op
Definition: dataflow_pattern.h:486
Definition: dataflow_pattern.h:509
CallPattern(DFPattern op, Array< DFPattern > args, bool varg_default_wildcard=false)
TVM_DEFINE_OBJECT_REF_METHODS(CallPattern, DFPattern, CallPatternNode)
A Pattern to Match a Relax Constant.
Definition: dataflow_pattern.h:455
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantPatternNode, DFPatternNode)
static constexpr const char * _type_key
Definition: dataflow_pattern.h:462
static void RegisterReflection()
Definition: dataflow_pattern.h:457
Managed reference to a ConstantPattern.
Definition: dataflow_pattern.h:470
TVM_DEFINE_OBJECT_REF_METHODS(ConstantPattern, DFPattern, ConstantPatternNode)
Additional constraints on the graph.
Definition: dataflow_pattern.h:165
TVM_DECLARE_BASE_OBJECT_INFO(DFConstraintNode, Object)
virtual Array< DFPattern > GetDependentPatterns() const =0
Return the patterns on which the constraint depends.
static constexpr const uint32_t _type_child_slots
Definition: dataflow_pattern.h:201
virtual std::tuple< PrimExpr, bool > AsPrimExpr(std::function< Optional< Var >(const DFPatternNode *)> match_state) const =0
Convert the constraint to a PrimExpr.
static constexpr const char * _type_key
Definition: dataflow_pattern.h:200
Definition: dataflow_pattern.h:205
TVM_DEFINE_OBJECT_REF_METHODS(DFConstraint, ObjectRef, DFConstraintNode)
Base type of all dataflow patterns.
Definition: dataflow_pattern.h:91
static constexpr const uint32_t _type_child_slots
Definition: dataflow_pattern.h:94
TVM_DECLARE_BASE_OBJECT_INFO(DFPatternNode, Object)
static constexpr const char * _type_key
Definition: dataflow_pattern.h:93
Managed reference to dataflow patterns.
Definition: dataflow_pattern.h:102
CallPattern operator()(const std::vector< DFPattern > &args) const
Syntatic Sugar for creating a CallPattern.
ShapePattern HasShape(const Array< PrimExpr > &shape) const
Syntatic Sugar for creating a ShapePattern.
DataTypePattern HasDtype(const DataType &dtype) const
Syntatic Sugar for creating a DataTypePattern with a DataType.
SameShapeConstraint HasSameShapeAs(const DFPattern &other) const
Syntatic Sugar for creating a ShapePattern.
AttrPattern HasAttr(const Map< String, Any > &attrs) const
Syntatic Sugar for creating an AttrPattern.
OrPattern operator|(const DFPattern &other) const
Syntatic Sugar for creating an OrPattern.
CallPattern operator()(Args &&... args) const
Syntatic Sugar for creating a CallPattern.
Definition: dataflow_pattern.h:996
TVM_DEFINE_OBJECT_REF_METHODS(DFPattern, ObjectRef, DFPatternNode)
StructInfoPattern HasStructInfo(const StructInfo &struct_info) const
Syntatic Sugar for creating a StructInfoPattern.
AndPattern operator&(const DFPattern &other) const
Syntatic Sugar for creating an AndPattern.
DFPattern dup() const
Syntatic Sugar for duplicating the current pattern.
NotPattern operator~() const
Syntatic Sugar for creating a NotPattern.
DataTypePattern HasDtype(const std::string &dtype) const
Syntatic Sugar for creating a DataTypePattern with a data type's name.
A pattern that asserting a root pattern has a certain data type.
Definition: dataflow_pattern.h:882
DFPattern pattern
Definition: dataflow_pattern.h:884
TVM_DECLARE_FINAL_OBJECT_INFO(DataTypePatternNode, DFPatternNode)
DataType dtype
Definition: dataflow_pattern.h:885
static constexpr const char * _type_key
Definition: dataflow_pattern.h:894
static void RegisterReflection()
Definition: dataflow_pattern.h:887
Managed reference to DataTypePatternNode.
Definition: dataflow_pattern.h:902
DataTypePattern(DFPattern pattern, DataType dtype)
TVM_DEFINE_OBJECT_REF_METHODS(DataTypePattern, DFPattern, DataTypePatternNode)
A Pattern to Match a Relax Dataflow Variable.
Definition: dataflow_pattern.h:409
static void RegisterReflection()
Definition: dataflow_pattern.h:411
TVM_DECLARE_FINAL_OBJECT_INFO(DataflowVarPatternNode, VarPatternNode)
static constexpr const char * _type_key
Definition: dataflow_pattern.h:416
Managed reference to a DataflowVarPattern.
Definition: dataflow_pattern.h:424
DataflowVarPattern(String name_hint)
TVM_DEFINE_OBJECT_REF_METHODS(DataflowVarPattern, DFPattern, DataflowVarPatternNode)
Pattern for Relax Expression.
Definition: dataflow_pattern.h:347
Expr expr
Definition: dataflow_pattern.h:349
TVM_DECLARE_FINAL_OBJECT_INFO(ExprPatternNode, DFPatternNode)
static constexpr const char * _type_key
Definition: dataflow_pattern.h:356
static void RegisterReflection()
Definition: dataflow_pattern.h:351
Managed reference to an ExprPattern.
Definition: dataflow_pattern.h:364
TVM_DEFINE_OBJECT_REF_METHODS(ExprPattern, DFPattern, ExprPatternNode)
A pattern of external function.
Definition: dataflow_pattern.h:943
const String & global_symbol() const
The external function name.
Definition: dataflow_pattern.h:948
String global_symbol_
Definition: dataflow_pattern.h:945
TVM_DECLARE_FINAL_OBJECT_INFO(ExternFuncPatternNode, DFPatternNode)
static void RegisterReflection()
Definition: dataflow_pattern.h:950
static constexpr const char * _type_key
Definition: dataflow_pattern.h:956
Managed reference to ExternFuncPatternNode.
Definition: dataflow_pattern.h:964
ExternFuncPattern(String global_symbol)
TVM_DEFINE_OBJECT_REF_METHODS(ExternFuncPattern, DFPattern, ExternFuncPatternNode)
A pattern to match a Relax Function.
Definition: dataflow_pattern.h:548
static void RegisterReflection()
Definition: dataflow_pattern.h:559
DFPattern body
Definition: dataflow_pattern.h:557
static constexpr const char * _type_key
Definition: dataflow_pattern.h:566
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPatternNode, DFPatternNode)
tvm::Array< DFPattern > params
Definition: dataflow_pattern.h:550
Managed reference to FunctionPatternNode.
Definition: dataflow_pattern.h:574
FunctionPattern(tvm::Array< DFPattern > params, DFPattern body)
Constructor.
TVM_DEFINE_OBJECT_REF_METHODS(FunctionPattern, DFPattern, FunctionPatternNode)
A Pattern to Match a Relax Global Variable.
Definition: dataflow_pattern.h:435
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarPatternNode, DFPatternNode)
static constexpr const char * _type_key
Definition: dataflow_pattern.h:437
Managed reference to a GlobalVarPattern.
Definition: dataflow_pattern.h:445
GlobalVarPattern(String name_hint)
TVM_DEFINE_OBJECT_REF_METHODS(GlobalVarPattern, DFPattern, GlobalVarPatternNode)
Pattern for rejecting a certain pattern.
Definition: dataflow_pattern.h:736
DFPattern reject
Definition: dataflow_pattern.h:738
TVM_DECLARE_FINAL_OBJECT_INFO(NotPatternNode, DFPatternNode)
static void RegisterReflection()
Definition: dataflow_pattern.h:740
static constexpr const char * _type_key
Definition: dataflow_pattern.h:745
Managed reference to NotPatternNode.
Definition: dataflow_pattern.h:753
NotPattern(DFPattern reject)
TVM_DEFINE_OBJECT_REF_METHODS(NotPattern, DFPattern, NotPatternNode)
Match a disjunction of other patterns.
Definition: dataflow_pattern.h:706
DFPattern left
Definition: dataflow_pattern.h:708
DFPattern right
Definition: dataflow_pattern.h:709
TVM_DECLARE_FINAL_OBJECT_INFO(OrPatternNode, DFPatternNode)
static void RegisterReflection()
Definition: dataflow_pattern.h:711
static constexpr const char * _type_key
Definition: dataflow_pattern.h:718
Managed reference to OrPatternNode.
Definition: dataflow_pattern.h:726
TVM_DEFINE_OBJECT_REF_METHODS(OrPattern, DFPattern, OrPatternNode)
OrPattern(DFPattern left, DFPattern right)
A context to manage the graph-level pattern matching.
Definition: dataflow_pattern.h:254
ExternUse
Constrainting matched graph with assertion to external uses.
Definition: dataflow_pattern.h:257
@ kMustNot
Definition: dataflow_pattern.h:259
@ kMay
Definition: dataflow_pattern.h:258
TVM_DECLARE_FINAL_OBJECT_INFO(PatternContextNode, Object)
std::vector< DFConstraint > validation_constraints
Definition: dataflow_pattern.h:271
static constexpr const char * _type_key
Definition: dataflow_pattern.h:273
enum tvm::relax::PatternContextNode::ExternUse allow_extern_use
std::map< DFPattern, std::vector< std::pair< DFPattern, std::vector< PairCons > > > > edge_constraints
Definition: dataflow_pattern.h:264
std::vector< DFPattern > src_ordered
Definition: dataflow_pattern.h:268
Managed reference to a pattern context.
Definition: dataflow_pattern.h:281
void add_constraint(DFConstraint constraint)
Add a validation constraint.
Definition: dataflow_pattern.h:327
void ExitWithScope() const
The RAII-like exit of a constraint context scope.
PatternContextNode * operator->()
Definition: dataflow_pattern.h:291
static Optional< PatternContext > Current()
Get the constraint context object on the top of the stack.
void add_constraint(DFPattern producer, DFPattern consumer, PairCons cons)
Build an edge constraint between two patterns (producer and consumer).
Definition: dataflow_pattern.h:303
const PatternContextNode * operator->() const
Definition: dataflow_pattern.h:286
PatternContext(bool incremental=false)
PatternContext(ObjectPtr< Object > n)
Definition: dataflow_pattern.h:283
void EnterWithScope() const
The RAII-like entry of a constraint context scope.
A sequence of DFPatterns that the previous DFPattern is connected to the next one.
Definition: dataflow_pattern.h:214
TVM_DECLARE_BASE_OBJECT_INFO(PatternSeqNode, Object)
tvm::Array< DFPattern > patterns
Definition: dataflow_pattern.h:216
static void RegisterReflection()
Definition: dataflow_pattern.h:219
static constexpr const char * _type_key
Definition: dataflow_pattern.h:224
std::vector< PairCons > pair_constraints
Definition: dataflow_pattern.h:217
Managed reference to pattern sequences.
Definition: dataflow_pattern.h:232
PatternSeq(tvm::Array< DFPattern > patterns, bool only_used_by=false)
PatternSeq OnlyUsedBy(PatternSeq other, int index=-1) const
TVM_DEFINE_OBJECT_REF_METHODS(PatternSeq, ObjectRef, PatternSeqNode)
PatternSeq UsedBy(PatternSeq other, int index=-1) const
friend PatternSeq UsedBy(const PatternSeq &lhs, const PatternSeq &rhs, int index)
Create used-by relationship between lhs[-1] and rhs[0], with [*lhs, *rhs] returned.
PatternSeq(DFPattern init_pattern)
friend PatternSeq OnlyUsedBy(const PatternSeq &lhs, const PatternSeq &rhs, int index)
Create only-used-by relationship between lhs[-1] and rhs[0], with [*lhs, *rhs] returned.
PatternSeq dup() const
Syntatic Sugar for duplicating the current pattern sequence.
A pattern to match an array of PrimExpr.
Definition: dataflow_pattern.h:520
TVM_DECLARE_FINAL_OBJECT_INFO(PrimArrPatternNode, DFPatternNode)
Array< PrimExpr > fields
Definition: dataflow_pattern.h:522
static void RegisterReflection()
Definition: dataflow_pattern.h:524
static constexpr const char * _type_key
Definition: dataflow_pattern.h:529
Managed reference to a PrimArrPattern.
Definition: dataflow_pattern.h:537
TVM_DEFINE_OBJECT_REF_METHODS(PrimArrPattern, DFPattern, PrimArrPatternNode)
PrimArrPattern(Array< PrimExpr > arr)
A pattern that asserting multiple root patterns have the same shape.
Definition: dataflow_pattern.h:850
Array< DFPattern > args
Definition: dataflow_pattern.h:852
TVM_DECLARE_FINAL_OBJECT_INFO(SameShapeConstraintNode, DFConstraintNode)
static void RegisterReflection()
Definition: dataflow_pattern.h:859
std::tuple< PrimExpr, bool > AsPrimExpr(std::function< Optional< Var >(const DFPatternNode *)> match_state) const override
Convert the constraint to a PrimExpr.
static constexpr const char * _type_key
Definition: dataflow_pattern.h:864
Array< DFPattern > GetDependentPatterns() const override
Return the patterns on which the constraint depends.
Definition: dataflow_pattern.h:854
Managed reference to SameShapePatternNode.
Definition: dataflow_pattern.h:872
TVM_DEFINE_OBJECT_REF_METHODS(SameShapeConstraint, DFConstraint, SameShapeConstraintNode)
SameShapeConstraint(Array< DFPattern > args)
A pattern that asserting a root pattern has a certain shape.
Definition: dataflow_pattern.h:820
static constexpr const char * _type_key
Definition: dataflow_pattern.h:832
TVM_DECLARE_FINAL_OBJECT_INFO(ShapePatternNode, DFPatternNode)
DFPattern pattern
Definition: dataflow_pattern.h:822
Array< PrimExpr > shape
Definition: dataflow_pattern.h:823
static void RegisterReflection()
Definition: dataflow_pattern.h:825
Managed reference to ShapePatternNode.
Definition: dataflow_pattern.h:840
TVM_DEFINE_OBJECT_REF_METHODS(ShapePattern, DFPattern, ShapePatternNode)
ShapePattern(DFPattern pattern, Array< PrimExpr > type)
Pattern for matching a certain struct info.
Definition: dataflow_pattern.h:794
StructInfo struct_info
Definition: dataflow_pattern.h:797
static void RegisterReflection()
Definition: dataflow_pattern.h:799
static constexpr const char * _type_key
Definition: dataflow_pattern.h:806
TVM_DECLARE_FINAL_OBJECT_INFO(StructInfoPatternNode, DFPatternNode)
DFPattern pattern
Definition: dataflow_pattern.h:796
Definition: dataflow_pattern.h:810
TVM_DEFINE_OBJECT_REF_METHODS(StructInfoPattern, DFPattern, StructInfoPatternNode)
StructInfoPattern(DFPattern pattern, StructInfo struct_info)
Managed reference to StructInfoNode.
Definition: expr.h:135
A pattern to match n'th indexing to a tuple.
Definition: dataflow_pattern.h:646
static constexpr const char * _type_key
Definition: dataflow_pattern.h:658
static void RegisterReflection()
Definition: dataflow_pattern.h:651
TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemPatternNode, DFPatternNode)
DFPattern tuple
Definition: dataflow_pattern.h:648
int index
Definition: dataflow_pattern.h:649
Managed reference to TupleGetItemPatternNode.
Definition: dataflow_pattern.h:666
TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItemPattern, DFPattern, TupleGetItemPatternNode)
TupleGetItemPattern(DFPattern tuple, int index)
Pattern to match a tuple of ordered expressions.
Definition: dataflow_pattern.h:590
tvm::Array< DFPattern > fields
Definition: dataflow_pattern.h:592
TVM_DECLARE_FINAL_OBJECT_INFO(TuplePatternNode, DFPatternNode)
static void RegisterReflection()
Definition: dataflow_pattern.h:594
static constexpr const char * _type_key
Definition: dataflow_pattern.h:599
Managed reference to TuplePatternNode.
Definition: dataflow_pattern.h:607
TVM_DEFINE_OBJECT_REF_METHODS(TuplePattern, DFPattern, TuplePatternNode)
TuplePattern(tvm::Array< DFPattern > fields)
A pattern to match multiple expressions unorderedly.
Definition: dataflow_pattern.h:617
tvm::Array< DFPattern > fields
Definition: dataflow_pattern.h:619
static constexpr const char * _type_key
Definition: dataflow_pattern.h:627
TVM_DECLARE_FINAL_OBJECT_INFO(UnorderedTuplePatternNode, DFPatternNode)
static void RegisterReflection()
Definition: dataflow_pattern.h:621
Managed reference to UnorderedTuplePatternNode.
Definition: dataflow_pattern.h:635
UnorderedTuplePattern(tvm::Array< DFPattern > fields)
TVM_DEFINE_OBJECT_REF_METHODS(UnorderedTuplePattern, DFPattern, UnorderedTuplePatternNode)
A Pattern to Match a Relax Variable.
Definition: dataflow_pattern.h:375
static constexpr const uint32_t _type_child_slots
Definition: dataflow_pattern.h:386
String name
Definition: dataflow_pattern.h:377
static void RegisterReflection()
Definition: dataflow_pattern.h:380
const String & name_hint() const
Definition: dataflow_pattern.h:378
static constexpr const char * _type_key
Definition: dataflow_pattern.h:385
TVM_DECLARE_BASE_OBJECT_INFO(VarPatternNode, DFPatternNode)
Managed reference to a VarPattern.
Definition: dataflow_pattern.h:394
TVM_DEFINE_OBJECT_REF_METHODS(VarPattern, DFPattern, VarPatternNode)
VarPattern(String name_hint)
Create a pattern matching by variable name.
Wildcard Pattern is a pattern that can match anything.
Definition: dataflow_pattern.h:763
static void RegisterReflection()
Definition: dataflow_pattern.h:765
static constexpr const char * _type_key
Definition: dataflow_pattern.h:770
TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode)
Managed reference to WildcardPatternNode.
Definition: dataflow_pattern.h:778
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode)
Runtime primitive data type.
Definition: data_type.h:47
Base expr nodes in TVM.
Definition: repr_printer.h:91
CallPattern IsCallDPSPacked(const String &name, Optional< TuplePattern > args=std::nullopt)
Syntatic Sugar for call_dps_packed (return a tensor)
PatternSeq operator^(const PatternSeq &lhs, const PatternSeq &rhs)
Syntax sugar of UsedBy(lhs, rhs, -1).
ExprPattern IsOp(const String &op_name)
Syntatic Sugar for creating a ExprPattern base on an Op.
DFPattern IsTuple(const Array< DFPattern > &fields, bool unordered=false)
Syntatic Sugar for creating TuplePattern or UnorderedTuplePattern (unordered=true)
WildcardPattern Wildcard()
Syntatic Sugar for creating a WildcardPattern.
ConstantPattern IsConst()
Syntatic Sugar for creating a ConstantPattern.
PatternSeq UsedBy(const PatternSeq &lhs, const PatternSeq &rhs, int index=-1)
Create used-by relationship between lhs[-1] and rhs[0], with [*lhs, *rhs] returned.
PatternSeq OnlyUsedBy(const PatternSeq &lhs, const PatternSeq &rhs, int index=-1)
Create only-used-by relationship between lhs[-1] and rhs[0], with [*lhs, *rhs] returned.
ExprPattern IsExpr(const Expr &expr)
Syntatic Sugar for creating a ExprPattern.
PatternSeq operator>>(const PatternSeq &lhs, const PatternSeq &rhs)
Syntax sugar of OnlyUsedBy(lhs, rhs, -1).
CallPattern IsCallTIR(const String &name, Optional< TuplePattern > args=std::nullopt)
Syntatic Sugar for call_tir (return a tensor)
TupleGetItemPattern IsTupleGetItem(const DFPattern tuple, int index=-1)
Syntatic Sugar for creating a TupleGetItemPattern.
VarPattern IsVar(const String &name)
Syntatic Sugar for creating a VarPattern with a name.
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1945
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
Relax Types.
Constraint of a DFPattern edge (producer -> consumer) in graph-level matching.
Definition: dataflow_pattern.h:137
int index
Definition: dataflow_pattern.h:143
bool operator==(const PairCons &other) const
Definition: dataflow_pattern.h:153
enum tvm::relax::PairCons::Type type
Type
Constraint types of the edge.
Definition: dataflow_pattern.h:139
@ kOnlyUsedBy
Definition: dataflow_pattern.h:141
@ kUsedBy
Definition: dataflow_pattern.h:140
PairCons(Type t, int index=-1)
Construct a new PairCons object.
Definition: dataflow_pattern.h:151
RAII wrapper function to enter and exit a context object similar to python's with syntax.