tvm.auto_scheduler

Namespace for TVM Auto-scheduler.

Classes:

ComputeDAG(compute_or_sche)

The auto-scheduler’s computational graph and related program analyses.

EmptyPolicy(task[, init_search_callbacks])

A simple example of the search policy which always returns the initial naive schedule (state).

HardwareParams(num_cores, vector_unit_bytes, …)

The parameters of target hardware used to guide the search policy

LocalBuilder([timeout, n_parallel, build_func])

LocalBuilder use local CPU cores to build programs in parallel.

LocalRPCMeasureContext([priority, …])

A context wrapper for running RPCRunner locally.

LocalRunner([timeout, number, repeat, …])

LocalRunner that uses local CPU/GPU to measures the time cost of programs.

MeasureInput(task, state)

Store the input of a measurement.

MeasureResult(costs, error_no, error_msg, …)

Store the results of a measurement.

PreloadMeasuredStates([filename])

A SearchCallback to load measured states from the log file for a search policy.

RPCRunner(key, host, port[, priority, …])

RPCRunner that uses RPC call to measures the time cost of programs on remote devices.

RandomModel()

A model returns random estimation for all inputs

RecordReader([filename])

Reader of the json log file.

RecordToFile([filename])

A measurement callback that writes measurement records into a file.

SearchTask(dag, workload_key, target[, …])

The computation information and hardware parameters for a schedule search task.

SketchPolicy(task[, program_cost_model, …])

The search policy that searches in a hierarchical search space defined by sketches.

TaskScheduler(tasks[, objective_func, strategy])

Allocate the time resources when tuning multiple tasks together.

TuningOptions([num_measure_trials, …])

This controls the options of performance tuning.

XGBModel([verbose_eval, num_warmup_sample, seed])

Train a XGBoost model to predict the normalized throughputs of programs.

Functions:

auto_schedule(task[, search_policy, …])

Run auto scheduling search for a task

create_task(func, args, target[, …])

Create a search task

load_best(filename[, workload_key, target])

Return the best measurement pair form a log file.

load_records(filename)

Load measurement records from a file.

make_workload_key(func, args)

Make a workload key by function and arguments.

register_workload(func_name[, f, override])

Register a function that generates a certain workload.

save_records(filename, inputs, results)

Append measure records to file.

class tvm.auto_scheduler.TuningOptions(num_measure_trials=0, early_stopping=None, num_measures_per_round=64, verbose=1, builder='local', runner='local', measure_callbacks=None)

This controls the options of performance tuning.

Parameters
  • num_measure_trials (int = 0) – The number of measurement trials. The search policy measures num_measure_trials schedules in total and returns the best one among them. With num_measure_trials == 0, the policy will do the schedule search but won’t involve measurement. This can be used to get a runnable schedule quickly without auto-tuning.

  • early_stopping (Optional[int]) – Stop the tuning early if getting no improvement after n measurements.

  • num_measures_per_round (int = 64) – The number of schedules to be measured at each search round. The whole schedule search process will try a total number of num_measure_trials in several rounds.

  • verbose (int = 1) – Verbosity level. 0 for silent, 1 to output information during schedule search.

  • builder (Union[ProgramBuilder, str] = 'local') – ProgramBuilder which builds the program.

  • runner (Union[ProgramRunner, str] = 'local') – ProgramRunner which runs the program and measures time costs.

  • measure_callbacks (Optional[List[MeasureCallback]]) – Callback functions called after each measurement. Candidates: - auto_scheduler.RecordToFile

class tvm.auto_scheduler.HardwareParams(num_cores, vector_unit_bytes, cache_line_bytes)

The parameters of target hardware used to guide the search policy

TODO(jcf94): This is considered to be merged with the new Target specification: https://discuss.tvm.ai/t/rfc-tvm-target-specification/6844

Parameters
  • num_cores (int) – The number of device cores.

  • vector_unit_bytes (int) – The width of vector units in bytes.

  • cache_line_bytes (int) – The size of cache line in bytes.

tvm.auto_scheduler.create_task(func, args, target, target_host=None, hardware_params=None)

Create a search task

Parameters
  • func (Union[Function, str]) – The function that returns the compute declaration Tensors. Can be the a function or the function name.

  • args (Union[Tuple[Any, ..], List[Any]]) – The args of the function.

  • target (Union[tvm.target.Target, str]) – The target device of this search task.

  • target_host (Optional[Union[tvm.target.Target, str]]) – The target host device of this search task.

  • hardware_params (Optional[HardwareParams]) – Hardware parameters used in this search task.

