tvm
dataflow_pattern.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_DATAFLOW_PATTERN_H_
25 #define TVM_RELAX_DATAFLOW_PATTERN_H_
26 
27 #include <tvm/ffi/container/array.h>
28 #include <tvm/ffi/optional.h>
29 #include <tvm/ffi/reflection/registry.h>
30 #include <tvm/ir/expr.h>
31 #include <tvm/relax/expr.h>
32 #include <tvm/relax/type.h>
33 #include <tvm/support/with.h>
34 
35 #include <cstdint>
36 #include <functional>
37 #include <map>
38 #include <memory>
39 #include <string>
40 #include <tuple>
41 #include <utility>
42 #include <vector>
43 
44 namespace tvm {
45 
46 namespace arith {
47 class Analyzer;
48 }
49 
50 namespace relax {
51 
52 class PatternSeq;
53 class CallPattern;
54 class OrPattern;
55 class AndPattern;
56 class NotPattern;
57 class ShapePattern;
58 class StructInfoPattern;
59 class DataTypePattern;
60 class AttrPattern;
61 class SameShapeConstraint;
62 
71 TVM_DLL PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index = -1);
73 TVM_DLL PatternSeq operator^(const PatternSeq& lhs, const PatternSeq& rhs);
74 
83 TVM_DLL PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index = -1);
85 TVM_DLL PatternSeq operator>>(const PatternSeq& lhs, const PatternSeq& rhs);
86 
91 class DFPatternNode : public Object {
92  public:
93  static constexpr const uint32_t _type_child_slots = 21;
94  TVM_FFI_DECLARE_OBJECT_INFO("relax.dpl.DFPattern", DFPatternNode, Object);
95 };
96 
101 class DFPattern : public ObjectRef {
102  public:
104  template <typename... Args>
105  CallPattern operator()(Args&&... args) const;
107  TVM_DLL CallPattern operator()(const std::vector<DFPattern>& args) const;
109  TVM_DLL OrPattern operator|(const DFPattern& other) const;
111  TVM_DLL AndPattern operator&(const DFPattern& other) const;
113  TVM_DLL NotPattern operator~() const;
115  TVM_DLL AttrPattern HasAttr(const ffi::Map<ffi::String, Any>& attrs) const;
117  TVM_DLL StructInfoPattern HasStructInfo(const StructInfo& struct_info) const;
119  TVM_DLL DataTypePattern HasDtype(const DataType& dtype) const;
121  TVM_DLL DataTypePattern HasDtype(const std::string& dtype) const;
123  TVM_DLL ShapePattern HasShape(const ffi::Array<PrimExpr>& shape) const;
125  TVM_DLL SameShapeConstraint HasSameShapeAs(const DFPattern& other) const;
127  TVM_DLL DFPattern dup() const;
128 
130  TVM_DLL operator PatternSeq() const;
131 
133 };
134 
136 struct PairCons {
138  enum Type {
142  int index = -1;
150  TVM_DLL explicit PairCons(Type t, int index = -1) : type(t), index(index) {}
151 
152  bool operator==(const PairCons& other) const {
153  return type == other.type && index == other.index;
154  }
155 };
156 
164 class DFConstraintNode : public Object {
165  public:
167  virtual ffi::Array<DFPattern> GetDependentPatterns() const = 0;
168 
196  virtual std::tuple<PrimExpr, bool> AsPrimExpr(
197  std::function<ffi::Optional<Var>(const DFPatternNode*)> match_state) const = 0;
198 
199  static constexpr const uint32_t _type_child_slots = 1;
200  TVM_FFI_DECLARE_OBJECT_INFO("relax.dpl.DFConstraint", DFConstraintNode, Object);
201 };
202 
203 class DFConstraint : public ObjectRef {
204  public:
206 };
207 
212 class PatternSeqNode final : public Object {
213  public:
214  tvm::ffi::Array<DFPattern> patterns;
215  std::vector<PairCons> pair_constraints;
217  static void RegisterReflection() {
218  namespace refl = tvm::ffi::reflection;
219  refl::ObjectDef<PatternSeqNode>().def_ro("patterns", &PatternSeqNode::patterns);
220  }
221  TVM_FFI_DECLARE_OBJECT_INFO("relax.dpl.PatternSeq", PatternSeqNode, Object);
222 };
223 
228 class PatternSeq final : public ObjectRef {
229  public:
230  TVM_DLL explicit PatternSeq(DFPattern init_pattern);
231  TVM_DLL explicit PatternSeq(tvm::ffi::Array<DFPattern> patterns, bool only_used_by = false);
232 
233  PatternSeq UsedBy(PatternSeq other, int index = -1) const;
234  PatternSeq OnlyUsedBy(PatternSeq other, int index = -1) const;
235 
237  PatternSeq dup() const;
238 
239  // friend functions
240  friend PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index);
241  friend PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index);
242 
244 };
245 
250 class PatternContextNode : public Object {
251  public:
253  enum ExternUse {
257 
258  // src node -> <dst node, constraint type> constraints.
259  // Dst nodes are kept in a vector to keep them ordered.
260  std::map<DFPattern, std::vector<std::pair<DFPattern, std::vector<PairCons>>>> edge_constraints;
261 
262  // Underlying DFPattern nodes which the edge constraints may reference
263  // Kept as a separate vector of patterns to process constraints in a fixed order.
264  std::vector<DFPattern> src_ordered;
265 
266  // Non-edge constraints
267  std::vector<DFConstraint> validation_constraints;
268  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.PatternContext", PatternContextNode, Object);
269 };
270 
275 class PatternContext : public ObjectRef {
276  public:
277  explicit PatternContext(ffi::UnsafeInit tag) : ObjectRef(tag) {}
278  TVM_DLL explicit PatternContext(ObjectPtr<Object> n) : ObjectRef(n) {}
279  TVM_DLL explicit PatternContext(bool incremental = false);
280 
282  ICHECK(get() != nullptr);
283  return static_cast<const PatternContextNode*>(get());
284  }
285 
287  ICHECK(get() != nullptr);
288  return static_cast<PatternContextNode*>(get_mutable());
289  }
290 
298  void add_constraint(DFPattern producer, DFPattern consumer, PairCons cons) {
299  auto& pairs = (*this)->edge_constraints[producer];
300  auto it = std::find_if(pairs.begin(), pairs.end(),
301  [consumer](auto p) { return p.first == consumer; });
302  if (it == pairs.end()) {
303  pairs.emplace_back(consumer, std::vector{cons});
304  } else {
305  auto& vec = it->second;
306  ICHECK(std::find(vec.cbegin(), vec.cend(), cons) == vec.cend())
307  << "Constraint already exists";
308  vec.push_back(cons);
309  }
310 
311  auto& patterns = (*this)->src_ordered;
312  if (std::find(patterns.begin(), patterns.end(), producer) == patterns.end()) {
313  patterns.push_back(producer);
314  }
315  }
316 
322  void add_constraint(DFConstraint constraint) {
323  (*this)->validation_constraints.push_back(constraint);
324  }
325 
327  TVM_DLL static ffi::Optional<PatternContext> Current();
328 
330  TVM_DLL void EnterWithScope() const;
332  TVM_DLL void ExitWithScope() const;
333 
334  private:
335  friend class With<PatternContext>;
336 };
337 
343  public:
346  static void RegisterReflection() {
347  namespace refl = tvm::ffi::reflection;
348  refl::ObjectDef<ExprPatternNode>().def_ro("expr", &ExprPatternNode::expr);
349  }
351 };
352 
357 class ExprPattern : public DFPattern {
358  public:
359  TVM_DLL explicit ExprPattern(Expr expr);
361 };
362 
369  public:
370  ffi::String name;
371  const ffi::String& name_hint() const { return name; }
372 
373  static void RegisterReflection() {
374  namespace refl = tvm::ffi::reflection;
375  refl::ObjectDef<VarPatternNode>().def_ro("name", &VarPatternNode::name);
376  }
377 
378  static constexpr const uint32_t _type_child_slots = 1;
380 };
381 
386 class VarPattern : public DFPattern {
387  public:
393  TVM_DLL VarPattern(ffi::String name_hint);
395 };
396 
402  public:
403  static void RegisterReflection() {
404  namespace refl = tvm::ffi::reflection;
405  refl::ObjectDef<DataflowVarPatternNode>();
406  }
409 };
410 
416  public:
418  TVM_DLL DataflowVarPattern(ffi::String name_hint);
420 };
421 
427  public:
429  DFPatternNode);
430 };
431 
436 class GlobalVarPattern : public DFPattern {
437  public:
438  TVM_DLL GlobalVarPattern(ffi::String name_hint);
440 };
441 
447  public:
448  static void RegisterReflection() {
449  namespace refl = tvm::ffi::reflection;
450  refl::ObjectDef<ConstantPatternNode>();
451  }
453  DFPatternNode);
454 };
455 
460 class ConstantPattern : public DFPattern {
461  public:
463 };
464 
470  public:
477  tvm::ffi::Array<DFPattern> args;
486  // Todo(relax-team): Dataflow pattern for StructInfo, and match sinfo_args
487 
488  static void RegisterReflection() {
489  namespace refl = tvm::ffi::reflection;
490  refl::ObjectDef<CallPatternNode>()
491  .def_ro("op", &CallPatternNode::op)
492  .def_ro("args", &CallPatternNode::args);
493  }
495 };
496 
497 class CallPattern : public DFPattern {
498  public:
499  TVM_DLL CallPattern(DFPattern op, ffi::Array<DFPattern> args, bool varg_default_wildcard = false);
501 };
502 
509  public:
510  ffi::Array<PrimExpr> fields;
512  static void RegisterReflection() {
513  namespace refl = tvm::ffi::reflection;
514  refl::ObjectDef<PrimArrPatternNode>().def_ro("fields", &PrimArrPatternNode::fields);
515  }
517 };
518 
523 class PrimArrPattern : public DFPattern {
524  public:
525  TVM_DLL PrimArrPattern(ffi::Array<PrimExpr> arr);
527 };
528 
535  public:
536  tvm::ffi::Array<DFPattern> params;
545  static void RegisterReflection() {
546  namespace refl = tvm::ffi::reflection;
547  refl::ObjectDef<FunctionPatternNode>()
548  .def_ro("params", &FunctionPatternNode::params)
549  .def_ro("body", &FunctionPatternNode::body);
550  }
552  DFPatternNode);
553 };
554 
559 class FunctionPattern : public DFPattern {
560  public:
566  TVM_DLL FunctionPattern(tvm::ffi::Array<DFPattern> params, DFPattern body);
567 
569 };
570 
576  public:
577  tvm::ffi::Array<DFPattern> fields;
579  static void RegisterReflection() {
580  namespace refl = tvm::ffi::reflection;
581  refl::ObjectDef<TuplePatternNode>().def_ro("fields", &TuplePatternNode::fields);
582  }
584 };
585 
590 class TuplePattern : public DFPattern {
591  public:
592  TVM_DLL explicit TuplePattern(tvm::ffi::Array<DFPattern> fields);
594 };
595 
601  public:
602  tvm::ffi::Array<DFPattern> fields;
604  static void RegisterReflection() {
605  namespace refl = tvm::ffi::reflection;
606  refl::ObjectDef<UnorderedTuplePatternNode>().def_ro("fields",
608  }
610  DFPatternNode);
611 };
612 
618  public:
619  TVM_DLL explicit UnorderedTuplePattern(tvm::ffi::Array<DFPattern> fields);
622 };
623 
630  public:
632  int index;
634  static void RegisterReflection() {
635  namespace refl = tvm::ffi::reflection;
636  refl::ObjectDef<TupleGetItemPatternNode>()
637  .def_ro("tuple", &TupleGetItemPatternNode::tuple)
638  .def_ro("index", &TupleGetItemPatternNode::index);
639  }
641  DFPatternNode);
642 };
643 
649  public:
650  TVM_DLL TupleGetItemPattern(DFPattern tuple, int index);
653 };
654 
660  public:
664  static void RegisterReflection() {
665  namespace refl = tvm::ffi::reflection;
666  refl::ObjectDef<AndPatternNode>()
667  .def_ro("left", &AndPatternNode::left)
668  .def_ro("right", &AndPatternNode::right);
669  }
671 };
672 
677 class AndPattern : public DFPattern {
678  public:
679  TVM_DLL AndPattern(DFPattern lhs, DFPattern rhs);
681 };
682 
687 class OrPatternNode : public DFPatternNode {
688  public:
692  static void RegisterReflection() {
693  namespace refl = tvm::ffi::reflection;
694  refl::ObjectDef<OrPatternNode>()
695  .def_ro("left", &OrPatternNode::left)
696  .def_ro("right", &OrPatternNode::right);
697  }
699 };
700 
705 class OrPattern : public DFPattern {
706  public:
707  TVM_DLL OrPattern(DFPattern left, DFPattern right);
709 };
710 
716  public:
719  static void RegisterReflection() {
720  namespace refl = tvm::ffi::reflection;
721  refl::ObjectDef<NotPatternNode>().def_ro("reject", &NotPatternNode::reject);
722  }
724 };
725 
730 class NotPattern : public DFPattern {
731  public:
732  TVM_DLL NotPattern(DFPattern reject);
734 };
735 
741  public:
742  static void RegisterReflection() {
743  namespace refl = tvm::ffi::reflection;
744  refl::ObjectDef<WildcardPatternNode>();
745  }
747  DFPatternNode);
748 };
749 
754 class WildcardPattern : public DFPattern {
755  public:
757  explicit WildcardPattern(ObjectPtr<WildcardPatternNode> data) : DFPattern(ffi::UnsafeInit{}) {
758  TVM_FFI_ICHECK(data != nullptr);
759  data_ = std::move(data);
760  }
761 
762  // Declaring WildcardPattern declared as non-nullable avoids the
763  // default zero-parameter constructor for ObjectRef with `data_ =
764  // nullptr`. This allows a zero-parameter constructor to be
765  // declared here, to create a valid wildcard instance.
766 
768 };
769 
775  public:
779  static void RegisterReflection() {
780  namespace refl = tvm::ffi::reflection;
781  refl::ObjectDef<StructInfoPatternNode>()
782  .def_ro("pattern", &StructInfoPatternNode::pattern)
783  .def_ro("struct_info", &StructInfoPatternNode::struct_info);
784  }
786  DFPatternNode);
787 };
788 
789 class StructInfoPattern : public DFPattern {
790  public:
791  TVM_DLL StructInfoPattern(DFPattern pattern, StructInfo struct_info);
793 };
794 
800  public:
802  ffi::Array<PrimExpr> shape;
804  static void RegisterReflection() {
805  namespace refl = tvm::ffi::reflection;
806  refl::ObjectDef<ShapePatternNode>()
807  .def_ro("pattern", &ShapePatternNode::pattern)
808  .def_ro("shape", &ShapePatternNode::shape);
809  }
811 };
812 
817 class ShapePattern : public DFPattern {
818  public:
819  TVM_DLL ShapePattern(DFPattern pattern, ffi::Array<PrimExpr> type);
821 };
822 
828  public:
829  ffi::Array<DFPattern> args;
831  ffi::Array<DFPattern> GetDependentPatterns() const override { return args; }
832 
833  std::tuple<PrimExpr, bool> AsPrimExpr(
834  std::function<ffi::Optional<Var>(const DFPatternNode*)> match_state) const override;
835 
836  static void RegisterReflection() {
837  namespace refl = tvm::ffi::reflection;
838  refl::ObjectDef<SameShapeConstraintNode>().def_ro("args", &SameShapeConstraintNode::args);
839  }
842 };
843 
849  public:
850  TVM_DLL SameShapeConstraint(ffi::Array<DFPattern> args);
853 };
854 
860  public:
864  static void RegisterReflection() {
865  namespace refl = tvm::ffi::reflection;
866  refl::ObjectDef<DataTypePatternNode>()
867  .def_ro("pattern", &DataTypePatternNode::pattern)
868  .def_ro("dtype", &DataTypePatternNode::dtype);
869  }
871  DFPatternNode);
872 };
873 
878 class DataTypePattern : public DFPattern {
879  public:
880  TVM_DLL DataTypePattern(DFPattern pattern, DataType dtype);
882 };
883 
889  public:
893  static void RegisterReflection() {
894  namespace refl = tvm::ffi::reflection;
895  refl::ObjectDef<AttrPatternNode>()
896  .def_ro("pattern", &AttrPatternNode::pattern)
897  .def_ro("attrs", &AttrPatternNode::attrs);
898  }
900 };
901 
906 class AttrPattern : public DFPattern {
907  public:
908  TVM_DLL AttrPattern(DFPattern pattern, DictAttrs attrs);
910 };
911 
918  public:
919  ffi::String global_symbol_;
922  const ffi::String& global_symbol() const { return global_symbol_; }
923 
924  static void RegisterReflection() {
925  namespace refl = tvm::ffi::reflection;
926  refl::ObjectDef<ExternFuncPatternNode>().def_ro("global_symbol",
928  }
930  DFPatternNode);
931 };
932 
937 class ExternFuncPattern : public DFPattern {
938  public:
939  TVM_DLL ExternFuncPattern(ffi::String global_symbol);
941 };
942 
944 VarPattern IsVar(const ffi::String& name);
950 ExprPattern IsExpr(const Expr& expr);
952 ExprPattern IsOp(const ffi::String& op_name);
954 // Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo
955 CallPattern IsCallTIR(const ffi::String& name, ffi::Optional<TuplePattern> args = std::nullopt);
957 CallPattern IsCallTIR(const ffi::String& name, TuplePattern var_args);
959 CallPattern IsCallDPSPacked(const ffi::String& name,
960  ffi::Optional<TuplePattern> args = std::nullopt);
962 CallPattern IsCallDPSPacked(const ffi::String& name, TuplePattern var_args);
964 DFPattern IsTuple(const ffi::Array<DFPattern>& fields, bool unordered = false);
966 TupleGetItemPattern IsTupleGetItem(const DFPattern tuple, int index = -1);
967 
969 template <typename... Args>
970 CallPattern DFPattern::operator()(Args&&... args) const {
971  return CallPattern(ffi::GetRef<DFPattern>(this->get()),
972  ffi::Array<DFPattern>({std::forward<Args>(args)...}));
973 }
974 
975 } // namespace relax
976 } // namespace tvm
977 #endif // TVM_RELAX_DATAFLOW_PATTERN_H_
Managed reference to DictAttrsNode.
Definition: attrs.h:162
Managed reference to RelaxExprNode.
Definition: expr.h:439
RAII wrapper function to enter and exit a context object similar to python's with syntax.
Definition: with.h:58
Match a conjunction of other patterns.
Definition: dataflow_pattern.h:659
DFPattern left
Definition: dataflow_pattern.h:661
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.AndPattern", AndPatternNode, DFPatternNode)
DFPattern right
Definition: dataflow_pattern.h:662
static void RegisterReflection()
Definition: dataflow_pattern.h:664
Managed reference to AndPatternNode.
Definition: dataflow_pattern.h:677
AndPattern(DFPattern lhs, DFPattern rhs)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AndPattern, DFPattern, AndPatternNode)
A pattern that asserting a root pattern has certain attributes.
Definition: dataflow_pattern.h:888
DictAttrs attrs
Definition: dataflow_pattern.h:891
DFPattern pattern
Definition: dataflow_pattern.h:890
static void RegisterReflection()
Definition: dataflow_pattern.h:893
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.AttrPattern", AttrPatternNode, DFPatternNode)
Managed reference to AttrPatternNode.
Definition: dataflow_pattern.h:906
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AttrPattern, DFPattern, AttrPatternNode)
AttrPattern(DFPattern pattern, DictAttrs attrs)
A pattern to match a callable node in Relax.
Definition: dataflow_pattern.h:469
bool varg_default_wildcard
Definition: dataflow_pattern.h:484
static void RegisterReflection()
Definition: dataflow_pattern.h:488
tvm::ffi::Array< DFPattern > args
Definition: dataflow_pattern.h:477
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.CallPattern", CallPatternNode, DFPatternNode)
DFPattern op
Definition: dataflow_pattern.h:476
Definition: dataflow_pattern.h:497
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CallPattern, DFPattern, CallPatternNode)
CallPattern(DFPattern op, ffi::Array< DFPattern > args, bool varg_default_wildcard=false)
A Pattern to Match a Relax Constant.
Definition: dataflow_pattern.h:446
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.ConstantPattern", ConstantPatternNode, DFPatternNode)
static void RegisterReflection()
Definition: dataflow_pattern.h:448
Managed reference to a ConstantPattern.
Definition: dataflow_pattern.h:460
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ConstantPattern, DFPattern, ConstantPatternNode)
Additional constraints on the graph.
Definition: dataflow_pattern.h:164
TVM_FFI_DECLARE_OBJECT_INFO("relax.dpl.DFConstraint", DFConstraintNode, Object)
virtual std::tuple< PrimExpr, bool > AsPrimExpr(std::function< ffi::Optional< Var >(const DFPatternNode *)> match_state) const =0
Convert the constraint to a PrimExpr.
virtual ffi::Array< DFPattern > GetDependentPatterns() const =0
Return the patterns on which the constraint depends.
static constexpr const uint32_t _type_child_slots
Definition: dataflow_pattern.h:199
Definition: dataflow_pattern.h:203
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DFConstraint, ObjectRef, DFConstraintNode)
Base type of all dataflow patterns.
Definition: dataflow_pattern.h:91
static constexpr const uint32_t _type_child_slots
Definition: dataflow_pattern.h:93
TVM_FFI_DECLARE_OBJECT_INFO("relax.dpl.DFPattern", DFPatternNode, Object)
Managed reference to dataflow patterns.
Definition: dataflow_pattern.h:101
CallPattern operator()(const std::vector< DFPattern > &args) const
Syntatic Sugar for creating a CallPattern.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DFPattern, ObjectRef, DFPatternNode)
DataTypePattern HasDtype(const DataType &dtype) const
Syntatic Sugar for creating a DataTypePattern with a DataType.
SameShapeConstraint HasSameShapeAs(const DFPattern &other) const
Syntatic Sugar for creating a ShapePattern.
OrPattern operator|(const DFPattern &other) const
Syntatic Sugar for creating an OrPattern.
CallPattern operator()(Args &&... args) const
Syntatic Sugar for creating a CallPattern.
Definition: dataflow_pattern.h:970
AttrPattern HasAttr(const ffi::Map< ffi::String, Any > &attrs) const
Syntatic Sugar for creating an AttrPattern.
StructInfoPattern HasStructInfo(const StructInfo &struct_info) const
Syntatic Sugar for creating a StructInfoPattern.
ShapePattern HasShape(const ffi::Array< PrimExpr > &shape) const
Syntatic Sugar for creating a ShapePattern.
AndPattern operator&(const DFPattern &other) const
Syntatic Sugar for creating an AndPattern.
DFPattern dup() const
Syntatic Sugar for duplicating the current pattern.
NotPattern operator~() const
Syntatic Sugar for creating a NotPattern.
DataTypePattern HasDtype(const std::string &dtype) const
Syntatic Sugar for creating a DataTypePattern with a data type's name.
A pattern that asserting a root pattern has a certain data type.
Definition: dataflow_pattern.h:859
DFPattern pattern
Definition: dataflow_pattern.h:861
DataType dtype
Definition: dataflow_pattern.h:862
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.DataTypePattern", DataTypePatternNode, DFPatternNode)
static void RegisterReflection()
Definition: dataflow_pattern.h:864
Managed reference to DataTypePatternNode.
Definition: dataflow_pattern.h:878
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataTypePattern, DFPattern, DataTypePatternNode)
DataTypePattern(DFPattern pattern, DataType dtype)
A Pattern to Match a Relax Dataflow Variable.
Definition: dataflow_pattern.h:401
static void RegisterReflection()
Definition: dataflow_pattern.h:403
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.DataflowVarPattern", DataflowVarPatternNode, VarPatternNode)
Managed reference to a DataflowVarPattern.
Definition: dataflow_pattern.h:415
DataflowVarPattern(ffi::String name_hint)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataflowVarPattern, DFPattern, DataflowVarPatternNode)
Pattern for Relax Expression.
Definition: dataflow_pattern.h:342
Expr expr
Definition: dataflow_pattern.h:344
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.ExprPattern", ExprPatternNode, DFPatternNode)
static void RegisterReflection()
Definition: dataflow_pattern.h:346
Managed reference to an ExprPattern.
Definition: dataflow_pattern.h:357
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ExprPattern, DFPattern, ExprPatternNode)
A pattern of external function.
Definition: dataflow_pattern.h:917
const ffi::String & global_symbol() const
The external function name.
Definition: dataflow_pattern.h:922
static void RegisterReflection()
Definition: dataflow_pattern.h:924
ffi::String global_symbol_
Definition: dataflow_pattern.h:919
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.ExternFuncPattern", ExternFuncPatternNode, DFPatternNode)
Managed reference to ExternFuncPatternNode.
Definition: dataflow_pattern.h:937
ExternFuncPattern(ffi::String global_symbol)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ExternFuncPattern, DFPattern, ExternFuncPatternNode)
A pattern to match a Relax Function.
Definition: dataflow_pattern.h:534
static void RegisterReflection()
Definition: dataflow_pattern.h:545
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.FunctionPattern", FunctionPatternNode, DFPatternNode)
DFPattern body
Definition: dataflow_pattern.h:543
tvm::ffi::Array< DFPattern > params
Definition: dataflow_pattern.h:536
Managed reference to FunctionPatternNode.
Definition: dataflow_pattern.h:559
FunctionPattern(tvm::ffi::Array< DFPattern > params, DFPattern body)
Constructor.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FunctionPattern, DFPattern, FunctionPatternNode)
A Pattern to Match a Relax Global Variable.
Definition: dataflow_pattern.h:426
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.GlobalVarPattern", GlobalVarPatternNode, DFPatternNode)
Managed reference to a GlobalVarPattern.
Definition: dataflow_pattern.h:436
GlobalVarPattern(ffi::String name_hint)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GlobalVarPattern, DFPattern, GlobalVarPatternNode)
Pattern for rejecting a certain pattern.
Definition: dataflow_pattern.h:715
DFPattern reject
Definition: dataflow_pattern.h:717
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.NotPattern", NotPatternNode, DFPatternNode)
static void RegisterReflection()
Definition: dataflow_pattern.h:719
Managed reference to NotPatternNode.
Definition: dataflow_pattern.h:730
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(NotPattern, DFPattern, NotPatternNode)
NotPattern(DFPattern reject)
Match a disjunction of other patterns.
Definition: dataflow_pattern.h:687
DFPattern left
Definition: dataflow_pattern.h:689
DFPattern right
Definition: dataflow_pattern.h:690
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.OrPattern", OrPatternNode, DFPatternNode)
static void RegisterReflection()
Definition: dataflow_pattern.h:692
Managed reference to OrPatternNode.
Definition: dataflow_pattern.h:705
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(OrPattern, DFPattern, OrPatternNode)
OrPattern(DFPattern left, DFPattern right)
A context to manage the graph-level pattern matching.
Definition: dataflow_pattern.h:250
ExternUse
Constrainting matched graph with assertion to external uses.
Definition: dataflow_pattern.h:253
@ kMustNot
Definition: dataflow_pattern.h:255
@ kMay
Definition: dataflow_pattern.h:254
std::vector< DFConstraint > validation_constraints
Definition: dataflow_pattern.h:267
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.PatternContext", PatternContextNode, Object)
enum tvm::relax::PatternContextNode::ExternUse allow_extern_use
std::map< DFPattern, std::vector< std::pair< DFPattern, std::vector< PairCons > > > > edge_constraints
Definition: dataflow_pattern.h:260
std::vector< DFPattern > src_ordered
Definition: dataflow_pattern.h:264
Managed reference to a pattern context.
Definition: dataflow_pattern.h:275
void add_constraint(DFConstraint constraint)
Add a validation constraint.
Definition: dataflow_pattern.h:322
void ExitWithScope() const
The RAII-like exit of a constraint context scope.
static ffi::Optional< PatternContext > Current()
Get the constraint context object on the top of the stack.
PatternContextNode * operator->()
Definition: dataflow_pattern.h:286
void add_constraint(DFPattern producer, DFPattern consumer, PairCons cons)
Build an edge constraint between two patterns (producer and consumer).
Definition: dataflow_pattern.h:298
const PatternContextNode * operator->() const
Definition: dataflow_pattern.h:281
PatternContext(bool incremental=false)
PatternContext(ffi::UnsafeInit tag)
Definition: dataflow_pattern.h:277
PatternContext(ObjectPtr< Object > n)
Definition: dataflow_pattern.h:278
void EnterWithScope() const
The RAII-like entry of a constraint context scope.
A sequence of DFPatterns that the previous DFPattern is connected to the next one.
Definition: dataflow_pattern.h:212
TVM_FFI_DECLARE_OBJECT_INFO("relax.dpl.PatternSeq", PatternSeqNode, Object)
static void RegisterReflection()
Definition: dataflow_pattern.h:217
std::vector< PairCons > pair_constraints
Definition: dataflow_pattern.h:215
tvm::ffi::Array< DFPattern > patterns
Definition: dataflow_pattern.h:214
Managed reference to pattern sequences.
Definition: dataflow_pattern.h:228
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PatternSeq, ObjectRef, PatternSeqNode)
PatternSeq OnlyUsedBy(PatternSeq other, int index=-1) const
PatternSeq UsedBy(PatternSeq other, int index=-1) const
PatternSeq(tvm::ffi::Array< DFPattern > patterns, bool only_used_by=false)
friend PatternSeq UsedBy(const PatternSeq &lhs, const PatternSeq &rhs, int index)
Create used-by relationship between lhs[-1] and rhs[0], with [*lhs, *rhs] returned.
PatternSeq(DFPattern init_pattern)
friend PatternSeq OnlyUsedBy(const PatternSeq &lhs, const PatternSeq &rhs, int index)
Create only-used-by relationship between lhs[-1] and rhs[0], with [*lhs, *rhs] returned.
PatternSeq dup() const
Syntatic Sugar for duplicating the current pattern sequence.
A pattern to match an array of PrimExpr.
Definition: dataflow_pattern.h:508
ffi::Array< PrimExpr > fields
Definition: dataflow_pattern.h:510
static void RegisterReflection()
Definition: dataflow_pattern.h:512
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.PrimArrPattern", PrimArrPatternNode, DFPatternNode)
Managed reference to a PrimArrPattern.
Definition: dataflow_pattern.h:523
PrimArrPattern(ffi::Array< PrimExpr > arr)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimArrPattern, DFPattern, PrimArrPatternNode)
A pattern that asserting multiple root patterns have the same shape.
Definition: dataflow_pattern.h:827
ffi::Array< DFPattern > GetDependentPatterns() const override
Return the patterns on which the constraint depends.
Definition: dataflow_pattern.h:831
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.SameShapeConstraint", SameShapeConstraintNode, DFConstraintNode)
static void RegisterReflection()
Definition: dataflow_pattern.h:836
std::tuple< PrimExpr, bool > AsPrimExpr(std::function< ffi::Optional< Var >(const DFPatternNode *)> match_state) const override
Convert the constraint to a PrimExpr.
ffi::Array< DFPattern > args
Definition: dataflow_pattern.h:829
Managed reference to SameShapePatternNode.
Definition: dataflow_pattern.h:848
SameShapeConstraint(ffi::Array< DFPattern > args)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SameShapeConstraint, DFConstraint, SameShapeConstraintNode)
A pattern that asserting a root pattern has a certain shape.
Definition: dataflow_pattern.h:799
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.ShapePattern", ShapePatternNode, DFPatternNode)
ffi::Array< PrimExpr > shape
Definition: dataflow_pattern.h:802
DFPattern pattern
Definition: dataflow_pattern.h:801
static void RegisterReflection()
Definition: dataflow_pattern.h:804
Managed reference to ShapePatternNode.
Definition: dataflow_pattern.h:817
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ShapePattern, DFPattern, ShapePatternNode)
ShapePattern(DFPattern pattern, ffi::Array< PrimExpr > type)
Pattern for matching a certain struct info.
Definition: dataflow_pattern.h:774
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.StructInfoPattern", StructInfoPatternNode, DFPatternNode)
StructInfo struct_info
Definition: dataflow_pattern.h:777
static void RegisterReflection()
Definition: dataflow_pattern.h:779
DFPattern pattern
Definition: dataflow_pattern.h:776
Definition: dataflow_pattern.h:789
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StructInfoPattern, DFPattern, StructInfoPatternNode)
StructInfoPattern(DFPattern pattern, StructInfo struct_info)
Managed reference to StructInfoNode.
Definition: expr.h:132
A pattern to match n'th indexing to a tuple.
Definition: dataflow_pattern.h:629
static void RegisterReflection()
Definition: dataflow_pattern.h:634
DFPattern tuple
Definition: dataflow_pattern.h:631
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.TupleGetItemPattern", TupleGetItemPatternNode, DFPatternNode)
int index
Definition: dataflow_pattern.h:632
Managed reference to TupleGetItemPatternNode.
Definition: dataflow_pattern.h:648
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TupleGetItemPattern, DFPattern, TupleGetItemPatternNode)
TupleGetItemPattern(DFPattern tuple, int index)
Pattern to match a tuple of ordered expressions.
Definition: dataflow_pattern.h:575
tvm::ffi::Array< DFPattern > fields
Definition: dataflow_pattern.h:577
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.TuplePattern", TuplePatternNode, DFPatternNode)
static void RegisterReflection()
Definition: dataflow_pattern.h:579
Managed reference to TuplePatternNode.
Definition: dataflow_pattern.h:590
TuplePattern(tvm::ffi::Array< DFPattern > fields)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TuplePattern, DFPattern, TuplePatternNode)
A pattern to match multiple expressions unorderedly.
Definition: dataflow_pattern.h:600
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.UnorderedTuplePattern", UnorderedTuplePatternNode, DFPatternNode)
tvm::ffi::Array< DFPattern > fields
Definition: dataflow_pattern.h:602
static void RegisterReflection()
Definition: dataflow_pattern.h:604
Managed reference to UnorderedTuplePatternNode.
Definition: dataflow_pattern.h:617
UnorderedTuplePattern(tvm::ffi::Array< DFPattern > fields)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(UnorderedTuplePattern, DFPattern, UnorderedTuplePatternNode)
A Pattern to Match a Relax Variable.
Definition: dataflow_pattern.h:368
static constexpr const uint32_t _type_child_slots
Definition: dataflow_pattern.h:378
ffi::String name
Definition: dataflow_pattern.h:370
const ffi::String & name_hint() const
Definition: dataflow_pattern.h:371
TVM_FFI_DECLARE_OBJECT_INFO("relax.dpl.VarPattern", VarPatternNode, DFPatternNode)
static void RegisterReflection()
Definition: dataflow_pattern.h:373
Managed reference to a VarPattern.
Definition: dataflow_pattern.h:386
VarPattern(ffi::String name_hint)
Create a pattern matching by variable name.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(VarPattern, DFPattern, VarPatternNode)
Wildcard Pattern is a pattern that can match anything.
Definition: dataflow_pattern.h:740
static void RegisterReflection()
Definition: dataflow_pattern.h:742
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.WildcardPattern", WildcardPatternNode, DFPatternNode)
Managed reference to WildcardPatternNode.
Definition: dataflow_pattern.h:754
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(WildcardPattern, DFPattern, WildcardPatternNode)
WildcardPattern(ObjectPtr< WildcardPatternNode > data)
Definition: dataflow_pattern.h:757
Runtime primitive data type.
Definition: data_type.h:47
Base expr nodes in TVM.
Definition: repr_printer.h:91
PatternSeq operator^(const PatternSeq &lhs, const PatternSeq &rhs)
Syntax sugar of UsedBy(lhs, rhs, -1).
ExprPattern IsOp(const ffi::String &op_name)
Syntatic Sugar for creating a ExprPattern base on an Op.
DFPattern IsTuple(const ffi::Array< DFPattern > &fields, bool unordered=false)
Syntatic Sugar for creating TuplePattern or UnorderedTuplePattern (unordered=true)
CallPattern IsCallTIR(const ffi::String &name, ffi::Optional< TuplePattern > args=std::nullopt)
Syntatic Sugar for call_tir (return a tensor)
CallPattern IsCallDPSPacked(const ffi::String &name, ffi::Optional< TuplePattern > args=std::nullopt)
Syntatic Sugar for call_dps_packed (return a tensor)
WildcardPattern Wildcard()
Syntatic Sugar for creating a WildcardPattern.
VarPattern IsVar(const ffi::String &name)
Syntatic Sugar for creating a VarPattern with a name.
ConstantPattern IsConst()
Syntatic Sugar for creating a ConstantPattern.
PatternSeq UsedBy(const PatternSeq &lhs, const PatternSeq &rhs, int index=-1)
Create used-by relationship between lhs[-1] and rhs[0], with [*lhs, *rhs] returned.
PatternSeq OnlyUsedBy(const PatternSeq &lhs, const PatternSeq &rhs, int index=-1)
Create only-used-by relationship between lhs[-1] and rhs[0], with [*lhs, *rhs] returned.
ExprPattern IsExpr(const Expr &expr)
Syntatic Sugar for creating a ExprPattern.
PatternSeq operator>>(const PatternSeq &lhs, const PatternSeq &rhs)
Syntax sugar of OnlyUsedBy(lhs, rhs, -1).
TupleGetItemPattern IsTupleGetItem(const DFPattern tuple, int index=-1)
Syntatic Sugar for creating a TupleGetItemPattern.
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1960
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
Relax Types.
Constraint of a DFPattern edge (producer -> consumer) in graph-level matching.
Definition: dataflow_pattern.h:136
int index
Definition: dataflow_pattern.h:142
bool operator==(const PairCons &other) const
Definition: dataflow_pattern.h:152
enum tvm::relax::PairCons::Type type
Type
Constraint types of the edge.
Definition: dataflow_pattern.h:138
@ kOnlyUsedBy
Definition: dataflow_pattern.h:140
@ kUsedBy
Definition: dataflow_pattern.h:139
PairCons(Type t, int index=-1)
Construct a new PairCons object.
Definition: dataflow_pattern.h:150
RAII wrapper function to enter and exit a context object similar to python's with syntax.