Interactive MNIST Digit Classification Sample using ONNX RT.
#include <cstdlib>
#include <iostream>
#include <vector>
#include <random>
#include <iomanip>
#include <string>
#include <sstream>
{
std::cout << "\nDigit Pattern (28x28):" << std::endl;
std::cout << "=====================" << std::endl;
for (int i = 0; i < 28; i++)
{
for (int j = 0; j < 28; j++)
{
char pixel = data[i * 28 + j] > 0.5f ? '#' : '.';
std::cout << pixel;
}
std::cout << std::endl;
}
std::cout << std::endl;
}
{
std::vector<float> data(28 * 28, 0.0f);
int digit = 0;
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> digit_gen(0, 9);
std::uniform_real_distribution<float> noise_gen(0.0f, 0.3f);
if (input == "auto")
{
digit = digit_gen(gen);
std::cout << "Auto-generating digit: " << digit << std::endl;
}
else if (input.length() == 1 && std::isdigit(input[0]))
{
digit = input[0] - '0';
std::cout << "Creating simple pattern for digit: " << digit << std::endl;
}
else
{
std::cout << "Invalid input: " << input << std::endl;
return data;
}
switch (digit)
{
case 0:
for (int i = 6; i < 22; i++)
{
for (int j = 6; j < 22; j++)
{
int di = i - 14, dj = j - 14;
float oval = (di*di)/36.0f + (dj*dj)/64.0f;
if (oval >= 0.8f && oval <= 1.2f)
data[i * 28 + j] = 1.0f;
}
}
break;
case 1:
for (int i = 4; i < 24; i++)
{
data[i * 28 + 14] = 1.0f;
if (i < 8) data[i * 28 + 13] = 1.0f;
}
break;
case 2:
for (int j = 6; j < 22; j++)
{
data[6 * 28 + j] = 1.0f;
data[7 * 28 + j] = 1.0f;
if (j > 8 && j < 20) data[5 * 28 + j] = 1.0f;
}
for (int i = 6; i < 12; i++)
{
data[i * 28 + 20] = 1.0f;
data[i * 28 + 21] = 1.0f;
if (i > 7) data[i * 28 + 19] = 1.0f;
if (i > 8) data[i * 28 + 18] = 1.0f;
if (i > 9) data[i * 28 + 17] = 1.0f;
}
for (int j = 6; j < 22; j++)
{
data[12 * 28 + j] = 1.0f;
data[13 * 28 + j] = 1.0f;
if (j > 8 && j < 20) data[11 * 28 + j] = 1.0f;
}
for (int i = 13; i < 20; i++)
{
data[i * 28 + 6] = 1.0f;
data[i * 28 + 7] = 1.0f;
if (i < 18) data[i * 28 + 8] = 1.0f;
if (i < 17) data[i * 28 + 9] = 1.0f;
if (i < 16) data[i * 28 + 10] = 1.0f;
}
for (int j = 6; j < 22; j++)
{
data[20 * 28 + j] = 1.0f;
data[21 * 28 + j] = 1.0f;
if (j > 8 && j < 20) data[22 * 28 + j] = 1.0f;
}
break;
case 3:
for (int j = 6; j < 22; j++)
{
data[6 * 28 + j] = 1.0f;
data[7 * 28 + j] = 1.0f;
}
for (int i = 6; i < 12; i++)
{
data[i * 28 + 20] = 1.0f;
data[i * 28 + 21] = 1.0f;
}
for (int j = 6; j < 22; j++)
{
data[12 * 28 + j] = 1.0f;
data[13 * 28 + j] = 1.0f;
}
for (int i = 13; i < 20; i++)
{
data[i * 28 + 20] = 1.0f;
data[i * 28 + 21] = 1.0f;
}
for (int j = 6; j < 22; j++)
{
data[20 * 28 + j] = 1.0f;
data[21 * 28 + j] = 1.0f;
}
break;
case 4:
for (int i = 4; i < 24; i++)
{
if (i < 14)
{
data[i * 28 + 4] = 1.0f;
data[i * 28 + 20] = 1.0f;
}
else if (i < 18)
{
for (int j = 4; j < 24; j++)
data[i * 28 + j] = 1.0f;
}
else
{
data[i * 28 + 20] = 1.0f;
}
}
break;
case 5:
for (int i = 4; i < 24; i++)
{
if (i < 8 || (i > 10 && i < 14) || i > 20)
{
for (int j = 4; j < 24; j++)
data[i * 28 + j] = 1.0f;
}
else if (i < 14)
{
data[i * 28 + 4] = 1.0f;
}
else
{
data[i * 28 + 20] = 1.0f;
}
}
break;
case 6:
for (int i = 4; i < 24; i++)
{
if (i < 8 || (i > 10 && i < 14) || i > 20)
{
for (int j = 4; j < 24; j++)
data[i * 28 + j] = 1.0f;
}
else if (i < 14)
{
data[i * 28 + 4] = 1.0f;
}
else
{
data[i * 28 + 4] = 1.0f;
data[i * 28 + 20] = 1.0f;
}
}
break;
case 7:
for (int i = 4; i < 24; i++)
{
if (i < 8)
{
for (int j = 4; j < 24; j++)
data[i * 28 + j] = 1.0f;
}
else
{
data[i * 28 + 20] = 1.0f;
}
}
break;
case 8:
for (int i = 6; i < 12; i++)
{
for (int j = 8; j < 20; j++)
{
int di = i - 9, dj = j - 14;
float circle = (di*di)/9.0f + (dj*dj)/36.0f;
if (circle >= 0.8f && circle <= 1.2f)
data[i * 28 + j] = 1.0f;
}
}
for (int i = 14; i < 22; i++)
{
for (int j = 6; j < 22; j++)
{
int di = i - 18, dj = j - 14;
float circle = (di*di)/16.0f + (dj*dj)/64.0f;
if (circle >= 0.7f && circle <= 1.3f)
data[i * 28 + j] = 1.0f;
}
}
for (int i = 10; i < 16; i++)
{
data[i * 28 + 6] = 1.0f;
data[i * 28 + 7] = 1.0f;
data[i * 28 + 20] = 1.0f;
data[i * 28 + 21] = 1.0f;
}
break;
case 9:
for (int i = 6; i < 14; i++)
{
for (int j = 6; j < 22; j++)
{
int di = i - 10, dj = j - 14;
float circle = (di*di)/16.0f + (dj*dj)/64.0f;
if (circle >= 0.7f && circle <= 1.3f)
data[i * 28 + j] = 1.0f;
}
}
for (int i = 16; i < 22; i++)
{
for (int j = 10; j < 20; j++)
{
int di = i - 19, dj = j - 15;
float circle = (di*di)/4.0f + (dj*dj)/25.0f;
if (circle >= 0.8f && circle <= 1.2f)
data[i * 28 + j] = 1.0f;
}
}
for (int i = 10; i < 18; i++)
{
data[i * 28 + 20] = 1.0f;
data[i * 28 + 21] = 1.0f;
}
for (int i = 6; i < 12; i++)
{
data[i * 28 + 6] = 1.0f;
data[i * 28 + 7] = 1.0f;
}
break;
}
for (int i = 0; i < 28 * 28; i++)
{
data[i] += noise_gen(gen);
data[i] = std::min(1.0f, std::max(0.0f, data[i]));
}
return data;
}
{
std::cout << "Interactive MNIST Digit Classification Using ONNX RT" << std::endl;
std::cout << "===================================================" << std::endl;
{
std::cerr << "Failed to create Context" << std::endl;
return EXIT_FAILURE;
}
{
std::cerr << "Failed to create Graph" << std::endl;
return EXIT_FAILURE;
}
std::string model_path = "/Users/Andrew/Downloads/mnist.onnx";
{
std::cerr << "Failed to create model path array" << std::endl;
return EXIT_FAILURE;
}
if (model_path_array->addItems(model_path.length() + 1, model_path.c_str(),
sizeof(
char)) !=
VX_SUCCESS)
{
std::cerr << "Failed to add model path to array" << std::endl;
return EXIT_FAILURE;
}
size_t input_dims[] = {1, 1, 28, 28};
{
std::cerr << "Failed to create input tensor" << std::endl;
return EXIT_FAILURE;
}
size_t output_dims[] = {1, 10};
{
std::cerr << "Failed to create output tensor" << std::endl;
return EXIT_FAILURE;
}
{
std::cerr << "Failed to create object arrays" << std::endl;
return EXIT_FAILURE;
}
if (input_tensors->setItem(0, input_tensor) !=
VX_SUCCESS)
{
std::cerr << "Failed to set input tensor in array" << std::endl;
return EXIT_FAILURE;
}
if (output_tensors->setItem(0, output_tensor) !=
VX_SUCCESS)
{
std::cerr << "Failed to set output tensor in array" << std::endl;
return EXIT_FAILURE;
}
{
std::cerr << "Failed to get ONNX runtime kernel. Make sure ONNX RT target is loaded." << std::endl;
return EXIT_FAILURE;
}
auto node =
Node::createNode(graph, kernel, {model_path_array, input_tensors, output_tensors});
{
std::cerr << "Failed to create ONNX node" << std::endl;
return EXIT_FAILURE;
}
size_t view_start[4] = {0, 0, 0, 0};
size_t output_view_start[2] = {0, 0};
const size_t* input_strides = input_tensor->strides();
const size_t* output_strides = output_tensor->strides();
while (true)
{
std::string input;
std::cout << "\nEnter single digit (0-9) (or type 'auto' to auto-generate digit, 'quit' to exit): ";
std::getline(std::cin, input);
if (input.empty())
{
continue;
}
if (input == "quit" || input == "q" || input == "exit")
{
std::cout << "Goodbye!" << std::endl;
return EXIT_SUCCESS;
}
if (input_tensor->copyPatch(4, view_start, input_dims, input_strides,
{
std::cerr << "Failed to copy input data to tensor" << std::endl;
continue;
}
std::cout << "Processing digit classification..." << std::endl;
{
std::cerr << "Failed to process graph" << std::endl;
continue;
}
std::vector<float> output_data(10);
if (output_tensor->copyPatch(2, output_view_start, output_dims, output_strides,
{
std::cerr << "Failed to copy output data from tensor" << std::endl;
continue;
}
int predicted_digit = 0;
float max_prob = output_data[0];
for (int i = 1; i < 10; i++)
{
if (output_data[i] > max_prob)
{
max_prob = output_data[i];
predicted_digit = i;
}
}
std::cout << "\nClassification Results:" << std::endl;
std::cout << "======================" << std::endl;
for (int i = 0; i < 10; i++)
{
std::cout << "Digit " << i << ": " << std::fixed << std::setprecision(4)
<< output_data[i] << (i == predicted_digit ? " <-- PREDICTED" : "") << std::endl;
}
std::cout << "\nPredicted digit: " << predicted_digit << " (confidence: "
<< std::fixed << std::setprecision(2) << (max_prob * 100) << "%)" << std::endl;
std::cout << "\n" << std::string(50, '=') << std::endl;
}
return EXIT_SUCCESS;
}
CoreVX single-include header for C++ development.
int main()
Definition blur_pipeline.cpp:15
@ VX_TYPE_TENSOR
A vx_tensor.
Definition vx_types.h:507
@ VX_TYPE_FLOAT32
A vx_float32.
Definition vx_types.h:444
@ VX_TYPE_CHAR
A vx_char.
Definition vx_types.h:435
@ VX_SUCCESS
No error.
Definition vx_types.h:543
@ VX_READ_ONLY
The memory shall be treated by the system as if it were read-only. If the User writes to this memory,...
Definition vx_types.h:1515
@ VX_WRITE_ONLY
The memory shall be treated by the system as if it were write-only. If the User reads from this memor...
Definition vx_types.h:1519
@ VX_MEMORY_TYPE_HOST
The default memory type to import from the Host.
Definition vx_types.h:1338
@ VX_KERNEL_ORT_CPU_INF
The ONNX Runtime CPU Inference kernel.
Definition vx_corevx_ext.h:236
static vx_array createArray(vx_context context, vx_enum item_type, vx_size capacity, vx_bool is_virtual=vx_false_e, vx_enum type=VX_TYPE_ARRAY)
Create a Array object.
static vx_context createContext()
Create a new context.
static vx_status getStatus(vx_reference ref)
Provides a generic API to return status values from Object constructors if they fail.
static vx_graph createGraph(vx_context context)
Create a graph.
static vx_kernel getKernelByEnum(vx_context context, vx_enum kernelenum)
Get the Kernel By Enum.
static vx_node createNode(vx_graph graph, vx_kernel kernel)
Create a new node.
static vx_object_array createObjectArray(vx_context context, vx_enum type)
Create a Object Array object.
static vx_tensor createTensor(vx_context context, vx_size number_of_dims, const vx_size *dims, vx_enum data_type, vx_int8 fixed_point_position)
Create a tensor object.
The internal representation of a vx_array.
Definition vx_array.h:34
std::vector< float > createDigitFromInput(std::string &input)
Create a digit pattern from user input.
Definition ort_classification_sample.cpp:49
void printDigitPattern(const std::vector< float > &data)
Print the digit pattern to the console.
Definition ort_classification_sample.cpp:27