Returns

SearchTask

Return type

the created task

tvm.auto_scheduler.auto_schedule(task, search_policy=None, tuning_options=auto_scheduler.TuningOptions(46678784))

Run auto scheduling search for a task

Parameters
  • task (SearchTask) – The SearchTask for the computation declaration.

  • search_policy (Optional[SearchPolicy]) – The search policy to be used for schedule search.

  • tuning_options (Optional[TuningOptions]) – Tuning and measurement options.

Returns

Return type

A te.Schedule and the a list of te.Tensor to be used in tvm.lower or tvm.build.

class tvm.auto_scheduler.ComputeDAG(compute_or_sche)

The auto-scheduler’s computational graph and related program analyses.

We convert a compute declaration described by tvm.compute (could be a single operator or a subgraph) to a ComputeDAG. It keeps the input/output tensors, all operations in the DAG, and some static analysis results for the DAG (e.g. the total float operation count, consumer/producer relations of operations, whether an operation stage should be tiled/compute inlined). These analyses can help the search policy to make decisions during the search. ComputeDAG is also responsible for the interaction between auto-scheduler’s LoopState and TVM schedule (e.g. applying the LoopState transform steps to a TVM schedule, providing LoopState with extra information got from TVM schedule).

Parameters

compute (Union[List[Tensor], str, Schedule]) – Input/output tensors or workload key for a compute declaration.

Methods:

apply_steps_from_state(state[, layout_rewrite])

Apply the history transform steps from a State to get a TVM schedule.

get_init_state()

Get the init state of this ComputeDAG.

infer_bound_from_state(state)

Infer and fill the bound of all iterators of a state.

print_python_code_from_state(state)

Print transform steps in the history of a State as TVM’s python schedule code.

get_init_state()

Get the init state of this ComputeDAG.

Returns

state – The initial State without any transform steps.

Return type

State

apply_steps_from_state(state, layout_rewrite=False)

Apply the history transform steps from a State to get a TVM schedule.

Parameters
  • state (Union[State, StateObject]) – The state from which we get transform steps.

  • layout_rewrite (Bool) – Rewrite the layout of placeholders specified by “layout_free_placeholders” attr to make it most friendly for the generated schedule to read from.

Returns

Return type

A te.schedule and the a list of te.Tensor to be used in tvm.lower or tvm.build.

print_python_code_from_state(state)

Print transform steps in the history of a State as TVM’s python schedule code.

This is used to print transformation steps for debugging. Use apply_steps_from_state if you want to get a schedule for code generation.

Parameters

state (Union[State, StateObject]) – The state from which we get transform steps.

Returns

str – The Python schedule code.

Return type

Str

infer_bound_from_state(state)

Infer and fill the bound of all iterators of a state.

The states may lose complete bound information after some transform steps (e.g., compute_at). We can call this function to infer and fill all the bound information. This function calls TVM InferBound pass internally to get the bound. The returned state of this function is guaranteed to have complete iterator extent information.

Parameters

state (Union[State, StateObject]) – The state from which we get transform steps.

Returns

updated_state – The State with complete bound information.

Return type

State

class tvm.auto_scheduler.RandomModel

A model returns random estimation for all inputs

Methods:

predict(search_task, states)

Predict the scores of states

update(inputs, results)

Update the cost model according to new measurement results (training data).

update(inputs, results)

Update the cost model according to new measurement results (training data).

Parameters
  • inputs (List[auto_scheduler.measure.MeasureInput]) – The measurement inputs

  • results (List[auto_scheduler.measure.MeasureResult]) – The measurement results

predict(search_task, states)

Predict the scores of states

Parameters
  • search_task (SearchTask) – The search task of states

  • states (List[State]) – The input states

Returns

scores – The predicted scores for all states

Return type

List[float]

class tvm.auto_scheduler.XGBModel(verbose_eval=25, num_warmup_sample=100, seed=None)

