tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
structural_hash.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 #ifndef TVM_NODE_STRUCTURAL_HASH_H_
24 #define TVM_NODE_STRUCTURAL_HASH_H_
25 
26 #include <tvm/node/functor.h>
27 #include <tvm/runtime/data_type.h>
28 #include <tvm/runtime/ndarray.h>
29 
30 #include <functional>
31 #include <string>
32 
33 namespace tvm {
34 
39  protected:
40  template <typename T, typename U>
41  uint64_t Reinterpret(T value) const {
42  union Union {
43  T a;
44  U b;
45  } u;
46  static_assert(sizeof(Union) == sizeof(T), "sizeof(Union) != sizeof(T)");
47  static_assert(sizeof(Union) == sizeof(U), "sizeof(Union) != sizeof(U)");
48  u.b = 0;
49  u.a = value;
50  return u.b;
51  }
52 
53  public:
54  uint64_t operator()(const float& key) const { return Reinterpret<float, uint32_t>(key); }
55  uint64_t operator()(const double& key) const { return Reinterpret<double, uint64_t>(key); }
56  uint64_t operator()(const int64_t& key) const { return Reinterpret<int64_t, uint64_t>(key); }
57  uint64_t operator()(const uint64_t& key) const { return key; }
58  uint64_t operator()(const int& key) const { return Reinterpret<int, uint32_t>(key); }
59  uint64_t operator()(const bool& key) const { return key; }
60  uint64_t operator()(const runtime::DataType& key) const {
61  return Reinterpret<DLDataType, uint32_t>(key);
62  }
63  template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
64  uint64_t operator()(const ENum& key) const {
65  return Reinterpret<int64_t, uint64_t>(static_cast<int64_t>(key));
66  }
67  uint64_t operator()(const std::string& key) const {
68  return runtime::String::StableHashBytes(key.data(), key.length());
69  }
70 };
71 
83 class StructuralHash : public BaseValueHash {
84  public:
85  // inherit operator()
86  using BaseValueHash::operator();
92  TVM_DLL uint64_t operator()(const ObjectRef& key) const;
93 };
94 
111  public:
113  class Handler {
114  public:
120  virtual void SHashReduceHashedValue(uint64_t hashed_value) = 0;
127  virtual void SHashReduce(const ObjectRef& key, bool map_free_vars) = 0;
141  virtual void SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) = 0;
150  virtual bool LookupHashedValue(const ObjectRef& key, uint64_t* hashed_value) = 0;
155  virtual void MarkGraphNode() = 0;
156  };
157 
159  SHashReducer() = default;
165  explicit SHashReducer(Handler* handler, bool map_free_vars)
166  : handler_(handler), map_free_vars_(map_free_vars) {}
171  template <typename T,
172  typename = typename std::enable_if<!std::is_base_of<ObjectRef, T>::value>::type>
173  void operator()(const T& key) const {
174  // handle normal values.
175  handler_->SHashReduceHashedValue(BaseValueHash()(key));
176  }
181  void operator()(const ObjectRef& key) const { return handler_->SHashReduce(key, map_free_vars_); }
187  void DefHash(const ObjectRef& key) const { return handler_->SHashReduce(key, true); }
193  void FreeVarHashImpl(const runtime::Object* var) const {
194  handler_->SHashReduceFreeVar(var, map_free_vars_);
195  }
196 
198  Handler* operator->() const { return handler_; }
199 
200  private:
202  Handler* handler_;
208  bool map_free_vars_;
209 };
210 
217  public:
219  virtual ~SHashHandlerDefault();
220 
221  void SHashReduceHashedValue(uint64_t hashed_value) override;
222  void SHashReduce(const ObjectRef& key, bool map_free_vars) override;
223  void SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) override;
224  bool LookupHashedValue(const ObjectRef& key, uint64_t* hashed_value) override;
225  void MarkGraphNode() override;
226 
233  virtual uint64_t Hash(const ObjectRef& object, bool map_free_vars);
234 
235  protected:
242  virtual void DispatchSHash(const ObjectRef& object, bool map_free_vars);
243 
244  private:
245  class Impl;
246  Impl* impl;
247 };
248 
249 class SEqualReducer;
251  static constexpr const std::nullptr_t VisitAttrs = nullptr;
252  static void SHashReduce(const runtime::NDArray::Container* key, SHashReducer hash_reduce);
253  static bool SEqualReduce(const runtime::NDArray::Container* lhs,
255 };
256 
257 } // namespace tvm
258 #endif // TVM_NODE_STRUCTURAL_HASH_H_
static uint64_t StableHashBytes(const char *data, size_t size)
Hash the binary bytes.
Definition: string.h:251
void FreeVarHashImpl(const runtime::Object *var) const
Implementation for hash for a free var.
Definition: structural_hash.h:193
Internal handler that defines custom behaviors.
Definition: structural_hash.h:113
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:110
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
uint64_t operator()(const ENum &key) const
Definition: structural_hash.h:64
uint64_t Reinterpret(T value) const
Definition: structural_hash.h:41
base class of all object containers.
Definition: object.h:167
uint64_t operator()(const double &key) const
Definition: structural_hash.h:55
uint64_t operator()(const bool &key) const
Definition: structural_hash.h:59
IntSet Union(const Array< IntSet > &sets)
Create a union set of all sets, possibly relaxed.
The default handler for hash key computation.
Definition: structural_hash.h:216
A device-independent managed NDArray abstraction.
uint64_t operator()(const uint64_t &key) const
Definition: structural_hash.h:57
Definition: structural_hash.h:250
Handler * operator->() const
Definition: structural_hash.h:198
Runtime primitive data type.
Definition: data_type.h:41
Hash definition of base value classes.
Definition: structural_hash.h:38
Object container class that backs NDArray.
Definition: ndarray.h:286
uint64_t operator()(const int64_t &key) const
Definition: structural_hash.h:56
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Defines the Functor data structures.
Base class of all object reference.
Definition: object.h:511
SHashReducer(Handler *handler, bool map_free_vars)
Constructor with a specific handler.
Definition: structural_hash.h:165
uint64_t operator()(const std::string &key) const
Definition: structural_hash.h:67
uint64_t operator()(const int &key) const
Definition: structural_hash.h:58
uint64_t operator()(const runtime::DataType &key) const
Definition: structural_hash.h:60
void operator()(const T &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:173
Content-aware structural hashing.
Definition: structural_hash.h:83
uint64_t operator()(const float &key) const
Definition: structural_hash.h:54
void operator()(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:181
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:187