SourceXtractorPlusPlus  0.15
Please provide a description of the project.
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
OnnxSourceTask.cpp
Go to the documentation of this file.
1 
22 #include <NdArray/NdArray.h>
24 #include <onnxruntime_cxx_api.h>
25 
26 namespace NdArray = Euclid::NdArray;
27 
28 namespace SourceXtractor {
29 
30 
31 template<typename T>
32 static void fillCutout(const Image<T>& image, int center_x, int center_y, int width, int height, std::vector<T>& out) {
33  int x_start = center_x - width / 2;
34  int y_start = center_y - height / 2;
35  int x_end = x_start + width;
36  int y_end = y_start + height;
37 
38  ImageAccessor<T> accessor(image);
39 
40  int index = 0;
41  for (int iy = y_start; iy < y_end; iy++) {
42  for (int ix = x_start; ix < x_end; ix++, index++) {
43  if (ix >= 0 && iy >= 0 && ix < image.getWidth() && iy < image.getHeight()) {
44  out[index] = accessor.getValue(ix, iy);
45  }
46  }
47  }
48 }
49 
50 OnnxSourceTask::OnnxSourceTask(const std::vector<OnnxModel>& models) : m_models(models) {}
51 
59 template<typename O>
61 computePropertiesSpecialized(const OnnxModel& model, const DetectionFrameImages& detection_frame_images,
62  const PixelCentroid& centroid) {
63  Ort::RunOptions run_options;
64  auto mem_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
65 
66  const int center_x = static_cast<int>(centroid.getCentroidX() + 0.5);
67  const int center_y = static_cast<int>(centroid.getCentroidY() + 0.5);
68 
69  // Allocate memory
70  std::vector<int64_t> input_shape(model.m_input_shape.begin(), model.m_input_shape.end());
71  input_shape[0] = 1;
72  size_t input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1u, std::multiplies<size_t>());
73  std::vector<float> input_data(input_size);
74 
75  std::vector<int64_t> output_shape(model.m_output_shape.begin(), model.m_output_shape.end());
76  output_shape[0] = 1;
77  size_t output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1u, std::multiplies<size_t>());
78  std::vector<O> output_data(output_size);
79 
80  // Cut the needed area
81  {
82  const auto& image = detection_frame_images.getLockedImage(LayerSubtractedImage);
83  fillCutout(*image, center_x, center_y, input_shape[2], input_shape[3], input_data);
84  }
85 
86  // Setup input/output tensors
87  auto input_tensor = Ort::Value::CreateTensor<float>(
88  mem_info, input_data.data(), input_data.size(), input_shape.data(), input_shape.size());
89  auto output_tensor = Ort::Value::CreateTensor<O>(
90  mem_info, output_data.data(), output_data.size(), output_shape.data(), output_shape.size());
91 
92  // Run the model
93  const char *input_name = model.m_input_name.c_str();
94  const char *output_name = model.m_output_name.c_str();
95  model.m_session->Run(run_options,
96  &input_name, &input_tensor, 1,
97  &output_name, &output_tensor, 1);
98 
99  // Set the output
100  std::vector<size_t> catalog_shape{model.m_output_shape.begin() + 1, model.m_output_shape.end()};
101  return Euclid::make_unique<OnnxProperty::NdWrapper<O>>(catalog_shape, output_data);
102 }
103 
105  const auto& detection_frame_images = source.getProperty<DetectionFrameImages>();
106  const auto& centroid = source.getProperty<PixelCentroid>();
107 
109 
110  for (const auto& model : m_models) {
112 
113  switch (model.m_output_type) {
114  case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
115  result = computePropertiesSpecialized<float>(model, detection_frame_images, centroid);
116  break;
117  case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
118  result = computePropertiesSpecialized<int32_t>(model, detection_frame_images, centroid);
119  break;
120  default:
121  throw Elements::Exception() << "This should have not happened!" << model.m_output_type;
122  }
123 
124  output_dict.emplace(model.m_prop_name, std::move(result));
125  }
126 
127  source.setProperty<OnnxProperty>(std::move(output_dict));
128 }
129 
130 } // end of namespace SourceXtractor
OnnxSourceTask(const std::vector< OnnxModel > &models)
Euclid::NdArray::NdArray< T > NdArray
const PropertyType & getProperty(unsigned int index=0) const
Convenience template method to call getProperty() with a more user-friendly syntax.
std::vector< std::int64_t > m_input_shape
Input tensor shape.
Definition: OnnxModel.h:38
SeFloat getCentroidY() const
Y coordinate of centroid.
Definition: PixelCentroid.h:53
T end(T...args)
The centroid of all the pixels in the source, weighted by their DetectionImage pixel values...
Definition: PixelCentroid.h:37
static std::unique_ptr< OnnxProperty::NdWrapperBase > computePropertiesSpecialized(const OnnxModel &model, const DetectionFrameImages &detection_frame_images, const PixelCentroid &centroid)
virtual int getHeight() const =0
Returns the height of the image in pixels.
std::string m_input_name
Input tensor name.
Definition: OnnxModel.h:34
STL class.
const std::vector< OnnxModel > & m_models
std::string m_output_name
Output tensor name.
Definition: OnnxModel.h:35
T data(T...args)
std::shared_ptr< ImageAccessor< SeFloat > > getLockedImage(FrameImageLayer layer) const
void computeProperties(SourceInterface &source) const override
Computes one or more properties for the Source.
T move(T...args)
std::vector< std::int64_t > m_output_shape
Output tensor shape.
Definition: OnnxModel.h:39
T size(T...args)
STL class.
STL class.
T begin(T...args)
T c_str(T...args)
Interface representing an image.
Definition: Image.h:43
static void fillCutout(const Image< T > &image, int center_x, int center_y, int width, int height, std::vector< T > &out)
T accumulate(T...args)
The SourceInterface is an abstract &quot;source&quot; that has properties attached to it.
virtual int getWidth() const =0
Returns the width of the image in pixels.
SeFloat getCentroidX() const
X coordinate of centroid.
Definition: PixelCentroid.h:48
std::unique_ptr< Ort::Session > m_session
Session, one per model. In theory, it is thread-safe.
Definition: OnnxModel.h:41