Train a XGBoost model to predict the normalized throughputs of programs. Let the normalized throughput be the score of a program (higher is better). We predict the (approximate) score of a program = the sum of the scores of all stages in this program. i.e. score(P) = score_s0 + score_s1 + … + score_sn, where score_si is the score of Stage i in Program P. We extract feature for each stage and let the xgboost predict the score for each stage. We then sum up the predictions as the score of the whole program. We use RMSE as the loss function. i.e. loss(P, y) = 1/2 * (score(P) - y)^2, where P is the program and y is the normalized throughput according to the ground truth (measurement). XGBoost does not support this loss function because score(P) is a sum of the prediction of several samples, so we implemented a custom loss function and call it pack-sum-rmse. It is called “pack-sum” because we combine several samples into a “pack” and sum up their predictions.

Methods:

load(file_name)

Load the model from a file :param file_name: The filename :type file_name: str

predict(task, states)

Predict the scores of states :param search_task: The search task of states :type search_task: SearchTask :param statse: The input states :type statse: List[State]

predict_stages(task, states)

Predict the scores of all stages in states.

save(file_name)

Save the model to a file :param file_name: The filename :type file_name: str

update(inputs, results)

Update the cost model according to new measurement results (training data).

update_from_file(file_name[, n_lines])

Load measure records from a log file to update the cost model.

update(inputs, results)

Update the cost model according to new measurement results (training data). XGBoost does not support incremental training, so we re-train a new model every time. :param inputs: The measurement inputs :type inputs: List[MeasureInput] :param results: The measurement results :type results: List[MeasureResult]

predict(task, states)

Predict the scores of states :param search_task: The search task of states :type search_task: SearchTask :param statse: The input states :type statse: List[State]

Returns

scores – The predicted scores for all states

Return type

List[float]

predict_stages(task, states)

Predict the scores of all stages in states. This is the breakdown version of predict.

Parameters
  • search_task (SearchTask) – The search task of states

  • statse (List[State]) – The input states

Returns

scores – The predicted scores for all stages in all states in the packed format

Return type

List[float]

Note

For faster data copy between c++ and python, the python part returns scores in a single flatten array using a packed format. The c++ part then unpacks the flatten array. The packed format is: {

float scores[N]; // scores[i] is the score for states[i]. int n_stage_0; // the number of stages in states[0] float stage_scores_0[[n_stage_0] // the scores for all stages in states[0] int n_stage_1; // the number of stages in states[1] float stage_scores_1[n_stage_1]; // the scores for all stages in states[1] … int n_stage_i; // the number of stages in states[i] float stage_scores_1[n_stage_i]; // the scores for all stages in states[i] … // untill i == N - 1

} To implement this format, we also store int as float, so we can store all numbers into a single float array.

update_from_file(file_name, n_lines=None)

Load measure records from a log file to update the cost model. This function can be used to pre-train the cost model with history log files. :param file_name: The filename :type file_name: str :param n_lines: Only load first n lines of the log file :type n_lines: Optional[int]

save(file_name: str)

Save the model to a file :param file_name: The filename :type file_name: str

load(file_name: str)

Load the model from a file :param file_name: The filename :type file_name: str

class tvm.auto_scheduler.MeasureInput(task, state)

Store the input of a measurement.

Parameters
  • task (SearchTask) – The SearchTask of this measurement.

  • state (Union[State, StateObject]) – The State to be measured.

class tvm.auto_scheduler.MeasureResult(costs, error_no, error_msg, all_cost, timestamp)

Store the results of a measurement.

Parameters
  • costs (List[float]) – The time costs of execution.

  • error_no (int) – The error code.

  • error_msg (Optional[str]) – The error message if there is any error.

  • all_cost (float) – The time cost of build and run.

  • timestamp (float) – The time stamps of this measurement.

class tvm.auto_scheduler.LocalBuilder(timeout=15, n_parallel=8, build_func='default')

LocalBuilder use local CPU cores to build programs in parallel.

Parameters
  • timeout (int = 15) – The timeout limit (in second) for each build thread. This is used in a wrapper of the multiprocessing.Process.join().

  • n_parallel (int = multiprocessing.cpu_count()) – Number of threads used to build in parallel.

  • build_func (str = 'default') – The name of registered build function.

class tvm.auto_scheduler.LocalRunner(timeout=10, number=3, repeat=1, min_repeat_ms=100, cooldown_interval=0.0, enable_cpu_cache_flush=False)

LocalRunner that uses local CPU/GPU to measures the time cost of programs.

