tvm
expr.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 // Acknowledgement: Many low-level IR nodes originate from Halide.
25 #ifndef TVM_TIR_EXPR_H_
26 #define TVM_TIR_EXPR_H_
27 
28 #include <tvm/ffi/container/array.h>
29 #include <tvm/ffi/container/map.h>
30 #include <tvm/ffi/string.h>
31 #include <tvm/ir/expr.h>
32 #include <tvm/node/functor.h>
33 #include <tvm/node/node.h>
34 #include <tvm/runtime/base.h>
35 #include <tvm/runtime/data_type.h>
36 #include <tvm/tir/buffer.h>
37 #include <tvm/tir/var.h>
38 
39 #include <algorithm>
40 #include <iostream>
41 #include <limits>
42 #include <string>
43 #include <unordered_map>
44 #include <utility>
45 
46 namespace tvm {
47 namespace tir {
48 
51 
53 class StringImmNode : public PrimExprNode {
54  public:
56  ffi::String value;
57 
58  static void RegisterReflection() {
59  namespace refl = tvm::ffi::reflection;
60  refl::ObjectDef<StringImmNode>().def_ro("value", &StringImmNode::value);
61  }
63 };
64 
69 class StringImm : public PrimExpr {
70  public:
71  TVM_DLL StringImm(ffi::String value, Span span = Span());
74 };
75 
80 class CastNode : public PrimExprNode {
81  public:
84 
85  static void RegisterReflection() {
86  namespace refl = tvm::ffi::reflection;
87  refl::ObjectDef<CastNode>().def_ro("value", &CastNode::value);
88  }
90 };
91 
96 class Cast : public PrimExpr {
97  public:
98  TVM_DLL Cast(DataType dtype, PrimExpr value, Span span = Span());
101 };
102 
107 template <typename T>
108 class BinaryOpNode : public PrimExprNode {
109  public:
114 
115  static void RegisterReflection() {
116  namespace refl = tvm::ffi::reflection;
117  refl::ObjectDef<T>().def_ro("a", &T::a).def_ro("b", &T::b);
118  }
119 
120  static const constexpr int _type_child_slots [[maybe_unused]] = 0;
121  static const constexpr bool _type_final [[maybe_unused]] = true;
123 };
124 
126 class AddNode : public BinaryOpNode<AddNode> {
127  public:
128  static constexpr const char* _type_key = "tir.Add";
129 };
130 
135 class Add : public PrimExpr {
136  public:
137  TVM_DLL Add(PrimExpr a, PrimExpr b, Span span = Span());
140 };
141 
143 class SubNode : public BinaryOpNode<SubNode> {
144  public:
145  static constexpr const char* _type_key = "tir.Sub";
146 };
147 
152 class Sub : public PrimExpr {
153  public:
154  TVM_DLL Sub(PrimExpr a, PrimExpr b, Span span = Span());
155 
158 };
159 
161 class MulNode : public BinaryOpNode<MulNode> {
162  public:
163  static constexpr const char* _type_key = "tir.Mul";
164 };
165 
170 class Mul : public PrimExpr {
171  public:
172  TVM_DLL Mul(PrimExpr a, PrimExpr b, Span span = Span());
175 };
176 
181 class DivNode : public BinaryOpNode<DivNode> {
182  public:
183  static constexpr const char* _type_key = "tir.Div";
184 };
185 
190 class Div : public PrimExpr {
191  public:
192  TVM_DLL Div(PrimExpr a, PrimExpr b, Span span = Span());
195 };
196 
201 class ModNode : public BinaryOpNode<ModNode> {
202  public:
203  static constexpr const char* _type_key = "tir.Mod";
204 };
205 
210 class Mod : public PrimExpr {
211  public:
212  TVM_DLL Mod(PrimExpr a, PrimExpr b, Span span = Span());
215 };
216 
218 class FloorDivNode : public BinaryOpNode<FloorDivNode> {
219  public:
220  static constexpr const char* _type_key = "tir.FloorDiv";
221 };
222 
227 class FloorDiv : public PrimExpr {
228  public:
229  TVM_DLL FloorDiv(PrimExpr a, PrimExpr b, Span span = Span());
232 };
233 
235 class FloorModNode : public BinaryOpNode<FloorModNode> {
236  public:
237  static constexpr const char* _type_key = "tir.FloorMod";
238 };
239 
244 class FloorMod : public PrimExpr {
245  public:
246  TVM_DLL FloorMod(PrimExpr a, PrimExpr b, Span span = Span());
249 };
250 
252 class MinNode : public BinaryOpNode<MinNode> {
253  public:
254  static constexpr const char* _type_key = "tir.Min";
255 };
256 
261 class Min : public PrimExpr {
262  public:
263  TVM_DLL Min(PrimExpr a, PrimExpr b, Span span = Span());
266 };
267 
269 class MaxNode : public BinaryOpNode<MaxNode> {
270  public:
271  static constexpr const char* _type_key = "tir.Max";
272 };
273 
278 class Max : public PrimExpr {
279  public:
280  TVM_DLL Max(PrimExpr a, PrimExpr b, Span span = Span());
283 };
284 
289 template <typename T>
290 class CmpOpNode : public PrimExprNode {
291  public:
296 
297  static void RegisterReflection() {
298  namespace refl = tvm::ffi::reflection;
299  refl::ObjectDef<T>().def_ro("a", &T::a).def_ro("b", &T::b);
300  }
301 
302  static const constexpr int _type_child_slots [[maybe_unused]] = 0;
303  static const constexpr bool _type_final [[maybe_unused]] = true;
305 };
306 
308 class EQNode : public CmpOpNode<EQNode> {
309  public:
310  static constexpr const char* _type_key = "tir.EQ";
311 };
312 
317 class EQ : public PrimExpr {
318  public:
319  TVM_DLL EQ(PrimExpr a, PrimExpr b, Span span = Span());
322 };
323 
325 class NENode : public CmpOpNode<NENode> {
326  public:
327  static constexpr const char* _type_key = "tir.NE";
328 };
329 
334 class NE : public PrimExpr {
335  public:
336  TVM_DLL NE(PrimExpr a, PrimExpr b, Span span = Span());
339 };
340 
342 class LTNode : public CmpOpNode<LTNode> {
343  public:
344  static constexpr const char* _type_key = "tir.LT";
345 };
346 
351 class LT : public PrimExpr {
352  public:
353  TVM_DLL LT(PrimExpr a, PrimExpr b, Span span = Span());
356 };
357 
359 struct LENode : public CmpOpNode<LENode> {
360  public:
361  static constexpr const char* _type_key = "tir.LE";
362 };
363 
368 class LE : public PrimExpr {
369  public:
370  TVM_DLL LE(PrimExpr a, PrimExpr b, Span span = Span());
373 };
374 
376 class GTNode : public CmpOpNode<GTNode> {
377  public:
378  static constexpr const char* _type_key = "tir.GT";
379 };
380 
385 class GT : public PrimExpr {
386  public:
387  TVM_DLL GT(PrimExpr a, PrimExpr b, Span span = Span());
390 };
391 
393 class GENode : public CmpOpNode<GENode> {
394  public:
395  static constexpr const char* _type_key = "tir.GE";
396 };
397 
402 class GE : public PrimExpr {
403  public:
404  TVM_DLL GE(PrimExpr a, PrimExpr b, Span span = Span());
407 };
408 
410 class AndNode : public PrimExprNode {
411  public:
416 
417  static void RegisterReflection() {
418  namespace refl = tvm::ffi::reflection;
419  refl::ObjectDef<AndNode>().def_ro("a", &AndNode::a).def_ro("b", &AndNode::b);
420  }
422 };
423 
428 class And : public PrimExpr {
429  public:
430  TVM_DLL And(PrimExpr a, PrimExpr b, Span span = Span());
433 };
434 
436 class OrNode : public PrimExprNode {
437  public:
442 
443  static void RegisterReflection() {
444  namespace refl = tvm::ffi::reflection;
445  refl::ObjectDef<OrNode>().def_ro("a", &OrNode::a).def_ro("b", &OrNode::b);
446  }
448 };
449 
454 class Or : public PrimExpr {
455  public:
456  TVM_DLL Or(PrimExpr a, PrimExpr b, Span span = Span());
459 };
460 
462 class NotNode : public PrimExprNode {
463  public:
466 
467  static void RegisterReflection() {
468  namespace refl = tvm::ffi::reflection;
469  refl::ObjectDef<NotNode>().def_ro("a", &NotNode::a);
470  }
472 };
473 
478 class Not : public PrimExpr {
479  public:
480  TVM_DLL Not(PrimExpr a, Span span = Span());
483 };
484 
492 class SelectNode : public PrimExprNode {
493  public:
500 
501  static void RegisterReflection() {
502  namespace refl = tvm::ffi::reflection;
503  refl::ObjectDef<SelectNode>()
504  .def_ro("condition", &SelectNode::condition)
505  .def_ro("true_value", &SelectNode::true_value)
506  .def_ro("false_value", &SelectNode::false_value);
507  }
509 };
510 
515 class Select : public PrimExpr {
516  public:
517  TVM_DLL Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span = Span());
518 
521 };
522 
533 class BufferLoadNode : public PrimExprNode {
534  public:
538  ffi::Array<PrimExpr> indices;
540  ffi::Optional<PrimExpr> predicate;
541 
542  static void RegisterReflection() {
543  namespace refl = tvm::ffi::reflection;
544  refl::ObjectDef<BufferLoadNode>()
545  .def_ro("buffer", &BufferLoadNode::buffer)
546  .def_ro("indices", &BufferLoadNode::indices)
547  .def_ro("predicate", &BufferLoadNode::predicate);
548  }
550 
551  private:
561  void LegalizeDType();
562  friend class BufferLoad;
564  friend class VectorTypeRewriter;
565  friend class Vectorizer;
566 };
567 
572 class BufferLoad : public PrimExpr {
573  public:
574  TVM_DLL explicit BufferLoad(Buffer buffer, ffi::Array<PrimExpr> indices,
575  ffi::Optional<PrimExpr> predicate = std::nullopt, Span span = Span());
578 };
579 
590  public:
594  ffi::Array<PrimExpr> indices;
595 
596  static void RegisterReflection() {
597  namespace refl = tvm::ffi::reflection;
598  refl::ObjectDef<ProducerLoadNode>()
599  .def_ro("producer", &ProducerLoadNode::producer)
600  .def_ro("indices", &ProducerLoadNode::indices);
601  }
603 };
604 
609 class ProducerLoad : public PrimExpr {
610  public:
611  TVM_DLL explicit ProducerLoad(DataProducer producer, ffi::Array<PrimExpr> indices,
612  Span span = Span());
613 
616 };
617 
627 class RampNode : public PrimExprNode {
628  public:
635 
636  static void RegisterReflection() {
637  namespace refl = tvm::ffi::reflection;
638  refl::ObjectDef<RampNode>()
639  .def_ro("base", &RampNode::base)
640  .def_ro("stride", &RampNode::stride)
641  .def_ro("lanes", &RampNode::lanes);
642  }
644 };
645 
650 class Ramp : public PrimExpr {
651  public:
652  TVM_DLL Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span = Span());
655 };
656 
658 class BroadcastNode : public PrimExprNode {
659  public:
664 
665  static void RegisterReflection() {
666  namespace refl = tvm::ffi::reflection;
667  refl::ObjectDef<BroadcastNode>()
668  .def_ro("value", &BroadcastNode::value)
669  .def_ro("lanes", &BroadcastNode::lanes);
670  }
672 };
673 
678 class Broadcast : public PrimExpr {
679  public:
680  TVM_DLL Broadcast(PrimExpr value, PrimExpr lanes, Span span = Span());
683 };
684 
688 class LetNode : public PrimExprNode {
689  public:
696 
697  static void RegisterReflection() {
698  namespace refl = tvm::ffi::reflection;
699  refl::ObjectDef<LetNode>()
700  .def_ro("var", &LetNode::var, refl::AttachFieldFlag::SEqHashDef())
701  .def_ro("value", &LetNode::value)
702  .def_ro("body", &LetNode::body);
703  }
705 };
706 
711 class Let : public PrimExpr {
712  public:
713  TVM_DLL Let(Var var, PrimExpr value, PrimExpr body, Span span = Span());
716 };
717 
721 class CallNode : public PrimExprNode {
722  public:
730 
732  ffi::Array<PrimExpr> args;
733 
734  static void RegisterReflection() {
735  namespace refl = tvm::ffi::reflection;
736  refl::ObjectDef<CallNode>().def_ro("op", &CallNode::op).def_ro("args", &CallNode::args);
737  }
739 };
740 
745 class Call : public PrimExpr {
746  public:
747  TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args, Span span = Span());
750 };
751 
757 class ShuffleNode : public PrimExprNode {
758  public:
760  ffi::Array<PrimExpr> vectors;
762  ffi::Array<PrimExpr> indices;
763 
764  static void RegisterReflection() {
765  namespace refl = tvm::ffi::reflection;
766  refl::ObjectDef<ShuffleNode>()
767  .def_ro("vectors", &ShuffleNode::vectors)
768  .def_ro("indices", &ShuffleNode::indices);
769  }
771 };
772 
777 class Shuffle : public PrimExpr {
778  public:
779  TVM_DLL Shuffle(ffi::Array<PrimExpr> vectors, ffi::Array<PrimExpr> indices, Span span = Span());
780  TVM_DLL static PrimExpr Concat(ffi::Array<PrimExpr> vectors, Span span = Span());
781  TVM_DLL static PrimExpr ExtractElement(PrimExpr vector, int index, Span span = Span());
782 
785 };
786 
787 // Reduce operator
792 class CommReducerNode : public Object {
793  public:
795  ffi::Array<Var> lhs;
797  ffi::Array<Var> rhs;
799  ffi::Array<PrimExpr> result;
805  ffi::Array<PrimExpr> identity_element;
807  ffi::Array<PrimExpr> operator()(ffi::Array<PrimExpr> a, ffi::Array<PrimExpr> b) const;
812  mutable Span span;
813 
814  static void RegisterReflection() {
815  namespace refl = tvm::ffi::reflection;
816  refl::ObjectDef<CommReducerNode>()
817  .def_ro("lhs", &CommReducerNode::lhs, refl::AttachFieldFlag::SEqHashDef())
818  .def_ro("rhs", &CommReducerNode::rhs, refl::AttachFieldFlag::SEqHashDef())
819  .def_ro("result", &CommReducerNode::result)
820  .def_ro("identity_element", &CommReducerNode::identity_element)
821  .def_ro("span", &CommReducerNode::span, refl::AttachFieldFlag::SEqHashIgnore());
822  }
823 
824  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
826 };
827 
832 class CommReducer : public ObjectRef {
833  public:
834  TVM_DLL CommReducer(ffi::Array<Var> lhs, ffi::Array<Var> rhs, ffi::Array<PrimExpr> result,
835  ffi::Array<PrimExpr> identity_element, Span span = Span());
836 
838 };
839 
841 class ReduceNode : public PrimExprNode {
842  public:
846  ffi::Array<PrimExpr> source;
848  ffi::Array<PrimExpr> init;
850  ffi::Array<IterVar> axis;
858 
859  static void RegisterReflection() {
860  namespace refl = tvm::ffi::reflection;
861  refl::ObjectDef<ReduceNode>()
862  .def_ro("combiner", &ReduceNode::combiner)
863  .def_ro("source", &ReduceNode::source)
864  .def_ro("init", &ReduceNode::init)
865  .def_ro("axis", &ReduceNode::axis)
866  .def_ro("condition", &ReduceNode::condition)
867  .def_ro("value_index", &ReduceNode::value_index);
868  }
870 };
871 
876 class Reduce : public PrimExpr {
877  public:
878  TVM_DLL Reduce(CommReducer combiner, ffi::Array<PrimExpr> src, ffi::Array<IterVar> rdom,
879  PrimExpr condition, int value_index, ffi::Array<PrimExpr> init,
880  Span span = Span());
881 
884 };
885 
886 /*
887  * \brief Template function to convert Map to unordered_map
888  * Sometimes useful for API gluing when internal uses unordered_map
889  * \param dmap The container map
890  * \return The corresponding unordered_map.
891  * \tparam K the key of the Map.
892  * \tparam V the value of the Map.
893  */
894 template <typename K, typename V>
895 inline std::unordered_map<K, V> as_unordered_map(const ffi::Map<K, V>& dmap) {
896  std::unordered_map<K, V> ret;
897  for (auto kv : dmap) {
898  ret[kv.first] = kv.second;
899  }
900  return ret;
901 }
902 } // namespace tir
903 
904 namespace ffi {
905 
906 template <>
907 inline constexpr bool use_default_type_traits_v<tvm::tir::StringImm> = false;
908 
909 template <>
910 struct TypeTraits<tvm::tir::StringImm>
911  : public ObjectRefWithFallbackTraitsBase<tvm::tir::StringImm, ffi::String> {
912  TVM_FFI_INLINE static tvm::tir::StringImm ConvertFallbackValue(ffi::String value) {
913  return tvm::tir::StringImm(value);
914  }
915 };
916 } // namespace ffi
917 } // namespace tvm
918 
919 namespace std {
920 template <>
921 struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectPtrHash {};
922 } // namespace std
923 #endif // TVM_TIR_EXPR_H_
Symbolic n-dimensional array, to represent a memory buffer.
Constant floating point literals in the program.
Definition: expr.h:528
Constant integer literals in the program.
Definition: expr.h:493
Base node of all primitive expressions.
Definition: expr.h:91
Reference to PrimExprNode.
Definition: expr.h:124
DataType dtype() const
Definition: expr.h:138
Managed reference to RelaxExprNode.
Definition: expr.h:439
Definition: source_map.h:111
Runtime primitive data type.
Definition: data_type.h:47
a + b
Definition: expr.h:126
static constexpr const char * _type_key
Definition: expr.h:128
Managed reference to AddNode.
Definition: expr.h:135
TVM_DEFINE_OBJECT_REF_COW_METHOD(AddNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Add, PrimExpr, AddNode)
Add(PrimExpr a, PrimExpr b, Span span=Span())
a && b
Definition: expr.h:410
PrimExpr a
The left operand.
Definition: expr.h:413
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.And", AndNode, PrimExprNode)
PrimExpr b
The right operand.
Definition: expr.h:415
static void RegisterReflection()
Definition: expr.h:417
Managed reference to AndNode.
Definition: expr.h:428
TVM_DEFINE_OBJECT_REF_COW_METHOD(AndNode)
And(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(And, PrimExpr, AndNode)
Base template to implement binary ops.
Definition: expr.h:108
PrimExpr b
The right operand.
Definition: expr.h:113
TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(T, PrimExprNode)
static void RegisterReflection()
Definition: expr.h:115
PrimExpr a
The left operand.
Definition: expr.h:111
static constexpr const bool _type_final
Definition: expr.h:121
static constexpr const int _type_child_slots
Definition: expr.h:120
Create a vector where all the elements are value.
Definition: expr.h:658
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Broadcast", BroadcastNode, PrimExprNode)
PrimExpr value
The base value.
Definition: expr.h:661
static void RegisterReflection()
Definition: expr.h:665
PrimExpr lanes
The number of lanes.
Definition: expr.h:663
Managed reference to BroadcastNode.
Definition: expr.h:678
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Broadcast, PrimExpr, BroadcastNode)
Broadcast(PrimExpr value, PrimExpr lanes, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(BroadcastNode)
Load value from the high dimension buffer.
Definition: expr.h:533
ffi::Array< PrimExpr > indices
The indices location to be loaded.
Definition: expr.h:538
friend class VectorTypeRewriter
Definition: expr.h:564
friend class CustomDatatypesLowerer
Definition: expr.h:563
ffi::Optional< PrimExpr > predicate
The predicate mask for loading values.
Definition: expr.h:540
friend class Vectorizer
Definition: expr.h:565
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BufferLoad", BufferLoadNode, PrimExprNode)
Buffer buffer
The buffer variable.
Definition: expr.h:536
static void RegisterReflection()
Definition: expr.h:542
Managed reference to BufferLoadNode.
Definition: expr.h:572
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BufferLoad, PrimExpr, BufferLoadNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode)
BufferLoad(Buffer buffer, ffi::Array< PrimExpr > indices, ffi::Optional< PrimExpr > predicate=std::nullopt, Span span=Span())
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:156
Call node.
Definition: expr.h:721
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Call", CallNode, PrimExprNode)
RelaxExpr op
The operator(function) being invoked.
Definition: expr.h:729
ffi::Array< PrimExpr > args
The arguments.
Definition: expr.h:732
static void RegisterReflection()
Definition: expr.h:734
Managed reference to CallNode.
Definition: expr.h:745
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode)
Call(DataType dtype, RelaxExpr op, ffi::Array< PrimExpr > args, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, PrimExpr, CallNode)
Cast value from one data type to another.
Definition: expr.h:80
PrimExpr value
Original data type.
Definition: expr.h:83
static void RegisterReflection()
Definition: expr.h:85
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Cast", CastNode, PrimExprNode)
Managed reference to CastNode.
Definition: expr.h:96
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Cast, PrimExpr, CastNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(CastNode)
Cast(DataType dtype, PrimExpr value, Span span=Span())
Base template to implement comparison ops.
Definition: expr.h:290
static constexpr const int _type_child_slots
Definition: expr.h:302
static void RegisterReflection()
Definition: expr.h:297
PrimExpr a
The left operand.
Definition: expr.h:293
PrimExpr b
The right operand.
Definition: expr.h:295
static constexpr const bool _type_final
Definition: expr.h:303
TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(T, PrimExprNode)
A commutative reducer node to represent a commutative binary operator with identity element.
Definition: expr.h:792
ffi::Array< PrimExpr > identity_element
The identity element of reducer, which leaves other elements unchanged when combined with it,...
Definition: expr.h:805
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: expr.h:824
ffi::Array< PrimExpr > result
The result of reducer.
Definition: expr.h:799
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.CommReducer", CommReducerNode, Object)
ffi::Array< Var > rhs
The right argument of reducer.
Definition: expr.h:797
ffi::Array< Var > lhs
The left argument of reducer.
Definition: expr.h:795
ffi::Array< PrimExpr > operator()(ffi::Array< PrimExpr > a, ffi::Array< PrimExpr > b) const
Function call operator to combine a and b.
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:812
static void RegisterReflection()
Definition: expr.h:814
Managed reference to CommReducerNode.
Definition: expr.h:832
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CommReducer, ObjectRef, CommReducerNode)
CommReducer(ffi::Array< Var > lhs, ffi::Array< Var > rhs, ffi::Array< PrimExpr > result, ffi::Array< PrimExpr > identity_element, Span span=Span())
Managed reference to DataProducerNode.
Definition: buffer.h:286
a / b in the C semnatics.
Definition: expr.h:181
static constexpr const char * _type_key
Definition: expr.h:183
Managed reference to DivNode.
Definition: expr.h:190
TVM_DEFINE_OBJECT_REF_COW_METHOD(DivNode)
Div(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Div, PrimExpr, DivNode)
a == b
Definition: expr.h:308
static constexpr const char * _type_key
Definition: expr.h:310
Managed reference to EQNode.
Definition: expr.h:317
EQ(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(EQ, PrimExpr, EQNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(EQNode)
Floor division, floor(a/b)
Definition: expr.h:218
static constexpr const char * _type_key
Definition: expr.h:220
Managed reference to FloorDivNode.
Definition: expr.h:227
FloorDiv(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FloorDiv, PrimExpr, FloorDivNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorDivNode)
The remainder of the floordiv.
Definition: expr.h:235
static constexpr const char * _type_key
Definition: expr.h:237
Managed reference to FloorModNode.
Definition: expr.h:244
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FloorMod, PrimExpr, FloorModNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorModNode)
FloorMod(PrimExpr a, PrimExpr b, Span span=Span())
a >= b
Definition: expr.h:393
static constexpr const char * _type_key
Definition: expr.h:395
Managed reference to GENode.
Definition: expr.h:402
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GE, PrimExpr, GENode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(GENode)
GE(PrimExpr a, PrimExpr b, Span span=Span())
a > b
Definition: expr.h:376
static constexpr const char * _type_key
Definition: expr.h:378
Managed reference to GTNode.
Definition: expr.h:385
TVM_DEFINE_OBJECT_REF_COW_METHOD(GTNode)
GT(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GT, PrimExpr, GTNode)
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:297
Managed reference to LENode.
Definition: expr.h:368
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LE, PrimExpr, LENode)
LE(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(LENode)
a < b
Definition: expr.h:342
static constexpr const char * _type_key
Definition: expr.h:344
Managed reference to LTNode.
Definition: expr.h:351
TVM_DEFINE_OBJECT_REF_COW_METHOD(LTNode)
LT(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LT, PrimExpr, LTNode)
Let binding. Bind var to value then evaluate body.
Definition: expr.h:688
static void RegisterReflection()
Definition: expr.h:697
Var var
The variable.
Definition: expr.h:691
PrimExpr value
The value to be binded.
Definition: expr.h:693
PrimExpr body
The result expression.
Definition: expr.h:695
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Let", LetNode, PrimExprNode)
Managed reference to LetNode.
Definition: expr.h:711
TVM_DEFINE_OBJECT_REF_COW_METHOD(LetNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Let, PrimExpr, LetNode)
Let(Var var, PrimExpr value, PrimExpr body, Span span=Span())
max(a, b)
Definition: expr.h:269
static constexpr const char * _type_key
Definition: expr.h:271
Managed reference to MaxNode.
Definition: expr.h:278
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Max, PrimExpr, MaxNode)
Max(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(MaxNode)
min(a, b)
Definition: expr.h:252
static constexpr const char * _type_key
Definition: expr.h:254
Managed reference to MinNode.
Definition: expr.h:261
Min(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Min, PrimExpr, MinNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(MinNode)
a % b in the C semnatics.
Definition: expr.h:201
static constexpr const char * _type_key
Definition: expr.h:203
Managed reference to ModNode.
Definition: expr.h:210
Mod(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Mod, PrimExpr, ModNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ModNode)
a * b
Definition: expr.h:161
static constexpr const char * _type_key
Definition: expr.h:163
Managed reference to MulNode.
Definition: expr.h:170
Mul(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(MulNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Mul, PrimExpr, MulNode)
a != b
Definition: expr.h:325
static constexpr const char * _type_key
Definition: expr.h:327
Managed reference to NENode.
Definition: expr.h:334
NE(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(NENode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(NE, PrimExpr, NENode)
!a
Definition: expr.h:462
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Not", NotNode, PrimExprNode)
PrimExpr a
The input operand.
Definition: expr.h:465
static void RegisterReflection()
Definition: expr.h:467
Managed reference to NotNode.
Definition: expr.h:478
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Not, PrimExpr, NotNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(NotNode)
Not(PrimExpr a, Span span=Span())
a || b
Definition: expr.h:436
PrimExpr b
The right operand.
Definition: expr.h:441
PrimExpr a
The left operand.
Definition: expr.h:439
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Or", OrNode, PrimExprNode)
static void RegisterReflection()
Definition: expr.h:443
Managed reference to OrNode.
Definition: expr.h:454
Or(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Or, PrimExpr, OrNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(OrNode)
Load value from the result produced by the producer.
Definition: expr.h:589
static void RegisterReflection()
Definition: expr.h:596
DataProducer producer
The buffer producer.
Definition: expr.h:592
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.ProducerLoad", ProducerLoadNode, PrimExprNode)
ffi::Array< PrimExpr > indices
The location arguments.
Definition: expr.h:594
Managed reference to ProducerLoadNode.
Definition: expr.h:609
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ProducerLoad, PrimExpr, ProducerLoadNode)
ProducerLoad(DataProducer producer, ffi::Array< PrimExpr > indices, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerLoadNode)
Construct a vector with lanes elements where its i-th element equals base + i * stride....
Definition: expr.h:627
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Ramp", RampNode, PrimExprNode)
PrimExpr stride
The stride of each step.
Definition: expr.h:632
PrimExpr lanes
Total number of lanes.
Definition: expr.h:634
static void RegisterReflection()
Definition: expr.h:636
PrimExpr base
The base value.
Definition: expr.h:630
Managed reference to RampNode.
Definition: expr.h:650
TVM_DEFINE_OBJECT_REF_COW_METHOD(RampNode)
Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Ramp, PrimExpr, RampNode)
Reduction operator.
Definition: expr.h:841
int value_index
the index of this reduce node
Definition: expr.h:857
CommReducer combiner
The commutative combiner.
Definition: expr.h:844
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Reduce", ReduceNode, PrimExprNode)
static void RegisterReflection()
Definition: expr.h:859
ffi::Array< PrimExpr > source
The source operand.
Definition: expr.h:846
PrimExpr condition
Predicate on the reduction Only add the body to reduction if condition is true.
Definition: expr.h:855
ffi::Array< IterVar > axis
The reduction axis.
Definition: expr.h:850
ffi::Array< PrimExpr > init
The init operand.
Definition: expr.h:848
Managed reference to ReduceNode.
Definition: expr.h:876
TVM_DEFINE_OBJECT_REF_COW_METHOD(ReduceNode)
Reduce(CommReducer combiner, ffi::Array< PrimExpr > src, ffi::Array< IterVar > rdom, PrimExpr condition, int value_index, ffi::Array< PrimExpr > init, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Reduce, PrimExpr, ReduceNode)
return true_value if condition is true, otherwise return false_value.
Definition: expr.h:492
PrimExpr condition
The condition.
Definition: expr.h:495
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Select", SelectNode, PrimExprNode)
PrimExpr true_value
value to be returned when condition is true.
Definition: expr.h:497
PrimExpr false_value
value to be returned when condition is false.
Definition: expr.h:499
static void RegisterReflection()
Definition: expr.h:501
Managed reference to SelectNode.
Definition: expr.h:515
TVM_DEFINE_OBJECT_REF_COW_METHOD(SelectNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Select, PrimExpr, SelectNode)
Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Shuffle instruction. vec = concat(vectors) result = (vec[indices[0]], vec[indices[1]] ....
Definition: expr.h:757
ffi::Array< PrimExpr > indices
The indices of each element.
Definition: expr.h:762
static void RegisterReflection()
Definition: expr.h:764
ffi::Array< PrimExpr > vectors
the input vectors.
Definition: expr.h:760
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Shuffle", ShuffleNode, PrimExprNode)
Managed reference to ShuffleNode.
Definition: expr.h:777
static PrimExpr Concat(ffi::Array< PrimExpr > vectors, Span span=Span())
Shuffle(ffi::Array< PrimExpr > vectors, ffi::Array< PrimExpr > indices, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Shuffle, PrimExpr, ShuffleNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ShuffleNode)
static PrimExpr ExtractElement(PrimExpr vector, int index, Span span=Span())
ffi::String constants, only used in asserts.
Definition: expr.h:53
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.StringImm", StringImmNode, PrimExprNode)
ffi::String value
The constant value content.
Definition: expr.h:56
static void RegisterReflection()
Definition: expr.h:58
Managed reference to StringImmNode.
Definition: expr.h:69
StringImm(ffi::String value, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StringImm, PrimExpr, StringImmNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode)
a - b
Definition: expr.h:143
static constexpr const char * _type_key
Definition: expr.h:145
Managed reference to SubNode.
Definition: expr.h:152
Sub(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Sub, PrimExpr, SubNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(SubNode)
a named variable in TIR
Definition: var.h:77
Defines the Functor data structures.
Base expr nodes in TVM.
Definition: repr_printer.h:91
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
std::unordered_map< K, V > as_unordered_map(const ffi::Map< K, V > &dmap)
Definition: expr.h:895
tvm::FloatImmNode FloatImmNode
Definition: expr.h:50
tvm::IntImmNode IntImmNode
Definition: expr.h:49
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
Definitions and helper macros for IR/AST nodes.
static TVM_FFI_INLINE tvm::tir::StringImm ConvertFallbackValue(ffi::String value)
Definition: expr.h:912
a <= b
Definition: expr.h:359
static constexpr const char * _type_key
Definition: expr.h:361
Variables in the TIR.