tvm
string.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RUNTIME_CONTAINER_STRING_H_
25 #define TVM_RUNTIME_CONTAINER_STRING_H_
26 
27 #include <dmlc/endian.h>
28 #include <dmlc/logging.h>
30 #include <tvm/runtime/logging.h>
31 #include <tvm/runtime/memory.h>
32 #include <tvm/runtime/object.h>
33 
34 #include <algorithm>
35 #include <cstddef>
36 #include <cstring>
37 #include <initializer_list>
38 #include <memory>
39 #include <string>
40 #include <string_view>
41 #include <type_traits>
42 #include <unordered_map>
43 #include <utility>
44 #include <vector>
45 
46 namespace tvm {
47 namespace runtime {
48 
49 // Forward declare TVMArgValue
50 class TVMArgValue;
51 
53 class StringObj : public Object {
54  public:
56  const char* data;
57 
59  uint64_t size;
60 
61  static constexpr const uint32_t _type_index = TypeIndex::kRuntimeString;
62  static constexpr const char* _type_key = "runtime.String";
64 
65  private:
67  class FromStd;
68 
69  friend class String;
70 };
71 
98 class String : public ObjectRef {
99  public:
103  String() : String(std::string()) {}
112  String(std::string other); // NOLINT(*)
113 
119  String(const char* other) // NOLINT(*)
120  : String(std::string(other)) {}
121 
125  String(std::nullptr_t) // NOLINT(*)
126  : ObjectRef(nullptr) {}
127 
134  inline String& operator=(std::string other);
135 
141  inline String& operator=(const char* other);
142 
151  int compare(const String& other) const {
152  return memncmp(data(), other.data(), size(), other.size());
153  }
154 
163  int compare(const std::string& other) const {
164  return memncmp(data(), other.data(), size(), other.size());
165  }
166 
175  int compare(const char* other) const {
176  return memncmp(data(), other, size(), std::strlen(other));
177  }
178 
184  const char* c_str() const { return get()->data; }
185 
191  size_t size() const {
192  const auto* ptr = get();
193  return ptr->size;
194  }
195 
201  size_t length() const { return size(); }
202 
208  bool empty() const { return size() == 0; }
209 
216  char at(size_t pos) const {
217  if (pos < size()) {
218  return data()[pos];
219  } else {
220  throw std::out_of_range("tvm::String index out of bounds");
221  }
222  }
223 
229  const char* data() const { return get()->data; }
230 
236  operator std::string() const { return std::string{get()->data, size()}; }
237 
243  inline static bool CanConvertFrom(const TVMArgValue& val);
244 
251  static uint64_t StableHashBytes(const char* data, size_t size) {
252  const constexpr uint64_t kMultiplier = 1099511628211ULL;
253  const constexpr uint64_t kMod = 2147483647ULL;
254  union Union {
255  uint8_t a[8];
256  uint64_t b;
257  } u;
258  static_assert(sizeof(Union) == sizeof(uint64_t), "sizeof(Union) != sizeof(uint64_t)");
259  const char* it = data;
260  const char* end = it + size;
261  uint64_t result = 0;
262  for (; it + 8 <= end; it += 8) {
263  if (DMLC_IO_NO_ENDIAN_SWAP) {
264  u.a[0] = it[0];
265  u.a[1] = it[1];
266  u.a[2] = it[2];
267  u.a[3] = it[3];
268  u.a[4] = it[4];
269  u.a[5] = it[5];
270  u.a[6] = it[6];
271  u.a[7] = it[7];
272  } else {
273  u.a[0] = it[7];
274  u.a[1] = it[6];
275  u.a[2] = it[5];
276  u.a[3] = it[4];
277  u.a[4] = it[3];
278  u.a[5] = it[2];
279  u.a[6] = it[1];
280  u.a[7] = it[0];
281  }
282  result = (result * kMultiplier + u.b) % kMod;
283  }
284  if (it < end) {
285  u.b = 0;
286  uint8_t* a = u.a;
287  if (it + 4 <= end) {
288  a[0] = it[0];
289  a[1] = it[1];
290  a[2] = it[2];
291  a[3] = it[3];
292  it += 4;
293  a += 4;
294  }
295  if (it + 2 <= end) {
296  a[0] = it[0];
297  a[1] = it[1];
298  it += 2;
299  a += 2;
300  }
301  if (it + 1 <= end) {
302  a[0] = it[0];
303  it += 1;
304  a += 1;
305  }
306  if (!DMLC_IO_NO_ENDIAN_SWAP) {
307  std::swap(u.a[0], u.a[7]);
308  std::swap(u.a[1], u.a[6]);
309  std::swap(u.a[2], u.a[5]);
310  std::swap(u.a[3], u.a[4]);
311  }
312  result = (result * kMultiplier + u.b) % kMod;
313  }
314  return result;
315  }
316 
318 
319  private:
330  static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count);
331 
342  static String Concat(const char* lhs, size_t lhs_size, const char* rhs, size_t rhs_size) {
343  std::string ret(lhs, lhs_size);
344  ret.append(rhs, rhs_size);
345  return String(ret);
346  }
347 
348  // Overload + operator
349  friend String operator+(const String& lhs, const String& rhs);
350  friend String operator+(const String& lhs, const std::string& rhs);
351  friend String operator+(const std::string& lhs, const String& rhs);
352  friend String operator+(const String& lhs, const char* rhs);
353  friend String operator+(const char* lhs, const String& rhs);
354 
356 };
357 
360  public:
369  explicit FromStd(std::string other) : data_container{other} {}
370 
371  private:
373  std::string data_container;
374 
375  friend class String;
376 };
377 
378 inline String::String(std::string other) {
379  auto ptr = make_object<StringObj::FromStd>(std::move(other));
380  ptr->size = ptr->data_container.size();
381  ptr->data = ptr->data_container.data();
382  data_ = std::move(ptr);
383 }
384 
385 inline String& String::operator=(std::string other) {
386  String replace{std::move(other)};
387  data_.swap(replace.data_);
388  return *this;
389 }
390 
391 inline String& String::operator=(const char* other) { return operator=(std::string(other)); }
392 
393 inline String operator+(const String& lhs, const String& rhs) {
394  size_t lhs_size = lhs.size();
395  size_t rhs_size = rhs.size();
396  return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size);
397 }
398 
399 inline String operator+(const String& lhs, const std::string& rhs) {
400  size_t lhs_size = lhs.size();
401  size_t rhs_size = rhs.size();
402  return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size);
403 }
404 
405 inline String operator+(const std::string& lhs, const String& rhs) {
406  size_t lhs_size = lhs.size();
407  size_t rhs_size = rhs.size();
408  return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size);
409 }
410 
411 inline String operator+(const char* lhs, const String& rhs) {
412  size_t lhs_size = std::strlen(lhs);
413  size_t rhs_size = rhs.size();
414  return String::Concat(lhs, lhs_size, rhs.data(), rhs_size);
415 }
416 
417 inline String operator+(const String& lhs, const char* rhs) {
418  size_t lhs_size = lhs.size();
419  size_t rhs_size = std::strlen(rhs);
420  return String::Concat(lhs.data(), lhs_size, rhs, rhs_size);
421 }
422 
423 // Overload < operator
424 inline bool operator<(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) < 0; }
425 
426 inline bool operator<(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) > 0; }
427 
428 inline bool operator<(const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; }
429 
430 inline bool operator<(const String& lhs, const char* rhs) { return lhs.compare(rhs) < 0; }
431 
432 inline bool operator<(const char* lhs, const String& rhs) { return rhs.compare(lhs) > 0; }
433 
434 // Overload > operator
435 inline bool operator>(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) > 0; }
436 
437 inline bool operator>(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) < 0; }
438 
439 inline bool operator>(const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; }
440 
441 inline bool operator>(const String& lhs, const char* rhs) { return lhs.compare(rhs) > 0; }
442 
443 inline bool operator>(const char* lhs, const String& rhs) { return rhs.compare(lhs) < 0; }
444 
445 // Overload <= operator
446 inline bool operator<=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) <= 0; }
447 
448 inline bool operator<=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) >= 0; }
449 
450 inline bool operator<=(const String& lhs, const String& rhs) { return lhs.compare(rhs) <= 0; }
451 
452 inline bool operator<=(const String& lhs, const char* rhs) { return lhs.compare(rhs) <= 0; }
453 
454 inline bool operator<=(const char* lhs, const String& rhs) { return rhs.compare(lhs) >= 0; }
455 
456 // Overload >= operator
457 inline bool operator>=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) >= 0; }
458 
459 inline bool operator>=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) <= 0; }
460 
461 inline bool operator>=(const String& lhs, const String& rhs) { return lhs.compare(rhs) >= 0; }
462 
463 inline bool operator>=(const String& lhs, const char* rhs) { return lhs.compare(rhs) >= 0; }
464 
465 inline bool operator>=(const char* lhs, const String& rhs) { return rhs.compare(rhs) <= 0; }
466 
467 // Overload == operator
468 inline bool operator==(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) == 0; }
469 
470 inline bool operator==(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) == 0; }
471 
472 inline bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; }
473 
474 inline bool operator==(const String& lhs, const char* rhs) { return lhs.compare(rhs) == 0; }
475 
476 inline bool operator==(const char* lhs, const String& rhs) { return rhs.compare(lhs) == 0; }
477 
478 // Overload != operator
479 inline bool operator!=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) != 0; }
480 
481 inline bool operator!=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) != 0; }
482 
483 inline bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; }
484 
485 inline bool operator!=(const String& lhs, const char* rhs) { return lhs.compare(rhs) != 0; }
486 
487 inline bool operator!=(const char* lhs, const String& rhs) { return rhs.compare(lhs) != 0; }
488 
489 inline std::ostream& operator<<(std::ostream& out, const String& input) {
490  out.write(input.data(), input.size());
491  return out;
492 }
493 
494 inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) {
495  if (lhs == rhs && lhs_count == rhs_count) return 0;
496 
497  for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) {
498  if (lhs[i] < rhs[i]) return -1;
499  if (lhs[i] > rhs[i]) return 1;
500  }
501  if (lhs_count < rhs_count) {
502  return -1;
503  } else if (lhs_count > rhs_count) {
504  return 1;
505  } else {
506  return 0;
507  }
508 }
509 
510 inline size_t ObjectHash::operator()(const ObjectRef& a) const {
511  if (const auto* str = a.as<StringObj>()) {
512  return String::StableHashBytes(str->data, str->size);
513  }
514  return ObjectPtrHash()(a);
515 }
516 
517 inline bool ObjectEqual::operator()(const ObjectRef& a, const ObjectRef& b) const {
518  if (a.same_as(b)) {
519  return true;
520  }
521  if (const auto* str_a = a.as<StringObj>()) {
522  if (const auto* str_b = b.as<StringObj>()) {
523  return String::memncmp(str_a->data, str_b->data, str_a->size, str_b->size) == 0;
524  }
525  }
526  return false;
527 }
528 } // namespace runtime
529 
530 // expose the functions to the root namespace.
531 using runtime::String;
532 using runtime::StringObj;
533 } // namespace tvm
534 
535 namespace std {
536 
537 template <>
538 struct hash<::tvm::runtime::String> {
539  std::size_t operator()(const ::tvm::runtime::String& str) const {
540  return ::tvm::runtime::String::StableHashBytes(str.data(), str.size());
541  }
542 };
543 } // namespace std
544 
545 #endif // TVM_RUNTIME_CONTAINER_STRING_H_
Base class of all object reference.
Definition: object.h:519
const Object * get() const
Definition: object.h:554
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:605
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:906
bool same_as(const ObjectRef &other) const
Comparator.
Definition: object.h:530
base class of all object containers.
Definition: object.h:171
An object representing string moved from std::string.
Definition: string.h:359
FromStd(std::string other)
Construct a new FromStd object.
Definition: string.h:369
An object representing string. It's POD type.
Definition: string.h:53
const char * data
The pointer to string data.
Definition: string.h:56
static constexpr const char * _type_key
Definition: string.h:62
uint64_t size
The length of the string object.
Definition: string.h:59
static constexpr const uint32_t _type_index
Definition: string.h:61
TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object)
Reference to string objects.
Definition: string.h:98
int compare(const char *other) const
Compares this to other.
Definition: string.h:175
static bool CanConvertFrom(const TVMArgValue &val)
Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String.
Definition: packed_func.h:2221
friend String operator+(const String &lhs, const String &rhs)
Definition: string.h:393
String & operator=(std::string other)
Change the value the reference object points to.
Definition: string.h:385
const char * data() const
Return the data pointer.
Definition: string.h:229
String(const char *other)
Construct a new String object.
Definition: string.h:119
size_t length() const
Return the length of the string.
Definition: string.h:201
static uint64_t StableHashBytes(const char *data, size_t size)
Hash the binary bytes.
Definition: string.h:251
bool empty() const
Retun if the string is empty.
Definition: string.h:208
char at(size_t pos) const
Read an element.
Definition: string.h:216
int compare(const std::string &other) const
Compares this String object to other.
Definition: string.h:163
String()
Construct an empty string.
Definition: string.h:103
const char * c_str() const
Returns a pointer to the char array in the string.
Definition: string.h:184
String(std::nullptr_t)
Construct a new null object.
Definition: string.h:125
int compare(const String &other) const
Compares this String object to other.
Definition: string.h:151
size_t size() const
Return the length of the string.
Definition: string.h:191
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj)
IntSet Union(const Array< IntSet > &sets)
Create a union set of all sets, possibly relaxed.
bool operator<(const String &lhs, const std::string &rhs)
Definition: string.h:424
String operator+(const String &lhs, const String &rhs)
Definition: string.h:393
Array< T > Concat(Array< T > lhs, const Array< T > &rhs)
Concat two Arrays.
Definition: array.h:889
bool operator!=(const String &lhs, const std::string &rhs)
Definition: string.h:479
bool operator<=(const String &lhs, const std::string &rhs)
Definition: string.h:446
bool operator>=(const String &lhs, const std::string &rhs)
Definition: string.h:457
bool operator==(const String &lhs, const std::string &rhs)
Definition: string.h:468
std::ostream & operator<<(std::ostream &os, const ObjectRef &n)
Definition: repr_printer.h:97
bool operator>(const String &lhs, const std::string &rhs)
Definition: string.h:435
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
A managed object in the TVM runtime.
Base utilities for common POD(plain old data) container types.
Runtime memory management.
String-aware ObjectRef hash functor.
Definition: base.h:50
bool operator()(const ObjectRef &a, const ObjectRef &b) const
Check if the two ObjectRef are equal.
Definition: string.h:517
size_t operator()(const ObjectRef &a) const
Calculate the hash code of an ObjectRef.
Definition: string.h:510
ObjectRef hash functor.
Definition: object.h:655
@ kRuntimeString
runtime::String.
Definition: object.h:66