Parameters
  • timeout (int = 10) – The timeout limit (in second) for each run. This is used in a wrapper of the multiprocessing.Process.join().

  • number (int = 3) – The number of times to run the generated code for taking average. We call these runs as one repeat of measurement.

  • repeat (int = 1) – The number of times to repeat the measurement. In total, the generated code will be run (1 + number x repeat) times, where the first “1” is warm up and will be discarded. The returned result contains repeat costs, each of which is an average of number costs.

  • min_repeat_ms (int = 100) – The minimum duration of one repeat in milliseconds. By default, one repeat contains number runs. If this parameter is set, the parameters number will be dynamically adjusted to meet the minimum duration requirement of one repeat. i.e., When the run time of one repeat falls below this time, the number parameter will be automatically increased.

  • cooldown_interval (float = 0.0) – The cool down interval between two measurements.

  • enable_cpu_cache_flush (bool = False) – Whether to flush cache on CPU between repeated measurements. Flushing cache can make the measured latency of one operator closer to its actual latency during end-to-end inference. To make this option effective, the argument number should also be set to 1. This is only has effect on CPU task.

class tvm.auto_scheduler.RPCRunner(key, host, port, priority=1, n_parallel=1, timeout=10, number=3, repeat=1, min_repeat_ms=100, cooldown_interval=0.0, enable_cpu_cache_flush=False)

RPCRunner that uses RPC call to measures the time cost of programs on remote devices. Or sometime we may need to use RPC even in local running to insulate the thread environment. (e.g. running CUDA programs)

Parameters
  • key (str) – The key of the device registered in the RPC tracker.

  • host (str) – The host address of the RPC Tracker.

  • port (int) – The port of RPC Tracker.

  • priority (int = 1) – The priority of this run request, larger is more prior.

  • n_parallel (int = 1) – The number of tasks run in parallel.

  • timeout (int = 10) – The timeout limit (in second) for each run. This is used in a wrapper of the multiprocessing.Process.join().

  • number (int = 3) – The number of times to run the generated code for taking average. We call these runs as one repeat of measurement.

  • repeat (int = 1) – The number of times to repeat the measurement. In total, the generated code will be run (1 + number x repeat) times, where the first “1” is warm up and will be discarded. The returned result contains repeat costs, each of which is an average of number costs.

  • min_repeat_ms (int = 100) – The minimum duration of one repeat in milliseconds. By default, one repeat contains number runs. If this parameter is set, the parameters number will be dynamically adjusted to meet the minimum duration requirement of one repeat. i.e., When the run time of one repeat falls below this time, the number parameter will be automatically increased.

  • cooldown_interval (float = 0.0) – The cool down interval between two measurements.

  • enable_cpu_cache_flush (bool = False) – Whether to flush cache on CPU between repeated measurements. Flushing cache can make the measured latency of one operator closer to its actual latency during end-to-end inference. To make this option effective, the argument number should also be set to 1. This is only has effect on CPU task.

class tvm.auto_scheduler.LocalRPCMeasureContext(priority=1, n_parallel=1, timeout=10, number=3, repeat=1, min_repeat_ms=0, cooldown_interval=0.0, enable_cpu_cache_flush=False)

A context wrapper for running RPCRunner locally. This will launch a local RPC Tracker and local RPC Server.

Parameters
  • priority (int = 1) – The priority of this run request, larger is more prior.

  • n_parallel (int = 1) – The number of tasks run in parallel.

  • timeout (int = 10) – The timeout limit (in second) for each run. This is used in a wrapper of the multiprocessing.Process.join().

  • number (int = 3) – The number of times to run the generated code for taking average. We call these runs as one repeat of measurement.

  • repeat (int = 1) – The number of times to repeat the measurement. In total, the generated code will be run (1 + number x repeat) times, where the first “1” is warm up and will be discarded. The returned result contains repeat costs, each of which is an average of number costs.

  • min_repeat_ms (int = 0) – The minimum duration of one repeat in milliseconds. By default, one repeat contains number runs. If this parameter is set, the parameters number will be dynamically adjusted to meet the minimum duration requirement of one repeat. i.e., When the run time of one repeat falls below this time, the number parameter will be automatically increased.

  • cooldown_interval (float = 0.0) – The cool down interval between two measurements.

  • enable_cpu_cache_flush (bool = False) – Whether to flush cache on CPU between repeated measurements. Flushing cache can make the measured latency of one operator closer to its actual latency during end-to-end inference. To make this option effective, the argument number should also be set to 1. This is only has effect on CPU task.

class tvm.auto_scheduler.RecordToFile(filename='auto_scheduler_tuning.json')

A measurement callback that writes measurement records into a file.

Parameters

