tvm
expr.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 // Acknowledgement: Many low-level IR nodes originate from Halide.
25 #ifndef TVM_TIR_EXPR_H_
26 #define TVM_TIR_EXPR_H_
27 
28 #include <tvm/ir/expr.h>
29 #include <tvm/node/functor.h>
30 #include <tvm/node/node.h>
35 #include <tvm/runtime/data_type.h>
36 #include <tvm/tir/buffer.h>
37 #include <tvm/tir/var.h>
38 
39 #include <algorithm>
40 #include <iostream>
41 #include <limits>
42 #include <string>
43 #include <unordered_map>
44 #include <utility>
45 
46 namespace tvm {
47 namespace tir {
48 
51 
53 class StringImmNode : public PrimExprNode {
54  public:
57 
59  v->Visit("dtype", &dtype);
60  v->Visit("value", &value);
61  v->Visit("span", &span);
62  }
63 
64  bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const {
65  return equal(value, other->value);
66  }
67 
68  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
69 
70  static constexpr const char* _type_key = "tir.StringImm";
72 };
73 
78 class StringImm : public PrimExpr {
79  public:
80  TVM_DLL StringImm(String value, Span span = Span());
82 };
83 
88 class CastNode : public PrimExprNode {
89  public:
92 
94  v->Visit("dtype", &dtype);
95  v->Visit("value", &value);
96  v->Visit("span", &span);
97  }
98 
99  bool SEqualReduce(const CastNode* other, SEqualReducer equal) const {
100  return equal(dtype, other->dtype) && equal(value, other->value);
101  }
102 
103  void SHashReduce(SHashReducer hash_reduce) const {
104  hash_reduce(dtype);
105  hash_reduce(value);
106  }
107 
108  static constexpr const char* _type_key = "tir.Cast";
110 };
111 
116 class Cast : public PrimExpr {
117  public:
118  TVM_DLL Cast(DataType dtype, PrimExpr value, Span span = Span());
120 };
121 
126 template <typename T>
127 class BinaryOpNode : public PrimExprNode {
128  public:
133 
135  v->Visit("dtype", &(this->dtype));
136  v->Visit("a", &a);
137  v->Visit("b", &b);
138  v->Visit("span", &span);
139  }
140 
141  bool SEqualReduce(const T* other, SEqualReducer equal) const {
142  return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b);
143  }
144 
145  void SHashReduce(SHashReducer hash_reduce) const {
146  hash_reduce(dtype);
147  hash_reduce(a);
148  hash_reduce(b);
149  }
150 
152 };
153 
155 class AddNode : public BinaryOpNode<AddNode> {
156  public:
157  static constexpr const char* _type_key = "tir.Add";
158 };
159 
164 class Add : public PrimExpr {
165  public:
166  TVM_DLL Add(PrimExpr a, PrimExpr b, Span span = Span());
168 };
169 
171 class SubNode : public BinaryOpNode<SubNode> {
172  public:
173  static constexpr const char* _type_key = "tir.Sub";
174 };
175 
180 class Sub : public PrimExpr {
181  public:
182  TVM_DLL Sub(PrimExpr a, PrimExpr b, Span span = Span());
184 };
185 
187 class MulNode : public BinaryOpNode<MulNode> {
188  public:
189  static constexpr const char* _type_key = "tir.Mul";
190 };
191 
196 class Mul : public PrimExpr {
197  public:
198  TVM_DLL Mul(PrimExpr a, PrimExpr b, Span span = Span());
200 };
201 
206 class DivNode : public BinaryOpNode<DivNode> {
207  public:
208  static constexpr const char* _type_key = "tir.Div";
209 };
210 
215 class Div : public PrimExpr {
216  public:
217  TVM_DLL Div(PrimExpr a, PrimExpr b, Span span = Span());
219 };
220 
225 class ModNode : public BinaryOpNode<ModNode> {
226  public:
227  static constexpr const char* _type_key = "tir.Mod";
228 };
229 
234 class Mod : public PrimExpr {
235  public:
236  TVM_DLL Mod(PrimExpr a, PrimExpr b, Span span = Span());
238 };
239 
241 class FloorDivNode : public BinaryOpNode<FloorDivNode> {
242  public:
243  static constexpr const char* _type_key = "tir.FloorDiv";
244 };
245 
250 class FloorDiv : public PrimExpr {
251  public:
252  TVM_DLL FloorDiv(PrimExpr a, PrimExpr b, Span span = Span());
254 };
255 
257 class FloorModNode : public BinaryOpNode<FloorModNode> {
258  public:
259  static constexpr const char* _type_key = "tir.FloorMod";
260 };
261 
266 class FloorMod : public PrimExpr {
267  public:
268  TVM_DLL FloorMod(PrimExpr a, PrimExpr b, Span span = Span());
270 };
271 
273 class MinNode : public BinaryOpNode<MinNode> {
274  public:
275  static constexpr const char* _type_key = "tir.Min";
276 };
277 
282 class Min : public PrimExpr {
283  public:
284  TVM_DLL Min(PrimExpr a, PrimExpr b, Span span = Span());
286 };
287 
289 class MaxNode : public BinaryOpNode<MaxNode> {
290  public:
291  static constexpr const char* _type_key = "tir.Max";
292 };
293 
298 class Max : public PrimExpr {
299  public:
300  TVM_DLL Max(PrimExpr a, PrimExpr b, Span span = Span());
302 };
303 
308 template <typename T>
309 class CmpOpNode : public PrimExprNode {
310  public:
315 
317  v->Visit("dtype", &(this->dtype));
318  v->Visit("a", &a);
319  v->Visit("b", &b);
320  v->Visit("span", &span);
321  }
322 
323  bool SEqualReduce(const T* other, SEqualReducer equal) const {
324  return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b);
325  }
326 
327  void SHashReduce(SHashReducer hash_reduce) const {
328  hash_reduce(dtype);
329  hash_reduce(a);
330  hash_reduce(b);
331  }
332 
334 };
335 
337 class EQNode : public CmpOpNode<EQNode> {
338  public:
339  static constexpr const char* _type_key = "tir.EQ";
340 };
341 
346 class EQ : public PrimExpr {
347  public:
348  TVM_DLL EQ(PrimExpr a, PrimExpr b, Span span = Span());
350 };
351 
353 class NENode : public CmpOpNode<NENode> {
354  public:
355  static constexpr const char* _type_key = "tir.NE";
356 };
357 
362 class NE : public PrimExpr {
363  public:
364  TVM_DLL NE(PrimExpr a, PrimExpr b, Span span = Span());
366 };
367 
369 class LTNode : public CmpOpNode<LTNode> {
370  public:
371  static constexpr const char* _type_key = "tir.LT";
372 };
373 
378 class LT : public PrimExpr {
379  public:
380  TVM_DLL LT(PrimExpr a, PrimExpr b, Span span = Span());
382 };
383 
385 struct LENode : public CmpOpNode<LENode> {
386  public:
387  static constexpr const char* _type_key = "tir.LE";
388 };
389 
394 class LE : public PrimExpr {
395  public:
396  TVM_DLL LE(PrimExpr a, PrimExpr b, Span span = Span());
398 };
399 
401 class GTNode : public CmpOpNode<GTNode> {
402  public:
403  static constexpr const char* _type_key = "tir.GT";
404 };
405 
410 class GT : public PrimExpr {
411  public:
412  TVM_DLL GT(PrimExpr a, PrimExpr b, Span span = Span());
414 };
415 
417 class GENode : public CmpOpNode<GENode> {
418  public:
419  static constexpr const char* _type_key = "tir.GE";
420 };
421 
426 class GE : public PrimExpr {
427  public:
428  TVM_DLL GE(PrimExpr a, PrimExpr b, Span span = Span());
430 };
431 
433 class AndNode : public PrimExprNode {
434  public:
439 
441  v->Visit("dtype", &(this->dtype));
442  v->Visit("a", &a);
443  v->Visit("b", &b);
444  v->Visit("span", &span);
445  }
446 
447  bool SEqualReduce(const AndNode* other, SEqualReducer equal) const {
448  return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b);
449  }
450 
451  void SHashReduce(SHashReducer hash_reduce) const {
452  hash_reduce(dtype);
453  hash_reduce(a);
454  hash_reduce(b);
455  }
456 
457  static constexpr const char* _type_key = "tir.And";
459 };
460 
465 class And : public PrimExpr {
466  public:
467  TVM_DLL And(PrimExpr a, PrimExpr b, Span span = Span());
469 };
470 
472 class OrNode : public PrimExprNode {
473  public:
478 
480  v->Visit("dtype", &dtype);
481  v->Visit("a", &a);
482  v->Visit("b", &b);
483  v->Visit("span", &span);
484  }
485 
486  bool SEqualReduce(const OrNode* other, SEqualReducer equal) const {
487  return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b);
488  }
489 
490  void SHashReduce(SHashReducer hash_reduce) const {
491  hash_reduce(dtype);
492  hash_reduce(a);
493  hash_reduce(b);
494  }
495 
496  static constexpr const char* _type_key = "tir.Or";
498 };
499 
504 class Or : public PrimExpr {
505  public:
506  TVM_DLL Or(PrimExpr a, PrimExpr b, Span span = Span());
508 };
509 
511 class NotNode : public PrimExprNode {
512  public:
515 
517  v->Visit("dtype", &dtype);
518  v->Visit("a", &a);
519  v->Visit("span", &span);
520  }
521 
522  bool SEqualReduce(const NotNode* other, SEqualReducer equal) const {
523  return equal(dtype, other->dtype) && equal(a, other->a);
524  }
525 
526  void SHashReduce(SHashReducer hash_reduce) const {
527  hash_reduce(dtype);
528  hash_reduce(a);
529  }
530 
531  static constexpr const char* _type_key = "tir.Not";
533 };
534 
539 class Not : public PrimExpr {
540  public:
541  TVM_DLL Not(PrimExpr a, Span span = Span());
543 };
544 
552 class SelectNode : public PrimExprNode {
553  public:
560 
562  v->Visit("dtype", &dtype);
563  v->Visit("condition", &condition);
564  v->Visit("true_value", &true_value);
565  v->Visit("false_value", &false_value);
566  v->Visit("span", &span);
567  }
568 
569  bool SEqualReduce(const SelectNode* other, SEqualReducer equal) const {
570  return equal(dtype, other->dtype) && equal(condition, other->condition) &&
571  equal(true_value, other->true_value) && equal(false_value, other->false_value);
572  }
573 
574  void SHashReduce(SHashReducer hash_reduce) const {
575  hash_reduce(dtype);
576  hash_reduce(condition);
577  hash_reduce(true_value);
578  hash_reduce(false_value);
579  }
580 
581  static constexpr const char* _type_key = "tir.Select";
583 };
584 
589 class Select : public PrimExpr {
590  public:
591  TVM_DLL Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span = Span());
592 
594 };
595 
606 class BufferLoadNode : public PrimExprNode {
607  public:
612 
614  v->Visit("dtype", &(this->dtype));
615  v->Visit("buffer", &buffer);
616  v->Visit("indices", &indices);
617  v->Visit("span", &span);
618  }
619 
620  bool SEqualReduce(const BufferLoadNode* other, SEqualReducer equal) const {
621  return equal(dtype, other->dtype) && equal(buffer, other->buffer) &&
622  equal(indices, other->indices);
623  }
624 
625  void SHashReduce(SHashReducer hash_reduce) const {
626  hash_reduce(dtype);
627  hash_reduce(buffer);
628  hash_reduce(indices);
629  }
630 
631  static constexpr const char* _type_key = "tir.BufferLoad";
633 
634  private:
644  void LegalizeDType();
645  friend class BufferLoad;
646  friend class CustomDatatypesLowerer;
647  friend class VectorTypeRewriter;
648  friend class Vectorizer;
649 };
650 
655 class BufferLoad : public PrimExpr {
656  public:
657  TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices, Span span = Span());
660 };
661 
672  public:
677 
679  v->Visit("dtype", &(this->dtype));
680  v->Visit("producer", &producer);
681  v->Visit("indices", &indices);
682  v->Visit("span", &span);
683  }
684 
685  bool SEqualReduce(const ProducerLoadNode* other, SEqualReducer equal) const {
686  return equal(dtype, other->dtype) && equal(producer, other->producer) &&
687  equal(indices, other->indices);
688  }
689 
690  void SHashReduce(SHashReducer hash_reduce) const {
691  hash_reduce(dtype);
692  hash_reduce(producer);
693  hash_reduce(indices);
694  }
695 
696  static constexpr const char* _type_key = "tir.ProducerLoad";
698 };
699 
704 class ProducerLoad : public PrimExpr {
705  public:
706  TVM_DLL explicit ProducerLoad(DataProducer producer, Array<PrimExpr> indices, Span span = Span());
707 
709 };
710 
726 class LoadNode : public PrimExprNode {
727  public:
734 
736  v->Visit("dtype", &dtype);
737  v->Visit("buffer_var", &buffer_var);
738  v->Visit("index", &index);
739  v->Visit("predicate", &predicate);
740  v->Visit("span", &span);
741  }
742 
743  bool SEqualReduce(const LoadNode* other, SEqualReducer equal) const {
744  return equal(dtype, other->dtype) && equal(buffer_var, other->buffer_var) &&
745  equal(index, other->index) && equal(predicate, other->predicate);
746  }
747 
748  void SHashReduce(SHashReducer hash_reduce) const {
749  hash_reduce(dtype);
750  hash_reduce(buffer_var);
751  hash_reduce(index);
752  hash_reduce(predicate);
753  }
754 
755  static constexpr const char* _type_key = "tir.Load";
757 };
758 
763 class Load : public PrimExpr {
764  public:
765  TVM_DLL Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate,
766  Span span = Span());
768 };
769 
779 class RampNode : public PrimExprNode {
780  public:
786  int lanes;
787 
789  v->Visit("dtype", &dtype);
790  v->Visit("base", &base);
791  v->Visit("stride", &stride);
792  v->Visit("lanes", &lanes);
793  v->Visit("span", &span);
794  }
795 
796  bool SEqualReduce(const RampNode* other, SEqualReducer equal) const {
797  return equal(dtype, other->dtype) && equal(base, other->base) && equal(stride, other->stride) &&
798  equal(lanes, other->lanes);
799  }
800 
801  void SHashReduce(SHashReducer hash_reduce) const {
802  hash_reduce(dtype);
803  hash_reduce(base);
804  hash_reduce(stride);
805  hash_reduce(lanes);
806  }
807 
808  static constexpr const char* _type_key = "tir.Ramp";
810 };
811 
816 class Ramp : public PrimExpr {
817  public:
818  TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span = Span());
820 };
821 
823 class BroadcastNode : public PrimExprNode {
824  public:
828  int lanes;
829 
831  v->Visit("dtype", &dtype);
832  v->Visit("value", &value);
833  v->Visit("lanes", &lanes);
834  v->Visit("span", &span);
835  }
836 
837  bool SEqualReduce(const BroadcastNode* other, SEqualReducer equal) const {
838  return equal(dtype, other->dtype) && equal(value, other->value) && equal(lanes, other->lanes);
839  }
840 
841  void SHashReduce(SHashReducer hash_reduce) const {
842  hash_reduce(dtype);
843  hash_reduce(value);
844  hash_reduce(lanes);
845  }
846 
847  static constexpr const char* _type_key = "tir.Broadcast";
849 };
850 
855 class Broadcast : public PrimExpr {
856  public:
857  TVM_DLL Broadcast(PrimExpr value, int lanes, Span span = Span());
859 };
860 
864 class LetNode : public PrimExprNode {
865  public:
872 
874  v->Visit("dtype", &dtype);
875  v->Visit("var", &var);
876  v->Visit("value", &value);
877  v->Visit("body", &body);
878  v->Visit("span", &span);
879  }
880 
881  bool SEqualReduce(const LetNode* other, SEqualReducer equal) const {
882  return equal(dtype, other->dtype) && equal.DefEqual(var, other->var) &&
883  equal(value, other->value) && equal(body, other->body);
884  }
885 
886  void SHashReduce(SHashReducer hash_reduce) const {
887  hash_reduce(dtype);
888  hash_reduce.DefHash(var);
889  hash_reduce(value);
890  hash_reduce(body);
891  }
892 
893  static constexpr const char* _type_key = "tir.Let";
895 };
896 
901 class Let : public PrimExpr {
902  public:
903  TVM_DLL Let(Var var, PrimExpr value, PrimExpr body, Span span = Span());
905 };
906 
910 class CallNode : public PrimExprNode {
911  public:
919 
923  v->Visit("dtype", &dtype);
924  v->Visit("op", &op);
925  v->Visit("args", &args);
926  v->Visit("span", &span);
927  }
928 
929  bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
930  return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args);
931  }
932 
933  void SHashReduce(SHashReducer hash_reduce) const {
934  hash_reduce(dtype);
935  hash_reduce(op);
936  hash_reduce(args);
937  }
938 
939  static constexpr const char* _type_key = "tir.Call";
941 };
942 
947 class Call : public PrimExpr {
948  public:
949  TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, Span span = Span());
951 };
952 
958 class ShuffleNode : public PrimExprNode {
959  public:
964 
966  v->Visit("dtype", &dtype);
967  v->Visit("vectors", &vectors);
968  v->Visit("indices", &indices);
969  v->Visit("span", &span);
970  }
971 
972  bool SEqualReduce(const ShuffleNode* other, SEqualReducer equal) const {
973  return equal(dtype, other->dtype) && equal(vectors, other->vectors) &&
974  equal(indices, other->indices);
975  }
976 
977  void SHashReduce(SHashReducer hash_reduce) const {
978  hash_reduce(dtype);
979  hash_reduce(vectors);
980  hash_reduce(indices);
981  }
982 
983  static constexpr const char* _type_key = "tir.Shuffle";
985 };
986 
991 class Shuffle : public PrimExpr {
992  public:
993  TVM_DLL Shuffle(Array<PrimExpr> vectors, Array<PrimExpr> indices, Span span = Span());
994  TVM_DLL static PrimExpr Concat(Array<PrimExpr> vectors, Span span = Span());
995  TVM_DLL static PrimExpr ExtractElement(PrimExpr vector, int index, Span span = Span());
996 
998 };
999 
1000 // Reduce operator
1005 class CommReducerNode : public Object {
1006  public:
1020  Array<PrimExpr> operator()(Array<PrimExpr> a, Array<PrimExpr> b) const;
1025  mutable Span span;
1026 
1028  v->Visit("lhs", &lhs);
1029  v->Visit("rhs", &rhs);
1030  v->Visit("result", &result);
1031  v->Visit("identity_element", &identity_element);
1032  v->Visit("span", &span);
1033  }
1034 
1035  bool SEqualReduce(const CommReducerNode* other, SEqualReducer equal) const {
1036  return equal.DefEqual(lhs, other->lhs) && equal.DefEqual(rhs, other->rhs) &&
1037  equal(result, other->result) && equal(identity_element, other->identity_element);
1038  }
1039 
1040  void SHashReduce(SHashReducer hash_reduce) const {
1041  hash_reduce.DefHash(lhs);
1042  hash_reduce.DefHash(rhs);
1043  hash_reduce(result);
1044  hash_reduce(identity_element);
1045  }
1046 
1047  static constexpr const char* _type_key = "tir.CommReducer";
1048  static constexpr const bool _type_has_method_sequal_reduce = true;
1049  static constexpr const bool _type_has_method_shash_reduce = true;
1051 };
1052 
1057 class CommReducer : public ObjectRef {
1058  public:
1059  TVM_DLL CommReducer(Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result,
1060  Array<PrimExpr> identity_element, Span span = Span());
1061 
1063 };
1064 
1066 class ReduceNode : public PrimExprNode {
1067  public:
1083 
1085  v->Visit("dtype", &dtype);
1086  v->Visit("combiner", &combiner);
1087  v->Visit("source", &source);
1088  v->Visit("init", &init);
1089  v->Visit("axis", &axis);
1090  v->Visit("condition", &condition);
1091  v->Visit("value_index", &value_index);
1092  v->Visit("span", &span);
1093  }
1094 
1095  bool SEqualReduce(const ReduceNode* other, SEqualReducer equal) const {
1096  // check axis first so IterVars can define the necessary variables.
1097  return equal(dtype, other->dtype) && equal(axis, other->axis) &&
1098  equal(combiner, other->combiner) && equal(source, other->source) &&
1099  equal(init, other->init) && equal(condition, other->condition) &&
1100  equal(value_index, other->value_index);
1101  }
1102 
1103  void SHashReduce(SHashReducer hash_reduce) const {
1104  hash_reduce(dtype);
1105  hash_reduce(axis);
1106  hash_reduce(combiner);
1107  hash_reduce(source);
1108  hash_reduce(init);
1109  hash_reduce(condition);
1110  hash_reduce(value_index);
1111  }
1112 
1113  static constexpr const char* _type_key = "tir.Reduce";
1115 };
1116 
1121 class Reduce : public PrimExpr {
1122  public:
1123  TVM_DLL Reduce(CommReducer combiner, Array<PrimExpr> src, Array<IterVar> rdom, PrimExpr condition,
1124  int value_index, Array<PrimExpr> init, Span span = Span());
1125 
1127 };
1128 
1130 class AnyNode : public PrimExprNode {
1131  public:
1133  v->Visit("dtype", &dtype);
1134  v->Visit("span", &span);
1135  }
1136 
1137  bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const {
1138  return equal(dtype, other->dtype);
1139  }
1140 
1141  void SHashReduce(SHashReducer hash_reduce) const {}
1142 
1144  Var ToVar() const { return Var("any_dim", DataType::Int(32)); }
1145 
1147  SizeVar ToSizeVar() const { return SizeVar("any_dim", DataType::Int(32)); }
1148 
1149  static constexpr const char* _type_key = "tir.Any";
1151 };
1152 
1157 class Any : public PrimExpr {
1158  public:
1159  TVM_DLL Any(Span span = Span());
1160 
1162 };
1163 
1164 /*
1165  * \brief Template function to convert Map to unordered_map
1166  * Sometimes useful for API gluing when internal uses unordered_map
1167  * \param dmap The container map
1168  * \return The corresponding unordered_map.
1169  * \tparam K the key of the Map.
1170  * \tparam V the value of the Map.
1171  */
1172 template <typename K, typename V>
1173 inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
1174  std::unordered_map<K, V> ret;
1175  for (auto kv : dmap) {
1176  ret[kv.first] = kv.second;
1177  }
1178  return ret;
1179 }
1180 } // namespace tir
1181 } // namespace tvm
1182 
1183 namespace std {
1184 template <>
1185 struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectPtrHash {};
1186 } // namespace std
1187 #endif // TVM_TIR_EXPR_H_
Managed reference to MulNode.
Definition: expr.h:196
tvm::Span Span
Definition: base.h:65
Let binding. Bind var to value then evaluate body.
Definition: expr.h:864
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:440
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:1132
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:1141
bool DefEqual(const ObjectRef &lhs, const ObjectRef &rhs)
Reduce condition to comparison of two definitions, where free vars can be mapped. ...
Definition: structural_equal.h:165
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:327
Managed reference to CommReducerNode.
Definition: expr.h:1057
PrimExpr body
The result expression.
Definition: expr.h:871
Managed reference to CastNode.
Definition: expr.h:116
Var var
The variable.
Definition: expr.h:867
PrimExpr b
The right operand.
Definition: expr.h:132
PrimExpr predicate
The predicate to mask which lanes would be loaded.
Definition: expr.h:733
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:735
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:788
Definitions and helper macros for IR/AST nodes.
PrimExpr b
The right operand.
Definition: expr.h:477
Managed reference to ReduceNode.
Definition: expr.h:1121
Runtime String container types.
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:873
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:933
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:574
Array< Var > rhs
The right argument of reducer.
Definition: expr.h:1010
Var buffer_var
The buffer variable.
Definition: expr.h:729
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:102
PrimExpr a
The left operand.
Definition: expr.h:312
static constexpr const char * _type_key
Definition: expr.h:70
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:690
Base expr nodes in TVM.
static constexpr const bool _type_has_method_shash_reduce
Definition: expr.h:59
Managed reference to LTNode.
Definition: expr.h:378
PrimExpr a
The left operand.
Definition: expr.h:130
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:678
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:102
PrimExpr value
The value to be binded.
Definition: expr.h:869
Array< PrimExpr > init
The init operand.
Definition: expr.h:1073
Variables in the TIR.
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:801
String constants, only used in asserts.
Definition: expr.h:53
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
PrimExpr condition
Predicate on the reduction Only add the body to reduction if condition is true.
Definition: expr.h:1080
SizeVar ToSizeVar() const
Convert to SizeVar.
Definition: expr.h:1147
a named variable in TIR
Definition: var.h:88
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:58
Constant floating point literals in the program.
Definition: expr.h:321
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:301
PrimExpr true_value
value to be returned when condition is true.
Definition: expr.h:557
Managed reference to MinNode.
Definition: expr.h:282
PrimExpr b
The right operand.
Definition: expr.h:438
Definition: loop_state.h:456
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:625
a * b
Definition: expr.h:187
Managed reference to GTNode.
Definition: expr.h:410
bool SEqualReduce(const CallNode *other, SEqualReducer equal) const
Definition: expr.h:929
Managed reference to CallNode.
Definition: expr.h:947
Managed reference to GENode.
Definition: expr.h:426
Array< PrimExpr > identity_element
The identity element of reducer, which leaves other elements unchanged when combined with it...
Definition: expr.h:1018
int lanes
The number of lanes.
Definition: expr.h:828
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:1027
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:316
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:965
bool SEqualReduce(const LoadNode *other, SEqualReducer equal) const
Definition: expr.h:743
bool SEqualReduce(const T *other, SEqualReducer equal) const
Definition: expr.h:141
base class of all object containers.
Definition: object.h:167
Any shape.
Definition: expr.h:1130
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:1025
bool SEqualReduce(const LetNode *other, SEqualReducer equal) const
Definition: expr.h:881
Shuffle instruction. vec = concat(vectors) result = (vec[indices[0]], vec[indices[1]] ...
Definition: expr.h:958
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:103
std::unordered_map< K, V > as_unordered_map(const Map< K, V > &dmap)
Definition: expr.h:1173
Managed reference to FloorModNode.
Definition: expr.h:266
a + b
Definition: expr.h:155
Managed reference to LENode.
Definition: expr.h:394
bool SEqualReduce(const OrNode *other, SEqualReducer equal) const
Definition: expr.h:486
Constant integer literals in the program.
Definition: expr.h:275
PrimExpr base
The base value.
Definition: expr.h:782
PrimExpr index
The index locations to be loaded.
Definition: expr.h:731
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:1084
Var ToVar() const
Convert to var.
Definition: expr.h:1144
PrimExpr a
The input operand.
Definition: expr.h:514
a || b
Definition: expr.h:472
Runtime Array container types.
tvm::tir::Any Any
Definition: type.h:45
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:479
bool SEqualReduce(const SelectNode *other, SEqualReducer equal) const
Definition: expr.h:569
Array< PrimExpr > vectors
the input vectors.
Definition: expr.h:961
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:748
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:68
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:922
Managed reference to AnyNode.
Definition: expr.h:1157
Definition: span.h:115
Managed reference to DivNode.
Definition: expr.h:215
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:145
Managed reference to NENode.
Definition: expr.h:362
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:55
PrimExpr stride
The stride of each step.
Definition: expr.h:784
Managed reference to BroadcastNode.
Definition: expr.h:855
Runtime primitive data type.
Definition: data_type.h:41
bool SEqualReduce(const CommReducerNode *other, SEqualReducer equal) const
Definition: expr.h:1035
Managed reference to ModNode.
Definition: expr.h:234
Base template to implement comparison ops.
Definition: expr.h:309
CommReducer combiner
The commutative combiner.
Definition: expr.h:1069
Array< PrimExpr > source
The source operand.
Definition: expr.h:1071
bool SEqualReduce(const NotNode *other, SEqualReducer equal) const
Definition: expr.h:522
PrimExpr a
The left operand.
Definition: expr.h:436
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:516
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
ObjectRef hash functor.
Definition: object.h:624
PrimExpr b
The right operand.
Definition: expr.h:314
static constexpr const bool _type_has_method_sequal_reduce
Definition: expr.h:58
bool SEqualReduce(const T *other, SEqualReducer equal) const
Definition: expr.h:323
Managed reference to AddNode.
Definition: expr.h:164
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:526
bool SEqualReduce(const ProducerLoadNode *other, SEqualReducer equal) const
Definition: expr.h:685
Managed reference to LetNode.
Definition: expr.h:901
Create a vector where all the elements are value.
Definition: expr.h:823
PrimExpr value
Original data type.
Definition: expr.h:91
A commutative reducer node to represent a commutative binary operator with identity element...
Definition: expr.h:1005
Array< PrimExpr > indices
The indices of each element.
Definition: expr.h:963
Managed reference to OrNode.
Definition: expr.h:504
a > b
Definition: expr.h:401
tvm::IntImmNode IntImmNode
Definition: expr.h:49
int value_index
the index of this reduce node
Definition: expr.h:1082
Cast value from one data type to another.
Definition: expr.h:88
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:1040
Reference to string objects.
Definition: string.h:124
Managed reference to RelayExprNode.
Definition: expr.h:217
bool SEqualReduce(const ShuffleNode *other, SEqualReducer equal) const
Definition: expr.h:972
TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, PrimExprNode)
Managed reference to MaxNode.
Definition: expr.h:298
bool SEqualReduce(const CastNode *other, SEqualReducer equal) const
Definition: expr.h:99
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
bool SEqualReduce(const BroadcastNode *other, SEqualReducer equal) const
Definition: expr.h:837
Managed reference to LoadNode.
Definition: expr.h:763
Array< Var > lhs
The left argument of reducer.
Definition: expr.h:1008
Managed reference to FloorDivNode.
Definition: expr.h:250
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:93
tvm::FloatImmNode FloatImmNode
Definition: expr.h:50
Array< PrimExpr > indices
The location arguments.
Definition: expr.h:676
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:613
RelayExpr op
The operator(function) being invoked.
Definition: expr.h:918
bool SEqualReduce(const RampNode *other, SEqualReducer equal) const
Definition: expr.h:796
Load the value from buffer_var.
Definition: expr.h:726
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Defines the Functor data structures.
Managed reference to SubNode.
Definition: expr.h:180
Base class of all object reference.
Definition: object.h:511
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)
Define CopyOnWrite function in an ObjectRef.
Definition: object.h:785
!a
Definition: expr.h:511
a named variable represents a tensor index size
Definition: var.h:144
max(a, b)
Definition: expr.h:289
Managed reference to DataProducerNode.
Definition: buffer.h:293
bool SEqualReduce(const AndNode *other, SEqualReducer equal) const
Definition: expr.h:447
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:886
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types...
Definition: buffer.h:160
bool SEqualReduce(const AnyNode *other, SEqualReducer equal) const
Definition: expr.h:1137
Array< IterVar > axis
The reduction axis.
Definition: expr.h:1075
Symbolic n-dimensional array, to represent a memory buffer.
String value
The constant value content.
Definition: expr.h:56
PrimExpr false_value
value to be returned when condition is false.
Definition: expr.h:559
Managed reference to SelectNode.
Definition: expr.h:589
Managed reference to RampNode.
Definition: expr.h:816
Array< PrimExpr > indices
The indices location to be loaded.
Definition: expr.h:611
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:561
Managed reference to ProducerLoadNode.
Definition: expr.h:704
Construct a vector with lanes elements where its i-th element equals base + i * stride. This is useful to construct a index for a continuous vector load.
Definition: expr.h:779
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:134
Array< PrimExpr > result
The result of reducer.
Definition: expr.h:1012
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:490
DataType dtype
The runtime data type of the primitive expression.
Definition: expr.h:101
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:1103
min(a, b)
Definition: expr.h:273
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1268
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
The remainder of the floordiv.
Definition: expr.h:257
Managed reference to AndNode.
Definition: expr.h:465
Runtime Map container types.
a == b
Definition: expr.h:337
a && b
Definition: expr.h:433
DataProducer producer
The buffer producer.
Definition: expr.h:674
Base template to implement binary ops.
Definition: expr.h:127
Managed reference to ShuffleNode.
Definition: expr.h:991
Managed reference to NotNode.
Definition: expr.h:539
a < b
Definition: expr.h:369
Load value from the high dimension buffer.
Definition: expr.h:606
bool SEqualReduce(const ReduceNode *other, SEqualReducer equal) const
Definition: expr.h:1095
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:451
bool SEqualReduce(const BufferLoadNode *other, SEqualReducer equal) const
Definition: expr.h:620
PrimExpr value
The base value.
Definition: expr.h:826
a % b in the C semnatics.
Definition: expr.h:225
Reference to PrimExprNode.
Definition: expr.h:112
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:977
PrimExpr a
The left operand.
Definition: expr.h:475
Floor division, floor(a/b)
Definition: expr.h:241
a - b
Definition: expr.h:171
Array< PrimExpr > args
The arguments.
Definition: expr.h:921
Call node.
Definition: expr.h:910
Managed reference to StringImmNode.
Definition: expr.h:78
Buffer buffer
The buffer variable.
Definition: expr.h:609
return true_value if condition is true, otherwise return false_value.
Definition: expr.h:552
Reduction operator operator.
Definition: expr.h:1066
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:728
a / b in the C semnatics.
Definition: expr.h:206
a <= b
Definition: expr.h:385
Managed reference to EQNode.
Definition: expr.h:346
Load value from the result produced by the producer.
Definition: expr.h:671
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:841
bool SEqualReduce(const StringImmNode *other, SEqualReducer equal) const
Definition: expr.h:64
Array< T > Concat(Array< T > lhs, const Array< T > &rhs)
Concat two Arrays.
Definition: array.h:719
a != b
Definition: expr.h:353
Managed reference to BufferLoadNode.
Definition: expr.h:655
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:154
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:179
Base node of all primitive expressions.
Definition: expr.h:85
int lanes
Total number of lanes.
Definition: expr.h:786
a >= b
Definition: expr.h:417
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:830
PrimExpr condition
The condition.
Definition: expr.h:555