aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGuo, Yejun <yejun.guo@intel.com>2021-03-16 13:02:56 +0800
committerGuo, Yejun <yejun.guo@intel.com>2021-05-06 10:50:44 +0800
commitfc26dca64e0e5d20bb0fcc8743d073cf5b107264 (patch)
treeba5a30a5c1cbb922c8f8c3ba9b10ad22f86ccf95
parenta3b74651a0408ddb19c2f0334ad4ad3f368376a6 (diff)
downloadffmpeg-fc26dca64e0e5d20bb0fcc8743d073cf5b107264.tar.gz
lavfi/dnn: add classify support with openvino backend
Signed-off-by: Guo, Yejun <yejun.guo@intel.com>
-rw-r--r--libavfilter/dnn/dnn_backend_openvino.c143
-rw-r--r--libavfilter/dnn/dnn_io_proc.c60
-rw-r--r--libavfilter/dnn/dnn_io_proc.h1
-rw-r--r--libavfilter/dnn_filter_common.c21
-rw-r--r--libavfilter/dnn_filter_common.h2
-rw-r--r--libavfilter/dnn_interface.h10
6 files changed, 218 insertions, 19 deletions
diff --git a/libavfilter/dnn/dnn_backend_openvino.c b/libavfilter/dnn/dnn_backend_openvino.c
index 4e58ff6d9c..1ff8a720b9 100644
--- a/libavfilter/dnn/dnn_backend_openvino.c
+++ b/libavfilter/dnn/dnn_backend_openvino.c
@@ -29,6 +29,7 @@
#include "libavutil/avassert.h"
#include "libavutil/opt.h"
#include "libavutil/avstring.h"
+#include "libavutil/detection_bbox.h"
#include "../internal.h"
#include "queue.h"
#include "safe_queue.h"
@@ -74,6 +75,7 @@ typedef struct TaskItem {
// one task might have multiple inferences
typedef struct InferenceItem {
TaskItem *task;
+ uint32_t bbox_index;
} InferenceItem;
// one request for one call to openvino
@@ -182,12 +184,23 @@ static DNNReturnType fill_model_input_ov(OVModel *ov_model, RequestItem *request
request->inferences[i] = inference;
request->inference_count = i + 1;
task = inference->task;
- if (task->do_ioproc) {
- if (ov_model->model->frame_pre_proc != NULL) {
- ov_model->model->frame_pre_proc(task->in_frame, &input, ov_model->model->filter_ctx);
- } else {
- ff_proc_from_frame_to_dnn(task->in_frame, &input, ov_model->model->func_type, ctx);
+ switch (task->ov_model->model->func_type) {
+ case DFT_PROCESS_FRAME:
+ case DFT_ANALYTICS_DETECT:
+ if (task->do_ioproc) {
+ if (ov_model->model->frame_pre_proc != NULL) {
+ ov_model->model->frame_pre_proc(task->in_frame, &input, ov_model->model->filter_ctx);
+ } else {
+ ff_proc_from_frame_to_dnn(task->in_frame, &input, ov_model->model->func_type, ctx);
+ }
}
+ break;
+ case DFT_ANALYTICS_CLASSIFY:
+ ff_frame_to_dnn_classify(task->in_frame, &input, inference->bbox_index, ctx);
+ break;
+ default:
+ av_assert0(!"should not reach here");
+ break;
}
input.data = (uint8_t *)input.data
+ input.width * input.height * input.channels * get_datatype_size(input.dt);
@@ -276,6 +289,13 @@ static void infer_completion_callback(void *args)
}
task->ov_model->model->detect_post_proc(task->out_frame, &output, 1, task->ov_model->model->filter_ctx);
break;
+ case DFT_ANALYTICS_CLASSIFY:
+ if (!task->ov_model->model->classify_post_proc) {
+ av_log(ctx, AV_LOG_ERROR, "classify filter needs to provide post proc\n");
+ return;
+ }
+ task->ov_model->model->classify_post_proc(task->out_frame, &output, request->inferences[i]->bbox_index, task->ov_model->model->filter_ctx);
+ break;
default:
av_assert0(!"should not reach here");
break;
@@ -513,7 +533,44 @@ static DNNReturnType get_input_ov(void *model, DNNData *input, const char *input
return DNN_ERROR;
}
-static DNNReturnType extract_inference_from_task(DNNFunctionType func_type, TaskItem *task, Queue *inference_queue)
+static int contain_valid_detection_bbox(AVFrame *frame)
+{
+ AVFrameSideData *sd;
+ const AVDetectionBBoxHeader *header;
+ const AVDetectionBBox *bbox;
+
+ sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES);
+ if (!sd) { // this frame has nothing detected
+ return 0;
+ }
+
+ if (!sd->size) {
+ return 0;
+ }
+
+ header = (const AVDetectionBBoxHeader *)sd->data;
+ if (!header->nb_bboxes) {
+ return 0;
+ }
+
+ for (uint32_t i = 0; i < header->nb_bboxes; i++) {
+ bbox = av_get_detection_bbox(header, i);
+ if (bbox->x < 0 || bbox->w < 0 || bbox->x + bbox->w >= frame->width) {
+ return 0;
+ }
+ if (bbox->y < 0 || bbox->h < 0 || bbox->y + bbox->h >= frame->width) {
+ return 0;
+ }
+
+ if (bbox->classify_count == AV_NUM_DETECTION_BBOX_CLASSIFY) {
+ return 0;
+ }
+ }
+
+ return 1;
+}
+
+static DNNReturnType extract_inference_from_task(DNNFunctionType func_type, TaskItem *task, Queue *inference_queue, DNNExecBaseParams *exec_params)
{
switch (func_type) {
case DFT_PROCESS_FRAME:
@@ -532,6 +589,45 @@ static DNNReturnType extract_inference_from_task(DNNFunctionType func_type, Task
}
return DNN_SUCCESS;
}
+ case DFT_ANALYTICS_CLASSIFY:
+ {
+ const AVDetectionBBoxHeader *header;
+ AVFrame *frame = task->in_frame;
+ AVFrameSideData *sd;
+ DNNExecClassificationParams *params = (DNNExecClassificationParams *)exec_params;
+
+ task->inference_todo = 0;
+ task->inference_done = 0;
+
+ if (!contain_valid_detection_bbox(frame)) {
+ return DNN_SUCCESS;
+ }
+
+ sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES);
+ header = (const AVDetectionBBoxHeader *)sd->data;
+
+ for (uint32_t i = 0; i < header->nb_bboxes; i++) {
+ InferenceItem *inference;
+ const AVDetectionBBox *bbox = av_get_detection_bbox(header, i);
+
+ if (av_strncasecmp(bbox->detect_label, params->target, sizeof(bbox->detect_label)) != 0) {
+ continue;
+ }
+
+ inference = av_malloc(sizeof(*inference));
+ if (!inference) {
+ return DNN_ERROR;
+ }
+ task->inference_todo++;
+ inference->task = task;
+ inference->bbox_index = i;
+ if (ff_queue_push_back(inference_queue, inference) < 0) {
+ av_freep(&inference);
+ return DNN_ERROR;
+ }
+ }
+ return DNN_SUCCESS;
+ }
default:
av_assert0(!"should not reach here");
return DNN_ERROR;
@@ -598,7 +694,7 @@ static DNNReturnType get_output_ov(void *model, const char *input_name, int inpu
task.out_frame = out_frame;
task.ov_model = ov_model;
- if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue) != DNN_SUCCESS) {
+ if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue, NULL) != DNN_SUCCESS) {
av_frame_free(&out_frame);
av_frame_free(&in_frame);
av_log(ctx, AV_LOG_ERROR, "unable to extract inference from task.\n");
@@ -690,6 +786,14 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNExecBaseParams *
return DNN_ERROR;
}
+ if (model->func_type == DFT_ANALYTICS_CLASSIFY) {
+ // Once we add async support for tensorflow backend and native backend,
+ // we'll combine the two sync/async functions in dnn_interface.h to
+ // simplify the code in filter, and async will be an option within backends.
+ // so, do not support now, and classify filter will not call this function.
+ return DNN_ERROR;
+ }
+
if (ctx->options.batch_size > 1) {
avpriv_report_missing_feature(ctx, "batch mode for sync execution");
return DNN_ERROR;
@@ -710,7 +814,7 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNExecBaseParams *
task.out_frame = exec_params->out_frame ? exec_params->out_frame : exec_params->in_frame;
task.ov_model = ov_model;
- if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue) != DNN_SUCCESS) {
+ if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue, exec_params) != DNN_SUCCESS) {
av_log(ctx, AV_LOG_ERROR, "unable to extract inference from task.\n");
return DNN_ERROR;
}
@@ -730,6 +834,7 @@ DNNReturnType ff_dnn_execute_model_async_ov(const DNNModel *model, DNNExecBasePa
OVContext *ctx = &ov_model->ctx;
RequestItem *request;
TaskItem *task;
+ DNNReturnType ret;
if (ff_check_exec_params(ctx, DNN_OV, model->func_type, exec_params) != 0) {
return DNN_ERROR;
@@ -761,23 +866,25 @@ DNNReturnType ff_dnn_execute_model_async_ov(const DNNModel *model, DNNExecBasePa
return DNN_ERROR;
}
- if (extract_inference_from_task(ov_model->model->func_type, task, ov_model->inference_queue) != DNN_SUCCESS) {
+ if (extract_inference_from_task(model->func_type, task, ov_model->inference_queue, exec_params) != DNN_SUCCESS) {
av_log(ctx, AV_LOG_ERROR, "unable to extract inference from task.\n");
return DNN_ERROR;
}
- if (ff_queue_size(ov_model->inference_queue) < ctx->options.batch_size) {
- // not enough inference items queued for a batch
- return DNN_SUCCESS;
- }
+ while (ff_queue_size(ov_model->inference_queue) >= ctx->options.batch_size) {
+ request = ff_safe_queue_pop_front(ov_model->request_queue);
+ if (!request) {
+ av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n");
+ return DNN_ERROR;
+ }
- request = ff_safe_queue_pop_front(ov_model->request_queue);
- if (!request) {
- av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n");
- return DNN_ERROR;
+ ret = execute_model_ov(request, ov_model->inference_queue);
+ if (ret != DNN_SUCCESS) {
+ return ret;
+ }
}
- return execute_model_ov(request, ov_model->inference_queue);
+ return DNN_SUCCESS;
}
DNNAsyncStatusType ff_dnn_get_async_result_ov(const DNNModel *model, AVFrame **in, AVFrame **out)
diff --git a/libavfilter/dnn/dnn_io_proc.c b/libavfilter/dnn/dnn_io_proc.c
index e104cc5064..5f60d68078 100644
--- a/libavfilter/dnn/dnn_io_proc.c
+++ b/libavfilter/dnn/dnn_io_proc.c
@@ -22,6 +22,7 @@
#include "libavutil/imgutils.h"
#include "libswscale/swscale.h"
#include "libavutil/avassert.h"
+#include "libavutil/detection_bbox.h"
DNNReturnType ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx)
{
@@ -175,6 +176,65 @@ static enum AVPixelFormat get_pixel_format(DNNData *data)
return AV_PIX_FMT_BGR24;
}
+DNNReturnType ff_frame_to_dnn_classify(AVFrame *frame, DNNData *input, uint32_t bbox_index, void *log_ctx)
+{
+ const AVPixFmtDescriptor *desc;
+ int offsetx[4], offsety[4];
+ uint8_t *bbox_data[4];
+ struct SwsContext *sws_ctx;
+ int linesizes[4];
+ enum AVPixelFormat fmt;
+ int left, top, width, height;
+ const AVDetectionBBoxHeader *header;
+ const AVDetectionBBox *bbox;
+ AVFrameSideData *sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES);
+ av_assert0(sd);
+
+ header = (const AVDetectionBBoxHeader *)sd->data;
+ bbox = av_get_detection_bbox(header, bbox_index);
+
+ left = bbox->x;
+ width = bbox->w;
+ top = bbox->y;
+ height = bbox->h;
+
+ fmt = get_pixel_format(input);
+ sws_ctx = sws_getContext(width, height, frame->format,
+ input->width, input->height, fmt,
+ SWS_FAST_BILINEAR, NULL, NULL, NULL);
+ if (!sws_ctx) {
+ av_log(log_ctx, AV_LOG_ERROR, "Failed to create scale context for the conversion "
+ "fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n",
+ av_get_pix_fmt_name(frame->format), width, height,
+ av_get_pix_fmt_name(fmt), input->width, input->height);
+ return DNN_ERROR;
+ }
+
+ if (av_image_fill_linesizes(linesizes, fmt, input->width) < 0) {
+ av_log(log_ctx, AV_LOG_ERROR, "unable to get linesizes with av_image_fill_linesizes");
+ sws_freeContext(sws_ctx);
+ return DNN_ERROR;
+ }
+
+ desc = av_pix_fmt_desc_get(frame->format);
+ offsetx[1] = offsetx[2] = AV_CEIL_RSHIFT(left, desc->log2_chroma_w);
+ offsetx[0] = offsetx[3] = left;
+
+ offsety[1] = offsety[2] = AV_CEIL_RSHIFT(top, desc->log2_chroma_h);
+ offsety[0] = offsety[3] = top;
+
+ for (int k = 0; frame->data[k]; k++)
+ bbox_data[k] = frame->data[k] + offsety[k] * frame->linesize[k] + offsetx[k];
+
+ sws_scale(sws_ctx, (const uint8_t *const *)&bbox_data, frame->linesize,
+ 0, height,
+ (uint8_t *const *)(&input->data), linesizes);
+
+ sws_freeContext(sws_ctx);
+
+ return DNN_SUCCESS;
+}
+
static DNNReturnType proc_from_frame_to_dnn_analytics(AVFrame *frame, DNNData *input, void *log_ctx)
{
struct SwsContext *sws_ctx;
diff --git a/libavfilter/dnn/dnn_io_proc.h b/libavfilter/dnn/dnn_io_proc.h
index 91ad3cb261..16dcdd6d1a 100644
--- a/libavfilter/dnn/dnn_io_proc.h
+++ b/libavfilter/dnn/dnn_io_proc.h
@@ -32,5 +32,6 @@
DNNReturnType ff_proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, DNNFunctionType func_type, void *log_ctx);
DNNReturnType ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx);
+DNNReturnType ff_frame_to_dnn_classify(AVFrame *frame, DNNData *input, uint32_t bbox_index, void *log_ctx);
#endif
diff --git a/libavfilter/dnn_filter_common.c b/libavfilter/dnn_filter_common.c
index c085884eb4..52c7a5392a 100644
--- a/libavfilter/dnn_filter_common.c
+++ b/libavfilter/dnn_filter_common.c
@@ -77,6 +77,12 @@ int ff_dnn_set_detect_post_proc(DnnContext *ctx, DetectPostProc post_proc)
return 0;
}
+int ff_dnn_set_classify_post_proc(DnnContext *ctx, ClassifyPostProc post_proc)
+{
+ ctx->model->classify_post_proc = post_proc;
+ return 0;
+}
+
DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input)
{
return ctx->model->get_input(ctx->model->model, input, ctx->model_inputname);
@@ -112,6 +118,21 @@ DNNReturnType ff_dnn_execute_model_async(DnnContext *ctx, AVFrame *in_frame, AVF
return (ctx->dnn_module->execute_model_async)(ctx->model, &exec_params);
}
+DNNReturnType ff_dnn_execute_model_classification(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame, char *target)
+{
+ DNNExecClassificationParams class_params = {
+ {
+ .input_name = ctx->model_inputname,
+ .output_names = (const char **)&ctx->model_outputname,
+ .nb_output = 1,
+ .in_frame = in_frame,
+ .out_frame = out_frame,
+ },
+ .target = target,
+ };
+ return (ctx->dnn_module->execute_model_async)(ctx->model, &class_params.base);
+}
+
DNNAsyncStatusType ff_dnn_get_async_result(DnnContext *ctx, AVFrame **in_frame, AVFrame **out_frame)
{
return (ctx->dnn_module->get_async_result)(ctx->model, in_frame, out_frame);
diff --git a/libavfilter/dnn_filter_common.h b/libavfilter/dnn_filter_common.h
index 8deb18b39a..e7736d2bac 100644
--- a/libavfilter/dnn_filter_common.h
+++ b/libavfilter/dnn_filter_common.h
@@ -50,10 +50,12 @@ typedef struct DnnContext {
int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx);
int ff_dnn_set_frame_proc(DnnContext *ctx, FramePrePostProc pre_proc, FramePrePostProc post_proc);
int ff_dnn_set_detect_post_proc(DnnContext *ctx, DetectPostProc post_proc);
+int ff_dnn_set_classify_post_proc(DnnContext *ctx, ClassifyPostProc post_proc);
DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input);
DNNReturnType ff_dnn_get_output(DnnContext *ctx, int input_width, int input_height, int *output_width, int *output_height);
DNNReturnType ff_dnn_execute_model(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame);
DNNReturnType ff_dnn_execute_model_async(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame);
+DNNReturnType ff_dnn_execute_model_classification(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame, char *target);
DNNAsyncStatusType ff_dnn_get_async_result(DnnContext *ctx, AVFrame **in_frame, AVFrame **out_frame);
DNNReturnType ff_dnn_flush(DnnContext *ctx);
void ff_dnn_uninit(DnnContext *ctx);
diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h
index 941670675d..799244ee14 100644
--- a/libavfilter/dnn_interface.h
+++ b/libavfilter/dnn_interface.h
@@ -52,7 +52,7 @@ typedef enum {
DFT_NONE,
DFT_PROCESS_FRAME, // process the whole frame
DFT_ANALYTICS_DETECT, // detect from the whole frame
- // we can add more such as detect_from_crop, classify_from_bbox, etc.
+ DFT_ANALYTICS_CLASSIFY, // classify for each bounding box
}DNNFunctionType;
typedef struct DNNData{
@@ -71,8 +71,14 @@ typedef struct DNNExecBaseParams {
AVFrame *out_frame;
} DNNExecBaseParams;
+typedef struct DNNExecClassificationParams {
+ DNNExecBaseParams base;
+ const char *target;
+} DNNExecClassificationParams;
+
typedef int (*FramePrePostProc)(AVFrame *frame, DNNData *model, AVFilterContext *filter_ctx);
typedef int (*DetectPostProc)(AVFrame *frame, DNNData *output, uint32_t nb, AVFilterContext *filter_ctx);
+typedef int (*ClassifyPostProc)(AVFrame *frame, DNNData *output, uint32_t bbox_index, AVFilterContext *filter_ctx);
typedef struct DNNModel{
// Stores model that can be different for different backends.
@@ -97,6 +103,8 @@ typedef struct DNNModel{
FramePrePostProc frame_post_proc;
// set the post process to interpret detect result from DNNData
DetectPostProc detect_post_proc;
+ // set the post process to interpret classify result from DNNData
+ ClassifyPostProc classify_post_proc;
} DNNModel;
// Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.