aboutsummaryrefslogtreecommitdiffstats
path: root/libavfilter/dnn_filter_common.c
blob: 860ca7591f42d2d0c7ad6f2dfc7d3a1404145da1 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
/*
 * This file is part of FFmpeg.
 *
 * FFmpeg is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * FFmpeg is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with FFmpeg; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 */

#include "dnn_filter_common.h"
#include "libavutil/avstring.h"
#include "libavutil/mem.h"
#include "libavutil/opt.h"

#define MAX_SUPPORTED_OUTPUTS_NB 4

static char **separate_output_names(const char *expr, const char *val_sep, int *separated_nb)
{
    char *val, **parsed_vals = NULL;
    int val_num = 0;
    if (!expr || !val_sep || !separated_nb) {
        return NULL;
    }

    parsed_vals = av_calloc(MAX_SUPPORTED_OUTPUTS_NB, sizeof(*parsed_vals));
    if (!parsed_vals) {
        return NULL;
    }

    do {
        val = av_get_token(&expr, val_sep);
        if(val) {
            parsed_vals[val_num] = val;
            val_num++;
        }
        if (*expr) {
            expr++;
        }
    } while(*expr);

    parsed_vals[val_num] = NULL;
    *separated_nb = val_num;

    return parsed_vals;
}

typedef struct DnnFilterBase {
    const AVClass *class;
    DnnContext dnnctx;
} DnnFilterBase;

int ff_dnn_filter_init_child_class(AVFilterContext *filter) {
    DnnFilterBase *base = filter->priv;
    ff_dnn_init_child_class(&base->dnnctx);
    return 0;
}

void *ff_dnn_filter_child_next(void *obj, void *prev)
{
    DnnFilterBase *base = obj;
    return ff_dnn_child_next(&base->dnnctx, prev);
}

int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx)
{
    DNNBackendType backend = ctx->backend_type;

    if (!ctx->model_filename) {
        av_log(filter_ctx, AV_LOG_ERROR, "model file for network is not specified\n");
        return AVERROR(EINVAL);
    }

    if (backend == DNN_TH) {
        if (ctx->model_inputname)
            av_log(filter_ctx, AV_LOG_WARNING, "LibTorch backend do not require inputname, "\
                                               "inputname will be ignored.\n");
        if (ctx->model_outputnames)
            av_log(filter_ctx, AV_LOG_WARNING, "LibTorch backend do not require outputname(s), "\
                                               "all outputname(s) will be ignored.\n");
        ctx->nb_outputs = 1;
    } else if (backend == DNN_TF) {
        if (!ctx->model_inputname) {
            av_log(filter_ctx, AV_LOG_ERROR, "input name of the model network is not specified\n");
            return AVERROR(EINVAL);
        }
        ctx->model_outputnames = separate_output_names(ctx->model_outputnames_string, "&", &ctx->nb_outputs);
        if (!ctx->model_outputnames) {
            av_log(filter_ctx, AV_LOG_ERROR, "could not parse model output names\n");
            return AVERROR(EINVAL);
        }
    }

    ctx->dnn_module = ff_get_dnn_module(ctx->backend_type, filter_ctx);
    if (!ctx->dnn_module) {
        av_log(filter_ctx, AV_LOG_ERROR, "could not create DNN module for requested backend\n");
        return AVERROR(ENOMEM);
    }
    if (!ctx->dnn_module->load_model) {
        av_log(filter_ctx, AV_LOG_ERROR, "load_model for network is not specified\n");
        return AVERROR(EINVAL);
    }

    if (ctx->backend_options) {
        void *child = NULL;

        av_log(filter_ctx, AV_LOG_WARNING,
               "backend_configs is deprecated, please set backend options directly\n");
        while (child = ff_dnn_child_next(ctx, child)) {
            if (*(const AVClass **)child == &ctx->dnn_module->clazz) {
                int ret = av_opt_set_from_string(child, ctx->backend_options,
                                                 NULL, "=", "&");
                if (ret < 0) {
                    av_log(filter_ctx, AV_LOG_ERROR, "failed to parse options \"%s\"\n",
                           ctx->backend_options);
                    return ret;
                }
            }
        }
    }

    ctx->model = (ctx->dnn_module->load_model)(ctx, func_type, filter_ctx);
    if (!ctx->model) {
        av_log(filter_ctx, AV_LOG_ERROR, "could not load DNN model\n");
        return AVERROR(EINVAL);
    }

    return 0;
}

int ff_dnn_set_frame_proc(DnnContext *ctx, FramePrePostProc pre_proc, FramePrePostProc post_proc)
{
    ctx->model->frame_pre_proc = pre_proc;
    ctx->model->frame_post_proc = post_proc;
    return 0;
}

int ff_dnn_set_detect_post_proc(DnnContext *ctx, DetectPostProc post_proc)
{
    ctx->model->detect_post_proc = post_proc;
    return 0;
}

int ff_dnn_set_classify_post_proc(DnnContext *ctx, ClassifyPostProc post_proc)
{
    ctx->model->classify_post_proc = post_proc;
    return 0;
}

int ff_dnn_get_input(DnnContext *ctx, DNNData *input)
{
    return ctx->model->get_input(ctx->model->model, input, ctx->model_inputname);
}

int ff_dnn_get_output(DnnContext *ctx, int input_width, int input_height, int *output_width, int *output_height)
{
    char * output_name = ctx->model_outputnames && ctx->backend_type != DNN_TH ?
                         ctx->model_outputnames[0] : NULL;
    return ctx->model->get_output(ctx->model->model, ctx->model_inputname, input_width, input_height,
                                    (const char *)output_name, output_width, output_height);
}

int ff_dnn_execute_model(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame)
{
    DNNExecBaseParams exec_params = {
        .input_name     = ctx->model_inputname,
        .output_names   = (const char **)ctx->model_outputnames,
        .nb_output      = ctx->nb_outputs,
        .in_frame       = in_frame,
        .out_frame      = out_frame,
    };
    return (ctx->dnn_module->execute_model)(ctx->model, &exec_params);
}

int ff_dnn_execute_model_classification(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame, const char *target)
{
    DNNExecClassificationParams class_params = {
        {
            .input_name     = ctx->model_inputname,
            .output_names   = (const char **)ctx->model_outputnames,
            .nb_output      = ctx->nb_outputs,
            .in_frame       = in_frame,
            .out_frame      = out_frame,
        },
        .target = target,
    };
    return (ctx->dnn_module->execute_model)(ctx->model, &class_params.base);
}

DNNAsyncStatusType ff_dnn_get_result(DnnContext *ctx, AVFrame **in_frame, AVFrame **out_frame)
{
    return (ctx->dnn_module->get_result)(ctx->model, in_frame, out_frame);
}

int ff_dnn_flush(DnnContext *ctx)
{
    return (ctx->dnn_module->flush)(ctx->model);
}

void ff_dnn_uninit(DnnContext *ctx)
{
    if (ctx->dnn_module) {
        (ctx->dnn_module->free_model)(&ctx->model);
    }
    if (ctx->model_outputnames) {
        for (int i = 0; i < ctx->nb_outputs; i++)
            av_free(ctx->model_outputnames[i]);

        av_freep(&ctx->model_outputnames);
    }
}