tvm
random_engine.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
24 #ifndef TVM_SUPPORT_RANDOM_ENGINE_H_
25 #define TVM_SUPPORT_RANDOM_ENGINE_H_
26 #include <tvm/runtime/logging.h>
27 
28 #include <cstdint>
29 #include <random>
30 
31 namespace tvm {
32 namespace support {
33 
45  public:
46  using TRandState = int64_t;
48  using result_type = uint64_t;
50  static constexpr TRandState multiplier = 48271;
52  static constexpr TRandState increment = 0;
54  static constexpr TRandState modulus = 2147483647;
56  static constexpr result_type min() { return 0; }
58  static constexpr result_type max() { return modulus - 1; }
59 
64  static TRandState DeviceRandom() { return (std::random_device()()) % modulus; }
65 
77  (*rand_state_ptr_) = ((*rand_state_ptr_) * multiplier + increment) % modulus;
78  return *rand_state_ptr_;
79  }
85  static TRandState NormalizeSeed(TRandState rand_state) {
86  if (rand_state == -1) {
87  rand_state = DeviceRandom();
88  } else {
89  rand_state %= modulus;
90  }
91  if (rand_state == 0) {
92  rand_state = 1;
93  }
94  if (rand_state < 0) {
95  LOG(FATAL) << "ValueError: Random seed must be non-negative";
96  }
97  return rand_state;
98  }
103  void Seed(TRandState rand_state) {
104  ICHECK(rand_state_ptr_ != nullptr);
105  *rand_state_ptr_ = NormalizeSeed(rand_state);
106  }
107 
113  // In order for reproducibility, we compute the new seed using RNG's random state and a
114  // different set of parameters. Note that both 32767 and 1999999973 are prime numbers.
115  return ((*this)() * 32767) % 1999999973;
116  }
117 
125  explicit LinearCongruentialEngine(TRandState* rand_state_ptr) {
126  rand_state_ptr_ = rand_state_ptr;
127  }
128 
129  private:
130  TRandState* rand_state_ptr_;
131 };
132 
133 } // namespace support
134 } // namespace tvm
135 
136 #endif // TVM_SUPPORT_RANDOM_ENGINE_H_
This linear congruential engine is a drop-in replacement for std::minstd_rand. It strictly correspond...
Definition: random_engine.h:44
TRandState ForkSeed()
Fork a new seed for another RNG from current random state.
Definition: random_engine.h:112
static constexpr TRandState multiplier
The multiplier.
Definition: random_engine.h:50
static constexpr TRandState modulus
The modulus.
Definition: random_engine.h:54
uint64_t result_type
The result type.
Definition: random_engine.h:48
static constexpr result_type max()
The maximum possible value of random state here.
Definition: random_engine.h:58
result_type operator()()
Operator to move the random state to the next and return the new random state. According to definitio...
Definition: random_engine.h:76
void Seed(TRandState rand_state)
Change the start random state of RNG with the seed of a new random state value.
Definition: random_engine.h:103
int64_t TRandState
Definition: random_engine.h:46
static TRandState NormalizeSeed(TRandState rand_state)
Normalize the random seed to the range of [1, modulus - 1].
Definition: random_engine.h:85
static TRandState DeviceRandom()
Get a device random state.
Definition: random_engine.h:64
static constexpr TRandState increment
The increment.
Definition: random_engine.h:52
static constexpr result_type min()
The minimum possible value of random state here.
Definition: random_engine.h:56
LinearCongruentialEngine(TRandState *rand_state_ptr)
Construct a random number generator with a random state pointer.
Definition: random_engine.h:125
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36