filename (str) – File name for this callback to write log to.

class tvm.auto_scheduler.RecordReader(filename='auto_scheduler_tuning.json')

Reader of the json log file.

Parameters

filename (str = "auto_scheduler_tuning.json") – File name for this reader to load log from.

Methods:

read_lines([max_lines, skip_lines])

Read multiple lines from the log file.

read_lines(max_lines=None, skip_lines=0)

Read multiple lines from the log file.

Parameters
  • max_lines (Optional[int]) – The maximum number of lines. None to read all lines.

  • skip_lines (int = 0) – Skip the first n lines.

Returns

  • inputs (List[auto_scheduler.measure.MeasureInput]) – The MeasureInputs loaded from the log file.

  • results (List[auto_scheduler.measure.MeasureResult]) – The MeasureResults loaded from the log file.

Notes

Some unimportant and expensive fields in the returned MeasureInput are not deserialized for faster read speed (e.g. input.task.compute_dag, input.state.stages). If you want to use them, you can call the recover_measure_input below to rebuild these fields.

tvm.auto_scheduler.load_best(filename, workload_key=None, target=None)

Return the best measurement pair form a log file. This may return none results if there is no legal measure pair with the specified workload_key/target found from the log file.

Parameters
  • filename (str) – File name to load log from.

  • workload_key (Optional[str]) – The workload key of the compute declaration. With None, this returns the best measure pair of all workloads.

  • target (Optional[tvm.target.Target]) – The target device. With None, this returns the best measure pair of all target devices.

Returns

  • input (auto_scheduler.measure.MeasureInput) – The best State’s MeasureInput from this log fine.

  • result (auto_scheduler.measure.MeasureResult) – The best State’s MeasureResult from this log fine.

tvm.auto_scheduler.load_records(filename)

Load measurement records from a file.

Parameters

filename (str) – File name to load log from.

Returns

logs

Return type

List[auto_scheduler.measure.MeasureInput, auto_scheduler.measure.MeasureResult]

Notes

Some unimportant and expensive fields in the returned MeasureInput are not deserialized for faster read speed (e.g., input.task.compute_dag, input.state.stages). If you want to use them, you can call the recover_measure_input below to rebuild these fields.

tvm.auto_scheduler.save_records(filename, inputs, results)

Append measure records to file.

Parameters
  • filename (str) – File name to write log to.

  • inputs (List[MeasureInputs]) – The MeasureInputs to be written.

  • results (List[MeasureResults]) – The MeasureResults to be written.

class tvm.auto_scheduler.SearchTask(dag, workload_key, target, target_host=None, hardware_params=None)

The computation information and hardware parameters for a schedule search task.

Parameters
  • dag (ComputeDAG) – The ComputeDAG for the corresponding compute declaration.

  • workload_key (str) – The workload key for the corresponding compute declaration.

  • target (tvm.target.Target) – The target device of this search task.

  • target_host (Optional[tvm.target.Target]) – The target host device of this search task.

  • hardware_params (Optional[HardwareParams]) – Hardware parameters used in this search task.

class tvm.auto_scheduler.EmptyPolicy(task, init_search_callbacks=None)

A simple example of the search policy which always returns the initial naive schedule (state).

Parameters
  • task (SearchTask) – The SearchTask for the computation declaration.

  • init_search_callbacks (Optional[List[SearchCallback]]) – Callback functions called before the search process.

class tvm.auto_scheduler.SketchPolicy(task, program_cost_model=auto_scheduler.RandomModel(46881624), params=None, seed=None, verbose=1, init_search_callbacks=None)

The search policy that searches in a hierarchical search space defined by sketches. The policy randomly samples programs from the space defined by sketches and use evolutionary search to fine-tune them.

Parameters
  • task (SearchTask) – The SearchTask for the computation declaration.

  • program_cost_model (CostModel = RandomModel()) – The cost model to estimate the complete schedules.

  • params (Optional[Dict[str, Any]]) – Parameters of the search policy. See src/auto_scheduler/search_policy/sketch_search_policy.h for the definitions. See DEFAULT_PARAMS below to find the default values.

  • seed (Optional[int]) – Random seed.

  • verbose (int = 1) – Verbosity level. 0 for silent, 1 to output information during schedule search.

  • init_search_callbacks (Optional[List[SearchCallback]]) –

    Callback functions called before the search process, usually used to do extra initializations. Possible callbacks:

    • auto_scheduler.PreloadMeasuredStates

    • auto_scheduler.PreloadCustomSketchRule

    TODO(jcf94): Add these search callback implementations.

