tvm
expr.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 // Acknowledgement: Many low-level IR nodes originate from Halide.
25 #ifndef TVM_TIR_EXPR_H_
26 #define TVM_TIR_EXPR_H_
27 
28 #include <tvm/ffi/container/array.h>
29 #include <tvm/ffi/container/map.h>
30 #include <tvm/ffi/string.h>
31 #include <tvm/ir/expr.h>
32 #include <tvm/node/functor.h>
33 #include <tvm/runtime/base.h>
34 #include <tvm/runtime/data_type.h>
35 #include <tvm/tirx/buffer.h>
36 #include <tvm/tirx/var.h>
37 
38 #include <algorithm>
39 #include <iostream>
40 #include <limits>
41 #include <string>
42 #include <unordered_map>
43 #include <utility>
44 
45 namespace tvm {
46 namespace tirx {
47 
50 
52 class StringImmNode : public PrimExprNode {
53  public:
55  ffi::String value;
56 
57  static void RegisterReflection() {
58  namespace refl = tvm::ffi::reflection;
59  refl::ObjectDef<StringImmNode>().def_ro("value", &StringImmNode::value);
60  }
62 };
63 
68 class StringImm : public PrimExpr {
69  public:
70  TVM_DLL StringImm(ffi::String value, Span span = Span());
73 };
74 
79 class CastNode : public PrimExprNode {
80  public:
83 
84  static void RegisterReflection() {
85  namespace refl = tvm::ffi::reflection;
86  refl::ObjectDef<CastNode>().def_ro("value", &CastNode::value);
87  }
89 };
90 
95 class Cast : public PrimExpr {
96  public:
97  TVM_DLL Cast(DataType dtype, PrimExpr value, Span span = Span());
100 };
101 
106 template <typename T>
107 class BinaryOpNode : public PrimExprNode {
108  public:
113 
114  static void RegisterReflection() {
115  namespace refl = tvm::ffi::reflection;
116  refl::ObjectDef<T>().def_ro("a", &T::a).def_ro("b", &T::b);
117  }
118 
119  static const constexpr int _type_child_slots [[maybe_unused]] = 0;
120  static const constexpr bool _type_final [[maybe_unused]] = true;
122 };
123 
125 class AddNode : public BinaryOpNode<AddNode> {
126  public:
127  static constexpr const char* _type_key = "tirx.Add";
128 };
129 
134 class Add : public PrimExpr {
135  public:
136  TVM_DLL Add(PrimExpr a, PrimExpr b, Span span = Span());
139 };
140 
142 class SubNode : public BinaryOpNode<SubNode> {
143  public:
144  static constexpr const char* _type_key = "tirx.Sub";
145 };
146 
151 class Sub : public PrimExpr {
152  public:
153  TVM_DLL Sub(PrimExpr a, PrimExpr b, Span span = Span());
154 
157 };
158 
160 class MulNode : public BinaryOpNode<MulNode> {
161  public:
162  static constexpr const char* _type_key = "tirx.Mul";
163 };
164 
169 class Mul : public PrimExpr {
170  public:
171  TVM_DLL Mul(PrimExpr a, PrimExpr b, Span span = Span());
174 };
175 
180 class DivNode : public BinaryOpNode<DivNode> {
181  public:
182  static constexpr const char* _type_key = "tirx.Div";
183 };
184 
189 class Div : public PrimExpr {
190  public:
191  TVM_DLL Div(PrimExpr a, PrimExpr b, Span span = Span());
194 };
195 
200 class ModNode : public BinaryOpNode<ModNode> {
201  public:
202  static constexpr const char* _type_key = "tirx.Mod";
203 };
204 
209 class Mod : public PrimExpr {
210  public:
211  TVM_DLL Mod(PrimExpr a, PrimExpr b, Span span = Span());
214 };
215 
217 class FloorDivNode : public BinaryOpNode<FloorDivNode> {
218  public:
219  static constexpr const char* _type_key = "tirx.FloorDiv";
220 };
221 
226 class FloorDiv : public PrimExpr {
227  public:
228  TVM_DLL FloorDiv(PrimExpr a, PrimExpr b, Span span = Span());
231 };
232 
234 class FloorModNode : public BinaryOpNode<FloorModNode> {
235  public:
236  static constexpr const char* _type_key = "tirx.FloorMod";
237 };
238 
243 class FloorMod : public PrimExpr {
244  public:
245  TVM_DLL FloorMod(PrimExpr a, PrimExpr b, Span span = Span());
248 };
249 
251 class MinNode : public BinaryOpNode<MinNode> {
252  public:
253  static constexpr const char* _type_key = "tirx.Min";
254 };
255 
260 class Min : public PrimExpr {
261  public:
262  TVM_DLL Min(PrimExpr a, PrimExpr b, Span span = Span());
265 };
266 
268 class MaxNode : public BinaryOpNode<MaxNode> {
269  public:
270  static constexpr const char* _type_key = "tirx.Max";
271 };
272 
277 class Max : public PrimExpr {
278  public:
279  TVM_DLL Max(PrimExpr a, PrimExpr b, Span span = Span());
282 };
283 
288 template <typename T>
289 class CmpOpNode : public PrimExprNode {
290  public:
295 
296  static void RegisterReflection() {
297  namespace refl = tvm::ffi::reflection;
298  refl::ObjectDef<T>().def_ro("a", &T::a).def_ro("b", &T::b);
299  }
300 
301  static const constexpr int _type_child_slots [[maybe_unused]] = 0;
302  static const constexpr bool _type_final [[maybe_unused]] = true;
304 };
305 
307 class EQNode : public CmpOpNode<EQNode> {
308  public:
309  static constexpr const char* _type_key = "tirx.EQ";
310 };
311 
316 class EQ : public PrimExpr {
317  public:
318  TVM_DLL EQ(PrimExpr a, PrimExpr b, Span span = Span());
321 };
322 
324 class NENode : public CmpOpNode<NENode> {
325  public:
326  static constexpr const char* _type_key = "tirx.NE";
327 };
328 
333 class NE : public PrimExpr {
334  public:
335  TVM_DLL NE(PrimExpr a, PrimExpr b, Span span = Span());
338 };
339 
341 class LTNode : public CmpOpNode<LTNode> {
342  public:
343  static constexpr const char* _type_key = "tirx.LT";
344 };
345 
350 class LT : public PrimExpr {
351  public:
352  TVM_DLL LT(PrimExpr a, PrimExpr b, Span span = Span());
355 };
356 
358 struct LENode : public CmpOpNode<LENode> {
359  public:
360  static constexpr const char* _type_key = "tirx.LE";
361 };
362 
367 class LE : public PrimExpr {
368  public:
369  TVM_DLL LE(PrimExpr a, PrimExpr b, Span span = Span());
372 };
373 
375 class GTNode : public CmpOpNode<GTNode> {
376  public:
377  static constexpr const char* _type_key = "tirx.GT";
378 };
379 
384 class GT : public PrimExpr {
385  public:
386  TVM_DLL GT(PrimExpr a, PrimExpr b, Span span = Span());
389 };
390 
392 class GENode : public CmpOpNode<GENode> {
393  public:
394  static constexpr const char* _type_key = "tirx.GE";
395 };
396 
401 class GE : public PrimExpr {
402  public:
403  TVM_DLL GE(PrimExpr a, PrimExpr b, Span span = Span());
406 };
407 
409 class AndNode : public PrimExprNode {
410  public:
415 
416  static void RegisterReflection() {
417  namespace refl = tvm::ffi::reflection;
418  refl::ObjectDef<AndNode>().def_ro("a", &AndNode::a).def_ro("b", &AndNode::b);
419  }
421 };
422 
427 class And : public PrimExpr {
428  public:
429  TVM_DLL And(PrimExpr a, PrimExpr b, Span span = Span());
432 };
433 
435 class OrNode : public PrimExprNode {
436  public:
441 
442  static void RegisterReflection() {
443  namespace refl = tvm::ffi::reflection;
444  refl::ObjectDef<OrNode>().def_ro("a", &OrNode::a).def_ro("b", &OrNode::b);
445  }
447 };
448 
453 class Or : public PrimExpr {
454  public:
455  TVM_DLL Or(PrimExpr a, PrimExpr b, Span span = Span());
458 };
459 
461 class NotNode : public PrimExprNode {
462  public:
465 
466  static void RegisterReflection() {
467  namespace refl = tvm::ffi::reflection;
468  refl::ObjectDef<NotNode>().def_ro("a", &NotNode::a);
469  }
471 };
472 
477 class Not : public PrimExpr {
478  public:
479  TVM_DLL Not(PrimExpr a, Span span = Span());
482 };
483 
491 class SelectNode : public PrimExprNode {
492  public:
499 
500  static void RegisterReflection() {
501  namespace refl = tvm::ffi::reflection;
502  refl::ObjectDef<SelectNode>()
503  .def_ro("condition", &SelectNode::condition)
504  .def_ro("true_value", &SelectNode::true_value)
505  .def_ro("false_value", &SelectNode::false_value);
506  }
508 };
509 
514 class Select : public PrimExpr {
515  public:
516  TVM_DLL Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span = Span());
517 
520 };
521 
532 class BufferLoadNode : public PrimExprNode {
533  public:
537  ffi::Array<PrimExpr> indices;
539  ffi::Optional<PrimExpr> predicate;
540 
541  static void RegisterReflection() {
542  namespace refl = tvm::ffi::reflection;
543  refl::ObjectDef<BufferLoadNode>()
544  .def_ro("buffer", &BufferLoadNode::buffer)
545  .def_ro("indices", &BufferLoadNode::indices)
546  .def_ro("predicate", &BufferLoadNode::predicate);
547  }
549 
550  private:
560  void LegalizeDType();
561  friend class BufferLoad;
563  friend class VectorTypeRewriter;
564  friend class Vectorizer;
565 };
566 
571 class BufferLoad : public PrimExpr {
572  public:
573  TVM_DLL explicit BufferLoad(Buffer buffer, ffi::Array<PrimExpr> indices,
574  ffi::Optional<PrimExpr> predicate = std::nullopt, Span span = Span());
577 };
578 
589  public:
593  ffi::Array<PrimExpr> indices;
594 
595  static void RegisterReflection() {
596  namespace refl = tvm::ffi::reflection;
597  refl::ObjectDef<ProducerLoadNode>()
598  .def_ro("producer", &ProducerLoadNode::producer)
599  .def_ro("indices", &ProducerLoadNode::indices);
600  }
602 };
603 
608 class ProducerLoad : public PrimExpr {
609  public:
610  TVM_DLL explicit ProducerLoad(DataProducer producer, ffi::Array<PrimExpr> indices,
611  Span span = Span());
612 
615 };
616 
626 class RampNode : public PrimExprNode {
627  public:
634 
635  static void RegisterReflection() {
636  namespace refl = tvm::ffi::reflection;
637  refl::ObjectDef<RampNode>()
638  .def_ro("base", &RampNode::base)
639  .def_ro("stride", &RampNode::stride)
640  .def_ro("lanes", &RampNode::lanes);
641  }
643 };
644 
649 class Ramp : public PrimExpr {
650  public:
651  TVM_DLL Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span = Span());
654 };
655 
657 class BroadcastNode : public PrimExprNode {
658  public:
663 
664  static void RegisterReflection() {
665  namespace refl = tvm::ffi::reflection;
666  refl::ObjectDef<BroadcastNode>()
667  .def_ro("value", &BroadcastNode::value)
668  .def_ro("lanes", &BroadcastNode::lanes);
669  }
671 };
672 
677 class Broadcast : public PrimExpr {
678  public:
679  TVM_DLL Broadcast(PrimExpr value, PrimExpr lanes, Span span = Span());
682 };
683 
687 class LetNode : public PrimExprNode {
688  public:
695 
696  static void RegisterReflection() {
697  namespace refl = tvm::ffi::reflection;
698  refl::ObjectDef<LetNode>()
699  .def_ro("var", &LetNode::var, refl::AttachFieldFlag::SEqHashDef())
700  .def_ro("value", &LetNode::value)
701  .def_ro("body", &LetNode::body);
702  }
704 };
705 
710 class Let : public PrimExpr {
711  public:
712  TVM_DLL Let(Var var, PrimExpr value, PrimExpr body, Span span = Span());
715 };
716 
720 class CallNode : public PrimExprNode {
721  public:
729 
731  ffi::Array<PrimExpr> args;
732 
733  static void RegisterReflection() {
734  namespace refl = tvm::ffi::reflection;
735  refl::ObjectDef<CallNode>().def_ro("op", &CallNode::op).def_ro("args", &CallNode::args);
736  }
738 };
739 
744 class Call : public PrimExpr {
745  public:
746  TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args, Span span = Span());
749 };
750 
756 class ShuffleNode : public PrimExprNode {
757  public:
759  ffi::Array<PrimExpr> vectors;
761  ffi::Array<PrimExpr> indices;
762 
763  static void RegisterReflection() {
764  namespace refl = tvm::ffi::reflection;
765  refl::ObjectDef<ShuffleNode>()
766  .def_ro("vectors", &ShuffleNode::vectors)
767  .def_ro("indices", &ShuffleNode::indices);
768  }
770 };
771 
776 class Shuffle : public PrimExpr {
777  public:
778  TVM_DLL Shuffle(ffi::Array<PrimExpr> vectors, ffi::Array<PrimExpr> indices, Span span = Span());
779  TVM_DLL static PrimExpr Concat(ffi::Array<PrimExpr> vectors, Span span = Span());
780  TVM_DLL static PrimExpr ExtractElement(PrimExpr vector, int index, Span span = Span());
781 
784 };
785 
786 // Reduce operator
791 class CommReducerNode : public Object {
792  public:
794  ffi::Array<Var> lhs;
796  ffi::Array<Var> rhs;
798  ffi::Array<PrimExpr> result;
804  ffi::Array<PrimExpr> identity_element;
806  ffi::Array<PrimExpr> operator()(ffi::Array<PrimExpr> a, ffi::Array<PrimExpr> b) const;
811  mutable Span span;
812 
813  static void RegisterReflection() {
814  namespace refl = tvm::ffi::reflection;
815  refl::ObjectDef<CommReducerNode>()
816  .def_ro("lhs", &CommReducerNode::lhs, refl::AttachFieldFlag::SEqHashDef())
817  .def_ro("rhs", &CommReducerNode::rhs, refl::AttachFieldFlag::SEqHashDef())
818  .def_ro("result", &CommReducerNode::result)
819  .def_ro("identity_element", &CommReducerNode::identity_element)
820  .def_ro("span", &CommReducerNode::span, refl::AttachFieldFlag::SEqHashIgnore());
821  }
822 
823  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
825 };
826 
831 class CommReducer : public ObjectRef {
832  public:
833  TVM_DLL CommReducer(ffi::Array<Var> lhs, ffi::Array<Var> rhs, ffi::Array<PrimExpr> result,
834  ffi::Array<PrimExpr> identity_element, Span span = Span());
835 
837 };
838 
840 class ReduceNode : public PrimExprNode {
841  public:
845  ffi::Array<PrimExpr> source;
847  ffi::Array<PrimExpr> init;
849  ffi::Array<IterVar> axis;
857 
858  static void RegisterReflection() {
859  namespace refl = tvm::ffi::reflection;
860  refl::ObjectDef<ReduceNode>()
861  .def_ro("combiner", &ReduceNode::combiner)
862  .def_ro("source", &ReduceNode::source)
863  .def_ro("init", &ReduceNode::init)
864  .def_ro("axis", &ReduceNode::axis)
865  .def_ro("condition", &ReduceNode::condition)
866  .def_ro("value_index", &ReduceNode::value_index);
867  }
869 };
870 
875 class Reduce : public PrimExpr {
876  public:
877  TVM_DLL Reduce(CommReducer combiner, ffi::Array<PrimExpr> src, ffi::Array<IterVar> rdom,
878  PrimExpr condition, int value_index, ffi::Array<PrimExpr> init,
879  Span span = Span());
880 
883 };
884 
885 /*
886  * \brief Template function to convert Map to unordered_map
887  * Sometimes useful for API gluing when internal uses unordered_map
888  * \param dmap The container map
889  * \return The corresponding unordered_map.
890  * \tparam K the key of the Map.
891  * \tparam V the value of the Map.
892  */
893 template <typename K, typename V>
894 inline std::unordered_map<K, V> as_unordered_map(const ffi::Map<K, V>& dmap) {
895  std::unordered_map<K, V> ret;
896  for (auto kv : dmap) {
897  ret[kv.first] = kv.second;
898  }
899  return ret;
900 }
901 } // namespace tirx
902 
903 namespace ffi {
904 
905 template <>
906 inline constexpr bool use_default_type_traits_v<tvm::tirx::StringImm> = false;
907 
908 template <>
909 struct TypeTraits<tvm::tirx::StringImm>
910  : public ObjectRefWithFallbackTraitsBase<tvm::tirx::StringImm, ffi::String> {
911  TVM_FFI_INLINE static tvm::tirx::StringImm ConvertFallbackValue(ffi::String value) {
912  return tvm::tirx::StringImm(value);
913  }
914 };
915 } // namespace ffi
916 } // namespace tvm
917 
918 namespace std {
919 template <>
920 struct hash<::tvm::tirx::IterVar> : public ::tvm::ObjectPtrHash {};
921 } // namespace std
922 #endif // TVM_TIR_EXPR_H_
Symbolic n-dimensional array, to represent a memory buffer.
Constant floating point literals in the program.
Definition: expr.h:529
Constant integer literals in the program.
Definition: expr.h:494
Base node of all primitive expressions.
Definition: expr.h:93
Reference to PrimExprNode.
Definition: expr.h:126
DataType dtype() const
Definition: expr.h:140
Managed reference to RelaxExprNode.
Definition: expr.h:441
Definition: source_map.h:111
Runtime primitive data type.
Definition: data_type.h:47
a + b
Definition: expr.h:125
static constexpr const char * _type_key
Definition: expr.h:127
Managed reference to AddNode.
Definition: expr.h:134
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Add, PrimExpr, AddNode)
Add(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AddNode)
a && b
Definition: expr.h:409
PrimExpr a
The left operand.
Definition: expr.h:412
PrimExpr b
The right operand.
Definition: expr.h:414
static void RegisterReflection()
Definition: expr.h:416
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.And", AndNode, PrimExprNode)
Managed reference to AndNode.
Definition: expr.h:427
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(And, PrimExpr, AndNode)
And(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AndNode)
Base template to implement binary ops.
Definition: expr.h:107
static constexpr const int _type_child_slots
Definition: expr.h:119
PrimExpr b
The right operand.
Definition: expr.h:112
TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(T, PrimExprNode)
static constexpr const bool _type_final
Definition: expr.h:120
static void RegisterReflection()
Definition: expr.h:114
PrimExpr a
The left operand.
Definition: expr.h:110
Create a vector where all the elements are value.
Definition: expr.h:657
PrimExpr value
The base value.
Definition: expr.h:660
static void RegisterReflection()
Definition: expr.h:664
PrimExpr lanes
The number of lanes.
Definition: expr.h:662
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Broadcast", BroadcastNode, PrimExprNode)
Managed reference to BroadcastNode.
Definition: expr.h:677
TVM_DEFINE_OBJECT_REF_COW_METHOD(BroadcastNode)
Broadcast(PrimExpr value, PrimExpr lanes, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Broadcast, PrimExpr, BroadcastNode)
Load value from the high dimension buffer.
Definition: expr.h:532
Buffer buffer
The buffer variable.
Definition: expr.h:535
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.BufferLoad", BufferLoadNode, PrimExprNode)
friend class VectorTypeRewriter
Definition: expr.h:563
ffi::Optional< PrimExpr > predicate
The predicate mask for loading values.
Definition: expr.h:539
friend class CustomDatatypesLowerer
Definition: expr.h:562
friend class Vectorizer
Definition: expr.h:564
static void RegisterReflection()
Definition: expr.h:541
ffi::Array< PrimExpr > indices
The indices location to be loaded.
Definition: expr.h:537
Managed reference to BufferLoadNode.
Definition: expr.h:571
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BufferLoad, PrimExpr, BufferLoadNode)
BufferLoad(Buffer buffer, ffi::Array< PrimExpr > indices, ffi::Optional< PrimExpr > predicate=std::nullopt, Span span=Span())
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:156
Call node.
Definition: expr.h:720
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Call", CallNode, PrimExprNode)
ffi::Array< PrimExpr > args
The arguments.
Definition: expr.h:731
RelaxExpr op
The operator(function) being invoked.
Definition: expr.h:728
static void RegisterReflection()
Definition: expr.h:733
Managed reference to CallNode.
Definition: expr.h:744
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode)
Call(DataType dtype, RelaxExpr op, ffi::Array< PrimExpr > args, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, PrimExpr, CallNode)
Cast value from one data type to another.
Definition: expr.h:79
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Cast", CastNode, PrimExprNode)
PrimExpr value
Original data type.
Definition: expr.h:82
static void RegisterReflection()
Definition: expr.h:84
Managed reference to CastNode.
Definition: expr.h:95
Cast(DataType dtype, PrimExpr value, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(CastNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Cast, PrimExpr, CastNode)
Base template to implement comparison ops.
Definition: expr.h:289
static constexpr const bool _type_final
Definition: expr.h:302
PrimExpr a
The left operand.
Definition: expr.h:292
static constexpr const int _type_child_slots
Definition: expr.h:301
static void RegisterReflection()
Definition: expr.h:296
PrimExpr b
The right operand.
Definition: expr.h:294
TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(T, PrimExprNode)
A commutative reducer node to represent a commutative binary operator with identity element.
Definition: expr.h:791
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.CommReducer", CommReducerNode, Object)
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: expr.h:823
ffi::Array< Var > lhs
The left argument of reducer.
Definition: expr.h:794
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:811
ffi::Array< PrimExpr > identity_element
The identity element of reducer, which leaves other elements unchanged when combined with it,...
Definition: expr.h:804
static void RegisterReflection()
Definition: expr.h:813
ffi::Array< PrimExpr > operator()(ffi::Array< PrimExpr > a, ffi::Array< PrimExpr > b) const
Function call operator to combine a and b.
ffi::Array< Var > rhs
The right argument of reducer.
Definition: expr.h:796
ffi::Array< PrimExpr > result
The result of reducer.
Definition: expr.h:798
Managed reference to CommReducerNode.
Definition: expr.h:831
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CommReducer, ObjectRef, CommReducerNode)
CommReducer(ffi::Array< Var > lhs, ffi::Array< Var > rhs, ffi::Array< PrimExpr > result, ffi::Array< PrimExpr > identity_element, Span span=Span())
Managed reference to DataProducerNode.
Definition: buffer.h:286
a / b in the C semnatics.
Definition: expr.h:180
static constexpr const char * _type_key
Definition: expr.h:182
Managed reference to DivNode.
Definition: expr.h:189
TVM_DEFINE_OBJECT_REF_COW_METHOD(DivNode)
Div(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Div, PrimExpr, DivNode)
a == b
Definition: expr.h:307
static constexpr const char * _type_key
Definition: expr.h:309
Managed reference to EQNode.
Definition: expr.h:316
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(EQ, PrimExpr, EQNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(EQNode)
EQ(PrimExpr a, PrimExpr b, Span span=Span())
Floor division, floor(a/b)
Definition: expr.h:217
static constexpr const char * _type_key
Definition: expr.h:219
Managed reference to FloorDivNode.
Definition: expr.h:226
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FloorDiv, PrimExpr, FloorDivNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorDivNode)
FloorDiv(PrimExpr a, PrimExpr b, Span span=Span())
The remainder of the floordiv.
Definition: expr.h:234
static constexpr const char * _type_key
Definition: expr.h:236
Managed reference to FloorModNode.
Definition: expr.h:243
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FloorMod, PrimExpr, FloorModNode)
FloorMod(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorModNode)
a >= b
Definition: expr.h:392
static constexpr const char * _type_key
Definition: expr.h:394
Managed reference to GENode.
Definition: expr.h:401
GE(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GE, PrimExpr, GENode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(GENode)
a > b
Definition: expr.h:375
static constexpr const char * _type_key
Definition: expr.h:377
Managed reference to GTNode.
Definition: expr.h:384
TVM_DEFINE_OBJECT_REF_COW_METHOD(GTNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GT, PrimExpr, GTNode)
GT(PrimExpr a, PrimExpr b, Span span=Span())
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:296
Managed reference to LENode.
Definition: expr.h:367
TVM_DEFINE_OBJECT_REF_COW_METHOD(LENode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LE, PrimExpr, LENode)
LE(PrimExpr a, PrimExpr b, Span span=Span())
a < b
Definition: expr.h:341
static constexpr const char * _type_key
Definition: expr.h:343
Managed reference to LTNode.
Definition: expr.h:350
TVM_DEFINE_OBJECT_REF_COW_METHOD(LTNode)
LT(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LT, PrimExpr, LTNode)
Let binding. Bind var to value then evaluate body.
Definition: expr.h:687
PrimExpr body
The result expression.
Definition: expr.h:694
PrimExpr value
The value to be binded.
Definition: expr.h:692
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Let", LetNode, PrimExprNode)
static void RegisterReflection()
Definition: expr.h:696
Var var
The variable.
Definition: expr.h:690
Managed reference to LetNode.
Definition: expr.h:710
TVM_DEFINE_OBJECT_REF_COW_METHOD(LetNode)
Let(Var var, PrimExpr value, PrimExpr body, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Let, PrimExpr, LetNode)
max(a, b)
Definition: expr.h:268
static constexpr const char * _type_key
Definition: expr.h:270
Managed reference to MaxNode.
Definition: expr.h:277
TVM_DEFINE_OBJECT_REF_COW_METHOD(MaxNode)
Max(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Max, PrimExpr, MaxNode)
min(a, b)
Definition: expr.h:251
static constexpr const char * _type_key
Definition: expr.h:253
Managed reference to MinNode.
Definition: expr.h:260
TVM_DEFINE_OBJECT_REF_COW_METHOD(MinNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Min, PrimExpr, MinNode)
Min(PrimExpr a, PrimExpr b, Span span=Span())
a % b in the C semnatics.
Definition: expr.h:200
static constexpr const char * _type_key
Definition: expr.h:202
Managed reference to ModNode.
Definition: expr.h:209
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Mod, PrimExpr, ModNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ModNode)
Mod(PrimExpr a, PrimExpr b, Span span=Span())
a * b
Definition: expr.h:160
static constexpr const char * _type_key
Definition: expr.h:162
Managed reference to MulNode.
Definition: expr.h:169
Mul(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Mul, PrimExpr, MulNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(MulNode)
a != b
Definition: expr.h:324
static constexpr const char * _type_key
Definition: expr.h:326
Managed reference to NENode.
Definition: expr.h:333
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(NE, PrimExpr, NENode)
NE(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(NENode)
!a
Definition: expr.h:461
static void RegisterReflection()
Definition: expr.h:466
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Not", NotNode, PrimExprNode)
PrimExpr a
The input operand.
Definition: expr.h:464
Managed reference to NotNode.
Definition: expr.h:477
Not(PrimExpr a, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(NotNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Not, PrimExpr, NotNode)
a || b
Definition: expr.h:435
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Or", OrNode, PrimExprNode)
PrimExpr b
The right operand.
Definition: expr.h:440
static void RegisterReflection()
Definition: expr.h:442
PrimExpr a
The left operand.
Definition: expr.h:438
Managed reference to OrNode.
Definition: expr.h:453
Or(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(OrNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Or, PrimExpr, OrNode)
Load value from the result produced by the producer.
Definition: expr.h:588
ffi::Array< PrimExpr > indices
The location arguments.
Definition: expr.h:593
DataProducer producer
The buffer producer.
Definition: expr.h:591
static void RegisterReflection()
Definition: expr.h:595
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.ProducerLoad", ProducerLoadNode, PrimExprNode)
Managed reference to ProducerLoadNode.
Definition: expr.h:608
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ProducerLoad, PrimExpr, ProducerLoadNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerLoadNode)
ProducerLoad(DataProducer producer, ffi::Array< PrimExpr > indices, Span span=Span())
Construct a vector with lanes elements where its i-th element equals base + i * stride....
Definition: expr.h:626
PrimExpr lanes
Total number of lanes.
Definition: expr.h:633
PrimExpr base
The base value.
Definition: expr.h:629
static void RegisterReflection()
Definition: expr.h:635
PrimExpr stride
The stride of each step.
Definition: expr.h:631
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Ramp", RampNode, PrimExprNode)
Managed reference to RampNode.
Definition: expr.h:649
TVM_DEFINE_OBJECT_REF_COW_METHOD(RampNode)
Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Ramp, PrimExpr, RampNode)
Reduction operator.
Definition: expr.h:840
ffi::Array< PrimExpr > source
The source operand.
Definition: expr.h:845
ffi::Array< IterVar > axis
The reduction axis.
Definition: expr.h:849
CommReducer combiner
The commutative combiner.
Definition: expr.h:843
static void RegisterReflection()
Definition: expr.h:858
PrimExpr condition
Predicate on the reduction Only add the body to reduction if condition is true.
Definition: expr.h:854
int value_index
the index of this reduce node
Definition: expr.h:856
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Reduce", ReduceNode, PrimExprNode)
ffi::Array< PrimExpr > init
The init operand.
Definition: expr.h:847
Managed reference to ReduceNode.
Definition: expr.h:875
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Reduce, PrimExpr, ReduceNode)
Reduce(CommReducer combiner, ffi::Array< PrimExpr > src, ffi::Array< IterVar > rdom, PrimExpr condition, int value_index, ffi::Array< PrimExpr > init, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ReduceNode)
return true_value if condition is true, otherwise return false_value.
Definition: expr.h:491
static void RegisterReflection()
Definition: expr.h:500
PrimExpr false_value
value to be returned when condition is false.
Definition: expr.h:498
PrimExpr true_value
value to be returned when condition is true.
Definition: expr.h:496
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Select", SelectNode, PrimExprNode)
PrimExpr condition
The condition.
Definition: expr.h:494
Managed reference to SelectNode.
Definition: expr.h:514
Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(SelectNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Select, PrimExpr, SelectNode)
Shuffle instruction. vec = concat(vectors) result = (vec[indices[0]], vec[indices[1]] ....
Definition: expr.h:756
ffi::Array< PrimExpr > indices
The indices of each element.
Definition: expr.h:761
ffi::Array< PrimExpr > vectors
the input vectors.
Definition: expr.h:759
static void RegisterReflection()
Definition: expr.h:763
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Shuffle", ShuffleNode, PrimExprNode)
Managed reference to ShuffleNode.
Definition: expr.h:776
static PrimExpr ExtractElement(PrimExpr vector, int index, Span span=Span())
static PrimExpr Concat(ffi::Array< PrimExpr > vectors, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ShuffleNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Shuffle, PrimExpr, ShuffleNode)
Shuffle(ffi::Array< PrimExpr > vectors, ffi::Array< PrimExpr > indices, Span span=Span())
ffi::String constants, only used in asserts.
Definition: expr.h:52
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.StringImm", StringImmNode, PrimExprNode)
ffi::String value
The constant value content.
Definition: expr.h:55
static void RegisterReflection()
Definition: expr.h:57
Managed reference to StringImmNode.
Definition: expr.h:68
TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode)
StringImm(ffi::String value, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StringImm, PrimExpr, StringImmNode)
a - b
Definition: expr.h:142
static constexpr const char * _type_key
Definition: expr.h:144
Managed reference to SubNode.
Definition: expr.h:151
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Sub, PrimExpr, SubNode)
Sub(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(SubNode)
a named variable in TIR
Definition: var.h:76
Defines the Functor data structures.
Base expr nodes in TVM.
Definition: repr_printer.h:91
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
tvm::FloatImmNode FloatImmNode
Definition: expr.h:49
tvm::IntImmNode IntImmNode
Definition: expr.h:48
std::unordered_map< K, V > as_unordered_map(const ffi::Map< K, V > &dmap)
Definition: expr.h:894
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
static TVM_FFI_INLINE tvm::tirx::StringImm ConvertFallbackValue(ffi::String value)
Definition: expr.h:911
a <= b
Definition: expr.h:358
static constexpr const char * _type_key
Definition: expr.h:360
Variables in the TIR.