tvm
expr.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 // Acknowledgement: Many low-level IR nodes originate from Halide.
25 #ifndef TVM_TIR_EXPR_H_
26 #define TVM_TIR_EXPR_H_
27 
28 #include <tvm/ffi/container/array.h>
29 #include <tvm/ffi/container/map.h>
30 #include <tvm/ffi/string.h>
31 #include <tvm/ir/expr.h>
32 #include <tvm/node/functor.h>
33 #include <tvm/node/node.h>
34 #include <tvm/runtime/base.h>
35 #include <tvm/runtime/data_type.h>
36 #include <tvm/tir/buffer.h>
37 #include <tvm/tir/var.h>
38 
39 #include <algorithm>
40 #include <iostream>
41 #include <limits>
42 #include <string>
43 #include <unordered_map>
44 #include <utility>
45 
46 namespace tvm {
47 namespace tir {
48 
51 
53 class StringImmNode : public PrimExprNode {
54  public:
56  String value;
57 
58  static void RegisterReflection() {
59  namespace refl = tvm::ffi::reflection;
60  refl::ObjectDef<StringImmNode>().def_ro("value", &StringImmNode::value);
61  }
62 
63  static constexpr const char* _type_key = "tir.StringImm";
65 };
66 
71 class StringImm : public PrimExpr {
72  public:
73  TVM_DLL StringImm(String value, Span span = Span());
76 };
77 
82 class CastNode : public PrimExprNode {
83  public:
86 
87  static void RegisterReflection() {
88  namespace refl = tvm::ffi::reflection;
89  refl::ObjectDef<CastNode>().def_ro("value", &CastNode::value);
90  }
91 
92  static constexpr const char* _type_key = "tir.Cast";
94 };
95 
100 class Cast : public PrimExpr {
101  public:
102  TVM_DLL Cast(DataType dtype, PrimExpr value, Span span = Span());
105 };
106 
111 template <typename T>
112 class BinaryOpNode : public PrimExprNode {
113  public:
118 
119  static void RegisterReflection() {
120  namespace refl = tvm::ffi::reflection;
121  refl::ObjectDef<T>().def_ro("a", &T::a).def_ro("b", &T::b);
122  }
123 
125 };
126 
128 class AddNode : public BinaryOpNode<AddNode> {
129  public:
130  static constexpr const char* _type_key = "tir.Add";
131 };
132 
137 class Add : public PrimExpr {
138  public:
139  TVM_DLL Add(PrimExpr a, PrimExpr b, Span span = Span());
142 };
143 
145 class SubNode : public BinaryOpNode<SubNode> {
146  public:
147  static constexpr const char* _type_key = "tir.Sub";
148 };
149 
154 class Sub : public PrimExpr {
155  public:
156  TVM_DLL Sub(PrimExpr a, PrimExpr b, Span span = Span());
157 
160 };
161 
163 class MulNode : public BinaryOpNode<MulNode> {
164  public:
165  static constexpr const char* _type_key = "tir.Mul";
166 };
167 
172 class Mul : public PrimExpr {
173  public:
174  TVM_DLL Mul(PrimExpr a, PrimExpr b, Span span = Span());
177 };
178 
183 class DivNode : public BinaryOpNode<DivNode> {
184  public:
185  static constexpr const char* _type_key = "tir.Div";
186 };
187 
192 class Div : public PrimExpr {
193  public:
194  TVM_DLL Div(PrimExpr a, PrimExpr b, Span span = Span());
197 };
198 
203 class ModNode : public BinaryOpNode<ModNode> {
204  public:
205  static constexpr const char* _type_key = "tir.Mod";
206 };
207 
212 class Mod : public PrimExpr {
213  public:
214  TVM_DLL Mod(PrimExpr a, PrimExpr b, Span span = Span());
217 };
218 
220 class FloorDivNode : public BinaryOpNode<FloorDivNode> {
221  public:
222  static constexpr const char* _type_key = "tir.FloorDiv";
223 };
224 
229 class FloorDiv : public PrimExpr {
230  public:
231  TVM_DLL FloorDiv(PrimExpr a, PrimExpr b, Span span = Span());
234 };
235 
237 class FloorModNode : public BinaryOpNode<FloorModNode> {
238  public:
239  static constexpr const char* _type_key = "tir.FloorMod";
240 };
241 
246 class FloorMod : public PrimExpr {
247  public:
248  TVM_DLL FloorMod(PrimExpr a, PrimExpr b, Span span = Span());
251 };
252 
254 class MinNode : public BinaryOpNode<MinNode> {
255  public:
256  static constexpr const char* _type_key = "tir.Min";
257 };
258 
263 class Min : public PrimExpr {
264  public:
265  TVM_DLL Min(PrimExpr a, PrimExpr b, Span span = Span());
268 };
269 
271 class MaxNode : public BinaryOpNode<MaxNode> {
272  public:
273  static constexpr const char* _type_key = "tir.Max";
274 };
275 
280 class Max : public PrimExpr {
281  public:
282  TVM_DLL Max(PrimExpr a, PrimExpr b, Span span = Span());
285 };
286 
291 template <typename T>
292 class CmpOpNode : public PrimExprNode {
293  public:
298 
299  static void RegisterReflection() {
300  namespace refl = tvm::ffi::reflection;
301  refl::ObjectDef<T>().def_ro("a", &T::a).def_ro("b", &T::b);
302  }
303 
305 };
306 
308 class EQNode : public CmpOpNode<EQNode> {
309  public:
310  static constexpr const char* _type_key = "tir.EQ";
311 };
312 
317 class EQ : public PrimExpr {
318  public:
319  TVM_DLL EQ(PrimExpr a, PrimExpr b, Span span = Span());
322 };
323 
325 class NENode : public CmpOpNode<NENode> {
326  public:
327  static constexpr const char* _type_key = "tir.NE";
328 };
329 
334 class NE : public PrimExpr {
335  public:
336  TVM_DLL NE(PrimExpr a, PrimExpr b, Span span = Span());
339 };
340 
342 class LTNode : public CmpOpNode<LTNode> {
343  public:
344  static constexpr const char* _type_key = "tir.LT";
345 };
346 
351 class LT : public PrimExpr {
352  public:
353  TVM_DLL LT(PrimExpr a, PrimExpr b, Span span = Span());
356 };
357 
359 struct LENode : public CmpOpNode<LENode> {
360  public:
361  static constexpr const char* _type_key = "tir.LE";
362 };
363 
368 class LE : public PrimExpr {
369  public:
370  TVM_DLL LE(PrimExpr a, PrimExpr b, Span span = Span());
373 };
374 
376 class GTNode : public CmpOpNode<GTNode> {
377  public:
378  static constexpr const char* _type_key = "tir.GT";
379 };
380 
385 class GT : public PrimExpr {
386  public:
387  TVM_DLL GT(PrimExpr a, PrimExpr b, Span span = Span());
390 };
391 
393 class GENode : public CmpOpNode<GENode> {
394  public:
395  static constexpr const char* _type_key = "tir.GE";
396 };
397 
402 class GE : public PrimExpr {
403  public:
404  TVM_DLL GE(PrimExpr a, PrimExpr b, Span span = Span());
407 };
408 
410 class AndNode : public PrimExprNode {
411  public:
416 
417  static void RegisterReflection() {
418  namespace refl = tvm::ffi::reflection;
419  refl::ObjectDef<AndNode>().def_ro("a", &AndNode::a).def_ro("b", &AndNode::b);
420  }
421 
422  static constexpr const char* _type_key = "tir.And";
424 };
425 
430 class And : public PrimExpr {
431  public:
432  TVM_DLL And(PrimExpr a, PrimExpr b, Span span = Span());
435 };
436 
438 class OrNode : public PrimExprNode {
439  public:
444 
445  static void RegisterReflection() {
446  namespace refl = tvm::ffi::reflection;
447  refl::ObjectDef<OrNode>().def_ro("a", &OrNode::a).def_ro("b", &OrNode::b);
448  }
449 
450  static constexpr const char* _type_key = "tir.Or";
452 };
453 
458 class Or : public PrimExpr {
459  public:
460  TVM_DLL Or(PrimExpr a, PrimExpr b, Span span = Span());
463 };
464 
466 class NotNode : public PrimExprNode {
467  public:
470 
471  static void RegisterReflection() {
472  namespace refl = tvm::ffi::reflection;
473  refl::ObjectDef<NotNode>().def_ro("a", &NotNode::a);
474  }
475 
476  static constexpr const char* _type_key = "tir.Not";
478 };
479 
484 class Not : public PrimExpr {
485  public:
486  TVM_DLL Not(PrimExpr a, Span span = Span());
489 };
490 
498 class SelectNode : public PrimExprNode {
499  public:
506 
507  static void RegisterReflection() {
508  namespace refl = tvm::ffi::reflection;
509  refl::ObjectDef<SelectNode>()
510  .def_ro("condition", &SelectNode::condition)
511  .def_ro("true_value", &SelectNode::true_value)
512  .def_ro("false_value", &SelectNode::false_value);
513  }
514 
515  static constexpr const char* _type_key = "tir.Select";
517 };
518 
523 class Select : public PrimExpr {
524  public:
525  TVM_DLL Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span = Span());
526 
529 };
530 
541 class BufferLoadNode : public PrimExprNode {
542  public:
546  Array<PrimExpr> indices;
548  Optional<PrimExpr> predicate;
549 
550  static void RegisterReflection() {
551  namespace refl = tvm::ffi::reflection;
552  refl::ObjectDef<BufferLoadNode>()
553  .def_ro("buffer", &BufferLoadNode::buffer)
554  .def_ro("indices", &BufferLoadNode::indices)
555  .def_ro("predicate", &BufferLoadNode::predicate);
556  }
557 
558  static constexpr const char* _type_key = "tir.BufferLoad";
560 
561  private:
571  void LegalizeDType();
572  friend class BufferLoad;
574  friend class VectorTypeRewriter;
575  friend class Vectorizer;
576 };
577 
582 class BufferLoad : public PrimExpr {
583  public:
584  TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices,
585  Optional<PrimExpr> predicate = std::nullopt, Span span = Span());
588 };
589 
600  public:
604  Array<PrimExpr> indices;
605 
606  static void RegisterReflection() {
607  namespace refl = tvm::ffi::reflection;
608  refl::ObjectDef<ProducerLoadNode>()
609  .def_ro("producer", &ProducerLoadNode::producer)
610  .def_ro("indices", &ProducerLoadNode::indices);
611  }
612 
613  static constexpr const char* _type_key = "tir.ProducerLoad";
615 };
616 
621 class ProducerLoad : public PrimExpr {
622  public:
623  TVM_DLL explicit ProducerLoad(DataProducer producer, Array<PrimExpr> indices, Span span = Span());
624 
627 };
628 
638 class RampNode : public PrimExprNode {
639  public:
646 
647  static void RegisterReflection() {
648  namespace refl = tvm::ffi::reflection;
649  refl::ObjectDef<RampNode>()
650  .def_ro("base", &RampNode::base)
651  .def_ro("stride", &RampNode::stride)
652  .def_ro("lanes", &RampNode::lanes);
653  }
654 
655  static constexpr const char* _type_key = "tir.Ramp";
657 };
658 
663 class Ramp : public PrimExpr {
664  public:
665  TVM_DLL Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span = Span());
668 };
669 
671 class BroadcastNode : public PrimExprNode {
672  public:
677 
678  static void RegisterReflection() {
679  namespace refl = tvm::ffi::reflection;
680  refl::ObjectDef<BroadcastNode>()
681  .def_ro("value", &BroadcastNode::value)
682  .def_ro("lanes", &BroadcastNode::lanes);
683  }
684 
685  static constexpr const char* _type_key = "tir.Broadcast";
687 };
688 
693 class Broadcast : public PrimExpr {
694  public:
695  TVM_DLL Broadcast(PrimExpr value, PrimExpr lanes, Span span = Span());
698 };
699 
703 class LetNode : public PrimExprNode {
704  public:
711 
712  static void RegisterReflection() {
713  namespace refl = tvm::ffi::reflection;
714  refl::ObjectDef<LetNode>()
715  .def_ro("var", &LetNode::var, refl::AttachFieldFlag::SEqHashDef())
716  .def_ro("value", &LetNode::value)
717  .def_ro("body", &LetNode::body);
718  }
719 
720  static constexpr const char* _type_key = "tir.Let";
722 };
723 
728 class Let : public PrimExpr {
729  public:
730  TVM_DLL Let(Var var, PrimExpr value, PrimExpr body, Span span = Span());
733 };
734 
738 class CallNode : public PrimExprNode {
739  public:
747 
749  Array<PrimExpr> args;
750 
751  static void RegisterReflection() {
752  namespace refl = tvm::ffi::reflection;
753  refl::ObjectDef<CallNode>().def_ro("op", &CallNode::op).def_ro("args", &CallNode::args);
754  }
755 
756  static constexpr const char* _type_key = "tir.Call";
758 };
759 
764 class Call : public PrimExpr {
765  public:
766  TVM_DLL Call(DataType dtype, RelaxExpr op, Array<PrimExpr> args, Span span = Span());
769 };
770 
776 class ShuffleNode : public PrimExprNode {
777  public:
779  Array<PrimExpr> vectors;
781  Array<PrimExpr> indices;
782 
783  static void RegisterReflection() {
784  namespace refl = tvm::ffi::reflection;
785  refl::ObjectDef<ShuffleNode>()
786  .def_ro("vectors", &ShuffleNode::vectors)
787  .def_ro("indices", &ShuffleNode::indices);
788  }
789 
790  static constexpr const char* _type_key = "tir.Shuffle";
792 };
793 
798 class Shuffle : public PrimExpr {
799  public:
800  TVM_DLL Shuffle(Array<PrimExpr> vectors, Array<PrimExpr> indices, Span span = Span());
801  TVM_DLL static PrimExpr Concat(Array<PrimExpr> vectors, Span span = Span());
802  TVM_DLL static PrimExpr ExtractElement(PrimExpr vector, int index, Span span = Span());
803 
806 };
807 
808 // Reduce operator
813 class CommReducerNode : public Object {
814  public:
816  Array<Var> lhs;
818  Array<Var> rhs;
820  Array<PrimExpr> result;
826  Array<PrimExpr> identity_element;
828  Array<PrimExpr> operator()(Array<PrimExpr> a, Array<PrimExpr> b) const;
833  mutable Span span;
834 
835  static void RegisterReflection() {
836  namespace refl = tvm::ffi::reflection;
837  refl::ObjectDef<CommReducerNode>()
838  .def_ro("lhs", &CommReducerNode::lhs, refl::AttachFieldFlag::SEqHashDef())
839  .def_ro("rhs", &CommReducerNode::rhs, refl::AttachFieldFlag::SEqHashDef())
840  .def_ro("result", &CommReducerNode::result)
841  .def_ro("identity_element", &CommReducerNode::identity_element)
842  .def_ro("span", &CommReducerNode::span, refl::AttachFieldFlag::SEqHashIgnore());
843  }
844 
845  static constexpr const char* _type_key = "tir.CommReducer";
846  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
848 };
849 
854 class CommReducer : public ObjectRef {
855  public:
856  TVM_DLL CommReducer(Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result,
857  Array<PrimExpr> identity_element, Span span = Span());
858 
860 };
861 
863 class ReduceNode : public PrimExprNode {
864  public:
868  Array<PrimExpr> source;
870  Array<PrimExpr> init;
872  Array<IterVar> axis;
880 
881  static void RegisterReflection() {
882  namespace refl = tvm::ffi::reflection;
883  refl::ObjectDef<ReduceNode>()
884  .def_ro("combiner", &ReduceNode::combiner)
885  .def_ro("source", &ReduceNode::source)
886  .def_ro("init", &ReduceNode::init)
887  .def_ro("axis", &ReduceNode::axis)
888  .def_ro("condition", &ReduceNode::condition)
889  .def_ro("value_index", &ReduceNode::value_index);
890  }
891 
892  static constexpr const char* _type_key = "tir.Reduce";
894 };
895 
900 class Reduce : public PrimExpr {
901  public:
902  TVM_DLL Reduce(CommReducer combiner, Array<PrimExpr> src, Array<IterVar> rdom, PrimExpr condition,
903  int value_index, Array<PrimExpr> init, Span span = Span());
904 
907 };
908 
909 /*
910  * \brief Template function to convert Map to unordered_map
911  * Sometimes useful for API gluing when internal uses unordered_map
912  * \param dmap The container map
913  * \return The corresponding unordered_map.
914  * \tparam K the key of the Map.
915  * \tparam V the value of the Map.
916  */
917 template <typename K, typename V>
918 inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
919  std::unordered_map<K, V> ret;
920  for (auto kv : dmap) {
921  ret[kv.first] = kv.second;
922  }
923  return ret;
924 }
925 } // namespace tir
926 
927 namespace ffi {
928 
929 template <>
930 inline constexpr bool use_default_type_traits_v<tvm::tir::StringImm> = false;
931 
932 template <>
933 struct TypeTraits<tvm::tir::StringImm>
934  : public ObjectRefWithFallbackTraitsBase<tvm::tir::StringImm, String> {
935  TVM_FFI_INLINE static tvm::tir::StringImm ConvertFallbackValue(String value) {
936  return tvm::tir::StringImm(value);
937  }
938 };
939 } // namespace ffi
940 } // namespace tvm
941 
942 namespace std {
943 template <>
944 struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectPtrHash {};
945 } // namespace std
946 #endif // TVM_TIR_EXPR_H_
Symbolic n-dimensional array, to represent a memory buffer.
Constant floating point literals in the program.
Definition: expr.h:538
Constant integer literals in the program.
Definition: expr.h:501
Base node of all primitive expressions.
Definition: expr.h:95
Reference to PrimExprNode.
Definition: expr.h:129
DataType dtype() const
Definition: expr.h:143
Managed reference to RelaxExprNode.
Definition: expr.h:446
Definition: source_map.h:113
Runtime primitive data type.
Definition: data_type.h:47
a + b
Definition: expr.h:128
static constexpr const char * _type_key
Definition: expr.h:130
Managed reference to AddNode.
Definition: expr.h:137
TVM_DEFINE_OBJECT_REF_COW_METHOD(AddNode)
TVM_DEFINE_OBJECT_REF_METHODS(Add, PrimExpr, AddNode)
Add(PrimExpr a, PrimExpr b, Span span=Span())
a && b
Definition: expr.h:410
PrimExpr a
The left operand.
Definition: expr.h:413
static constexpr const char * _type_key
Definition: expr.h:422
PrimExpr b
The right operand.
Definition: expr.h:415
TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, PrimExprNode)
static void RegisterReflection()
Definition: expr.h:417
Managed reference to AndNode.
Definition: expr.h:430
TVM_DEFINE_OBJECT_REF_COW_METHOD(AndNode)
And(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(And, PrimExpr, AndNode)
Base template to implement binary ops.
Definition: expr.h:112
PrimExpr b
The right operand.
Definition: expr.h:117
TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode)
static void RegisterReflection()
Definition: expr.h:119
PrimExpr a
The left operand.
Definition: expr.h:115
Create a vector where all the elements are value.
Definition: expr.h:671
static constexpr const char * _type_key
Definition: expr.h:685
TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, PrimExprNode)
PrimExpr value
The base value.
Definition: expr.h:674
static void RegisterReflection()
Definition: expr.h:678
PrimExpr lanes
The number of lanes.
Definition: expr.h:676
Managed reference to BroadcastNode.
Definition: expr.h:693
Broadcast(PrimExpr value, PrimExpr lanes, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(BroadcastNode)
Load value from the high dimension buffer.
Definition: expr.h:541
friend class VectorTypeRewriter
Definition: expr.h:574
friend class CustomDatatypesLowerer
Definition: expr.h:573
TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode)
Array< PrimExpr > indices
The indices location to be loaded.
Definition: expr.h:546
friend class Vectorizer
Definition: expr.h:575
Buffer buffer
The buffer variable.
Definition: expr.h:544
static constexpr const char * _type_key
Definition: expr.h:558
static void RegisterReflection()
Definition: expr.h:550
Optional< PrimExpr > predicate
The predicate mask for loading values.
Definition: expr.h:548
Managed reference to BufferLoadNode.
Definition: expr.h:582
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode)
BufferLoad(Buffer buffer, Array< PrimExpr > indices, Optional< PrimExpr > predicate=std::nullopt, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode)
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:157
Call node.
Definition: expr.h:738
RelaxExpr op
The operator(function) being invoked.
Definition: expr.h:746
static constexpr const char * _type_key
Definition: expr.h:756
Array< PrimExpr > args
The arguments.
Definition: expr.h:749
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode)
static void RegisterReflection()
Definition: expr.h:751
Managed reference to CallNode.
Definition: expr.h:764
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode)
TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode)
Call(DataType dtype, RelaxExpr op, Array< PrimExpr > args, Span span=Span())
Cast value from one data type to another.
Definition: expr.h:82
PrimExpr value
Original data type.
Definition: expr.h:85
static void RegisterReflection()
Definition: expr.h:87
TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, PrimExprNode)
static constexpr const char * _type_key
Definition: expr.h:92
Managed reference to CastNode.
Definition: expr.h:100
TVM_DEFINE_OBJECT_REF_COW_METHOD(CastNode)
TVM_DEFINE_OBJECT_REF_METHODS(Cast, PrimExpr, CastNode)
Cast(DataType dtype, PrimExpr value, Span span=Span())
Base template to implement comparison ops.
Definition: expr.h:292
static void RegisterReflection()
Definition: expr.h:299
TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode)
PrimExpr a
The left operand.
Definition: expr.h:295
PrimExpr b
The right operand.
Definition: expr.h:297
A commutative reducer node to represent a commutative binary operator with identity element.
Definition: expr.h:813
Array< Var > rhs
The right argument of reducer.
Definition: expr.h:818
Array< Var > lhs
The left argument of reducer.
Definition: expr.h:816
static constexpr const char * _type_key
Definition: expr.h:845
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: expr.h:846
Array< PrimExpr > operator()(Array< PrimExpr > a, Array< PrimExpr > b) const
Function call operator to combine a and b.
Array< PrimExpr > result
The result of reducer.
Definition: expr.h:820
TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object)
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:833
static void RegisterReflection()
Definition: expr.h:835
Array< PrimExpr > identity_element
The identity element of reducer, which leaves other elements unchanged when combined with it,...
Definition: expr.h:826
Managed reference to CommReducerNode.
Definition: expr.h:854
CommReducer(Array< Var > lhs, Array< Var > rhs, Array< PrimExpr > result, Array< PrimExpr > identity_element, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(CommReducer, ObjectRef, CommReducerNode)
Managed reference to DataProducerNode.
Definition: buffer.h:288
a / b in the C semnatics.
Definition: expr.h:183
static constexpr const char * _type_key
Definition: expr.h:185
Managed reference to DivNode.
Definition: expr.h:192
TVM_DEFINE_OBJECT_REF_COW_METHOD(DivNode)
Div(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(Div, PrimExpr, DivNode)
a == b
Definition: expr.h:308
static constexpr const char * _type_key
Definition: expr.h:310
Managed reference to EQNode.
Definition: expr.h:317
TVM_DEFINE_OBJECT_REF_METHODS(EQ, PrimExpr, EQNode)
EQ(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(EQNode)
Floor division, floor(a/b)
Definition: expr.h:220
static constexpr const char * _type_key
Definition: expr.h:222
Managed reference to FloorDivNode.
Definition: expr.h:229
TVM_DEFINE_OBJECT_REF_METHODS(FloorDiv, PrimExpr, FloorDivNode)
FloorDiv(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorDivNode)
The remainder of the floordiv.
Definition: expr.h:237
static constexpr const char * _type_key
Definition: expr.h:239
Managed reference to FloorModNode.
Definition: expr.h:246
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorModNode)
TVM_DEFINE_OBJECT_REF_METHODS(FloorMod, PrimExpr, FloorModNode)
FloorMod(PrimExpr a, PrimExpr b, Span span=Span())
a >= b
Definition: expr.h:393
static constexpr const char * _type_key
Definition: expr.h:395
Managed reference to GENode.
Definition: expr.h:402
TVM_DEFINE_OBJECT_REF_METHODS(GE, PrimExpr, GENode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(GENode)
GE(PrimExpr a, PrimExpr b, Span span=Span())
a > b
Definition: expr.h:376
static constexpr const char * _type_key
Definition: expr.h:378
Managed reference to GTNode.
Definition: expr.h:385
TVM_DEFINE_OBJECT_REF_METHODS(GT, PrimExpr, GTNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(GTNode)
GT(PrimExpr a, PrimExpr b, Span span=Span())
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:298
Managed reference to LENode.
Definition: expr.h:368
LE(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(LENode)
TVM_DEFINE_OBJECT_REF_METHODS(LE, PrimExpr, LENode)
a < b
Definition: expr.h:342
static constexpr const char * _type_key
Definition: expr.h:344
Managed reference to LTNode.
Definition: expr.h:351
TVM_DEFINE_OBJECT_REF_COW_METHOD(LTNode)
LT(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(LT, PrimExpr, LTNode)
Let binding. Bind var to value then evaluate body.
Definition: expr.h:703
static void RegisterReflection()
Definition: expr.h:712
Var var
The variable.
Definition: expr.h:706
static constexpr const char * _type_key
Definition: expr.h:720
PrimExpr value
The value to be binded.
Definition: expr.h:708
TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, PrimExprNode)
PrimExpr body
The result expression.
Definition: expr.h:710
Managed reference to LetNode.
Definition: expr.h:728
TVM_DEFINE_OBJECT_REF_METHODS(Let, PrimExpr, LetNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(LetNode)
Let(Var var, PrimExpr value, PrimExpr body, Span span=Span())
max(a, b)
Definition: expr.h:271
static constexpr const char * _type_key
Definition: expr.h:273
Managed reference to MaxNode.
Definition: expr.h:280
TVM_DEFINE_OBJECT_REF_METHODS(Max, PrimExpr, MaxNode)
Max(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(MaxNode)
min(a, b)
Definition: expr.h:254
static constexpr const char * _type_key
Definition: expr.h:256
Managed reference to MinNode.
Definition: expr.h:263
Min(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(Min, PrimExpr, MinNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(MinNode)
a % b in the C semnatics.
Definition: expr.h:203
static constexpr const char * _type_key
Definition: expr.h:205
Managed reference to ModNode.
Definition: expr.h:212
TVM_DEFINE_OBJECT_REF_METHODS(Mod, PrimExpr, ModNode)
Mod(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ModNode)
a * b
Definition: expr.h:163
static constexpr const char * _type_key
Definition: expr.h:165
Managed reference to MulNode.
Definition: expr.h:172
Mul(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(Mul, PrimExpr, MulNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(MulNode)
a != b
Definition: expr.h:325
static constexpr const char * _type_key
Definition: expr.h:327
Managed reference to NENode.
Definition: expr.h:334
NE(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(NENode)
TVM_DEFINE_OBJECT_REF_METHODS(NE, PrimExpr, NENode)
!a
Definition: expr.h:466
TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, PrimExprNode)
PrimExpr a
The input operand.
Definition: expr.h:469
static constexpr const char * _type_key
Definition: expr.h:476
static void RegisterReflection()
Definition: expr.h:471
Managed reference to NotNode.
Definition: expr.h:484
TVM_DEFINE_OBJECT_REF_COW_METHOD(NotNode)
Not(PrimExpr a, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(Not, PrimExpr, NotNode)
a || b
Definition: expr.h:438
PrimExpr b
The right operand.
Definition: expr.h:443
PrimExpr a
The left operand.
Definition: expr.h:441
TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, PrimExprNode)
static constexpr const char * _type_key
Definition: expr.h:450
static void RegisterReflection()
Definition: expr.h:445
Managed reference to OrNode.
Definition: expr.h:458
TVM_DEFINE_OBJECT_REF_METHODS(Or, PrimExpr, OrNode)
Or(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(OrNode)
Load value from the result produced by the producer.
Definition: expr.h:599
static constexpr const char * _type_key
Definition: expr.h:613
Array< PrimExpr > indices
The location arguments.
Definition: expr.h:604
static void RegisterReflection()
Definition: expr.h:606
DataProducer producer
The buffer producer.
Definition: expr.h:602
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerLoadNode, PrimExprNode)
Managed reference to ProducerLoadNode.
Definition: expr.h:621
TVM_DEFINE_OBJECT_REF_METHODS(ProducerLoad, PrimExpr, ProducerLoadNode)
ProducerLoad(DataProducer producer, Array< PrimExpr > indices, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerLoadNode)
Construct a vector with lanes elements where its i-th element equals base + i * stride....
Definition: expr.h:638
static constexpr const char * _type_key
Definition: expr.h:655
PrimExpr stride
The stride of each step.
Definition: expr.h:643
TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, PrimExprNode)
PrimExpr lanes
Total number of lanes.
Definition: expr.h:645
static void RegisterReflection()
Definition: expr.h:647
PrimExpr base
The base value.
Definition: expr.h:641
Managed reference to RampNode.
Definition: expr.h:663
TVM_DEFINE_OBJECT_REF_COW_METHOD(RampNode)
TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode)
Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span=Span())
Reduction operator.
Definition: expr.h:863
Array< PrimExpr > init
The init operand.
Definition: expr.h:870
int value_index
the index of this reduce node
Definition: expr.h:879
static constexpr const char * _type_key
Definition: expr.h:892
TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode)
Array< IterVar > axis
The reduction axis.
Definition: expr.h:872
CommReducer combiner
The commutative combiner.
Definition: expr.h:866
static void RegisterReflection()
Definition: expr.h:881
Array< PrimExpr > source
The source operand.
Definition: expr.h:868
PrimExpr condition
Predicate on the reduction Only add the body to reduction if condition is true.
Definition: expr.h:877
Managed reference to ReduceNode.
Definition: expr.h:900
TVM_DEFINE_OBJECT_REF_COW_METHOD(ReduceNode)
TVM_DEFINE_OBJECT_REF_METHODS(Reduce, PrimExpr, ReduceNode)
Reduce(CommReducer combiner, Array< PrimExpr > src, Array< IterVar > rdom, PrimExpr condition, int value_index, Array< PrimExpr > init, Span span=Span())
return true_value if condition is true, otherwise return false_value.
Definition: expr.h:498
PrimExpr condition
The condition.
Definition: expr.h:501
PrimExpr true_value
value to be returned when condition is true.
Definition: expr.h:503
static constexpr const char * _type_key
Definition: expr.h:515
TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, PrimExprNode)
PrimExpr false_value
value to be returned when condition is false.
Definition: expr.h:505
static void RegisterReflection()
Definition: expr.h:507
Managed reference to SelectNode.
Definition: expr.h:523
TVM_DEFINE_OBJECT_REF_COW_METHOD(SelectNode)
TVM_DEFINE_OBJECT_REF_METHODS(Select, PrimExpr, SelectNode)
Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Shuffle instruction. vec = concat(vectors) result = (vec[indices[0]], vec[indices[1]] ....
Definition: expr.h:776
Array< PrimExpr > indices
The indices of each element.
Definition: expr.h:781
static constexpr const char * _type_key
Definition: expr.h:790
static void RegisterReflection()
Definition: expr.h:783
TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, PrimExprNode)
Array< PrimExpr > vectors
the input vectors.
Definition: expr.h:779
Managed reference to ShuffleNode.
Definition: expr.h:798
static PrimExpr Concat(Array< PrimExpr > vectors, Span span=Span())
Shuffle(Array< PrimExpr > vectors, Array< PrimExpr > indices, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ShuffleNode)
TVM_DEFINE_OBJECT_REF_METHODS(Shuffle, PrimExpr, ShuffleNode)
static PrimExpr ExtractElement(PrimExpr vector, int index, Span span=Span())
String constants, only used in asserts.
Definition: expr.h:53
String value
The constant value content.
Definition: expr.h:56
static constexpr const char * _type_key
Definition: expr.h:63
TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, PrimExprNode)
static void RegisterReflection()
Definition: expr.h:58
Managed reference to StringImmNode.
Definition: expr.h:71
StringImm(String value, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode)
TVM_DEFINE_OBJECT_REF_METHODS(StringImm, PrimExpr, StringImmNode)
a - b
Definition: expr.h:145
static constexpr const char * _type_key
Definition: expr.h:147
Managed reference to SubNode.
Definition: expr.h:154
Sub(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(Sub, PrimExpr, SubNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(SubNode)
a named variable in TIR
Definition: var.h:78
Defines the Functor data structures.
Base expr nodes in TVM.
Definition: repr_printer.h:91
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
std::unordered_map< K, V > as_unordered_map(const Map< K, V > &dmap)
Definition: expr.h:918
tvm::FloatImmNode FloatImmNode
Definition: expr.h:50
tvm::IntImmNode IntImmNode
Definition: expr.h:49
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
Definitions and helper macros for IR/AST nodes.
static TVM_FFI_INLINE tvm::tir::StringImm ConvertFallbackValue(String value)
Definition: expr.h:935
a <= b
Definition: expr.h:359
static constexpr const char * _type_key
Definition: expr.h:361
Variables in the TIR.