Methods:

evolutionary_search(init_populations, out_size)

Perform evolutionary search.

generate_sketches([print_for_debug])

Generate the sketches.

sample_initial_population(pop_size)

Sample initial population.

generate_sketches(print_for_debug=False)

Generate the sketches. This python interface is mainly used for debugging and testing. The actual search is all done in c++.

Parameters

print_for_debug (bool = False) – Whether print out the sketches for debug.

Returns

sketches – The generated sketches of this search task.

Return type

List[State]

sample_initial_population(pop_size)

Sample initial population. This python interface is mainly used for debugging and testing. The actual search is all done in c++.

Parameters

pop_size (int) – The size of sampled population

Returns

states – The sampled states

Return type

List[State]

Perform evolutionary search. This python interface is mainly used for debugging and testing. The actual search is all done in c++.

Parameters
  • init_populations (List[State]) – The initial population states

  • out_size (int) – The size of generated states

Returns

states – The generated states

Return type

List[State]

class tvm.auto_scheduler.PreloadMeasuredStates(filename='auto_scheduler_tuning.json')

A SearchCallback to load measured states from the log file for a search policy.

This can resume the state of the search policy:
  • Making sure an already measured state in former searches will never be measured again.

  • The history states can be used to speed up the search process(e.g. SketchPolicy uses history states as starting point to perform Evolutionary Search).

Parameters

filename (str) – The name of the record file.

class tvm.auto_scheduler.TaskScheduler(tasks, objective_func=None, strategy='gradient', load_model_file: str = None, load_log_file: str = None, verbose: int = 1, alpha: float = 0.2, beta: float = 2, gamma: float = 0.5, backward_window_size: int = 3)

Allocate the time resources when tuning multiple tasks together. This implements two strategies: “round-robin” and “gradient”.

Parameters
  • tasks (List[SearchTask]) – All tasks to tune

  • objective_func (Optional[Callable[List[float] -> float]]) – The objective function to be minimized. The objective function accepts the current latencies of all tasks and returns the objective. If not presented, the objective is the sum of the latencies of all task.

  • strategy (str = "gradient") – The scheduling strategy. “round-robin”: Tune tasks in round robin order. “gradient” : Tune tasks with gradient descent.

  • load_model_file (Optional[str]) – Load pre-trained model from this file. If this is None, the cost model will be trained from scratch.

  • load_log_file (Optional[str]) – Load measurement records from this file. If it is not None, the status of the task scheduler, search policies and cost models will be restored according to this file.

  • verbose (int = 1) – The level of verbosity. 0 means silent.

  • alpha (float = 0.2) – The parameter used for ‘gradient’ strategy

  • beta (float = 2) – The parameter used for ‘gradient’ strategy

  • backward_window_size (int = 3) – The parameter used for ‘gradient’ strategy

Methods:

tune(tune_option[, search_policy])

Tune a batch of tasks together.

tune(tune_option, search_policy='default')

Tune a batch of tasks together.

Parameters
  • tune_option (TuningOptions) – The options of tuning

  • search_policy (: Union[str, List[SearchPolicy]]) – The list of search policies. If it is str. “sketch.xgb” for SketchPolicy + XGBModel “sketch.random” for SketchPolicy + RandomModel

tvm.auto_scheduler.register_workload(func_name, f=None, override=False)

Register a function that generates a certain workload.

The input function should take hashable and jsonable arguments (int, float, tuple of int, tvm.tensor.Tensor, …) and return a list of tvm.tensor.Tensor.

Parameters
  • func_name (Union[Function, str]) – The generation function that returns the compute declaration Tensors or its function name.

  • f (Optional[Function]) – The generation function to be registered.

  • override (boolean = False) – Whether override existing entry.

Examples

@auto_scheduler.register_workload
def matmul(N, M, K):
    A = te.placeholder((N, K), name='A')
    B = te.placeholder((K, M), name='B')
    k = te.reduce_axis((0, K), name='k')
    C = te.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C')
    return [A, B, C]
tvm.auto_scheduler.make_workload_key(func, args)

Make a workload key by function and arguments.

Parameters
  • func (Union[Function, str]) – The function that returns the compute declaration Tensors. Can be the a function or the function name.

  • args (Args) – The args of the function.

Returns

workload_key – The workload key of the function.

Return type

str