/******************************************************************************* * Copyright 2017-2023 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ #include #include #include #include #include "oneapi/dnnl/dnnl.h" #define for_ for #define CHECK(f) \ do { \ dnnl_status_t s = f; \ if (s != dnnl_success) { \ printf("[%s:%d] error: %s returns %d\n", __FILE__, __LINE__, #f, \ s); \ exit(2); \ } \ } while (0) #define CHECK_TRUE(expr) \ do { \ int e_ = expr; \ if (!e_) { \ printf("[%s:%d] %s failed\n", __FILE__, __LINE__, #expr); \ exit(2); \ } \ } while (0) typedef float real_t; #define LENGTH_100 100 void test1() { dnnl_engine_t engine; CHECK(dnnl_engine_create(&engine, dnnl_cpu, 0)); dnnl_dims_t dims = {LENGTH_100}; real_t data[LENGTH_100]; dnnl_memory_desc_t md; const_dnnl_memory_desc_t c_md_tmp; dnnl_memory_t m; CHECK(dnnl_memory_desc_create_with_tag(&md, 1, dims, dnnl_f32, dnnl_x)); CHECK(dnnl_memory_create(&m, md, engine, NULL)); void *req = NULL; CHECK(dnnl_memory_get_data_handle(m, &req)); CHECK_TRUE(req == NULL); #if DNNL_CPU_RUNTIME != DNNL_RUNTIME_SYCL CHECK(dnnl_memory_set_data_handle(m, data)); CHECK(dnnl_memory_get_data_handle(m, &req)); CHECK_TRUE(req == data); #endif CHECK_TRUE(dnnl_memory_desc_get_size(md) == LENGTH_100 * sizeof(data[0])); CHECK(dnnl_memory_get_memory_desc(m, &c_md_tmp)); CHECK_TRUE(dnnl_memory_desc_equal(md, c_md_tmp)); CHECK(dnnl_memory_destroy(m)); CHECK(dnnl_memory_desc_destroy(md)); CHECK(dnnl_engine_destroy(engine)); } #if DNNL_CPU_RUNTIME != DNNL_RUNTIME_SYCL static size_t product(dnnl_dim_t *arr, size_t size) { size_t prod = 1; for (size_t i = 0; i < size; ++i) prod *= arr[i]; return prod; } void test2() { /* AlexNet: c3 * {2, 256, 13, 13} (x) {384, 256, 3, 3} -> {2, 384, 13, 13} * pad: {1, 1} * strides: {1, 1} */ const dnnl_dim_t mb = 2; const dnnl_dim_t groups = 2; const int ndims = 4; dnnl_dims_t c3_src_sizes = {mb, 256, 13, 13}; dnnl_dims_t c3_weights_sizes = {groups, 384 / groups, 256 / groups, 3, 3}; dnnl_dims_t c3_bias_sizes = {384}; dnnl_dims_t strides = {1, 1}; dnnl_dims_t dilation = {0, 0}; dnnl_dims_t padding = {0, 0}; // set proper values dnnl_dims_t c3_dst_sizes = {mb, 384, (c3_src_sizes[2] + 2 * padding[0] - c3_weights_sizes[3]) / strides[0] + 1, (c3_src_sizes[3] + 2 * padding[1] - c3_weights_sizes[4]) / strides[1] + 1}; real_t *src = (real_t *)calloc(product(c3_src_sizes, ndims), sizeof(real_t)); real_t *weights = (real_t *)calloc( product(c3_weights_sizes, ndims + 1), sizeof(real_t)); real_t *bias = (real_t *)calloc(product(c3_bias_sizes, 1), sizeof(real_t)); real_t *dst = (real_t *)calloc(product(c3_dst_sizes, ndims), sizeof(real_t)); real_t *out_mem = (real_t *)calloc(product(c3_dst_sizes, ndims), sizeof(real_t)); CHECK_TRUE(src && weights && bias && dst && out_mem); for (dnnl_dim_t i = 0; i < c3_bias_sizes[0]; ++i) bias[i] = i; dnnl_engine_t engine; CHECK(dnnl_engine_create(&engine, dnnl_cpu, 0)); dnnl_stream_t stream; CHECK(dnnl_stream_create(&stream, engine, dnnl_stream_default_flags)); /* first describe user data and create data descriptors for future * convolution w/ the specified format -- we do not want to do a reorder */ dnnl_memory_desc_t c3_src_md, c3_weights_md, c3_bias_md, c3_dst_md, out_md; dnnl_memory_t c3_src, c3_weights, c3_bias, c3_dst, out; // src { CHECK(dnnl_memory_desc_create_with_tag( &c3_src_md, 4, c3_src_sizes, dnnl_f32, dnnl_nChw8c)); CHECK(dnnl_memory_create(&c3_src, c3_src_md, engine, src)); } // weights { CHECK(dnnl_memory_desc_create_with_tag(&c3_weights_md, 4 + (groups != 1), c3_weights_sizes + (groups == 1), dnnl_f32, groups == 1 ? dnnl_OIhw8i8o : dnnl_gOIhw8i8o)); CHECK(dnnl_memory_create(&c3_weights, c3_weights_md, engine, weights)); } // bias { CHECK(dnnl_memory_desc_create_with_tag( &c3_bias_md, 1, c3_bias_sizes, dnnl_f32, dnnl_x)); CHECK(dnnl_memory_create(&c3_bias, c3_bias_md, engine, bias)); } // c3_dst { CHECK(dnnl_memory_desc_create_with_tag( &c3_dst_md, 4, c3_dst_sizes, dnnl_f32, dnnl_nChw8c)); CHECK(dnnl_memory_create(&c3_dst, c3_dst_md, engine, dst)); } // out { CHECK(dnnl_memory_desc_create_with_tag( &out_md, 4, c3_dst_sizes, dnnl_f32, dnnl_nchw)); CHECK(dnnl_memory_create(&out, out_md, engine, out_mem)); } /* create a convolution primitive descriptor */ dnnl_primitive_desc_t c3_pd; dnnl_primitive_t c3; CHECK(dnnl_convolution_forward_primitive_desc_create(&c3_pd, engine, dnnl_forward_training, dnnl_convolution_direct, c3_src_md, c3_weights_md, c3_bias_md, c3_dst_md, strides, dilation, padding, NULL, NULL)); CHECK_TRUE(dnnl_memory_desc_equal(c3_src_md, dnnl_primitive_desc_query_md(c3_pd, dnnl_query_src_md, 0))); CHECK_TRUE(dnnl_memory_desc_equal(c3_weights_md, dnnl_primitive_desc_query_md(c3_pd, dnnl_query_weights_md, 0))); CHECK_TRUE(dnnl_memory_desc_equal(c3_bias_md, dnnl_primitive_desc_query_md(c3_pd, dnnl_query_weights_md, 1))); CHECK_TRUE(dnnl_memory_desc_equal(c3_dst_md, dnnl_primitive_desc_query_md(c3_pd, dnnl_query_dst_md, 0))); CHECK_TRUE(dnnl_memory_desc_equal(c3_src_md, dnnl_primitive_desc_query_md( c3_pd, dnnl_query_exec_arg_md, DNNL_ARG_SRC))); CHECK_TRUE(dnnl_memory_desc_equal(c3_weights_md, dnnl_primitive_desc_query_md( c3_pd, dnnl_query_exec_arg_md, DNNL_ARG_WEIGHTS))); CHECK_TRUE(dnnl_memory_desc_equal(c3_bias_md, dnnl_primitive_desc_query_md( c3_pd, dnnl_query_exec_arg_md, DNNL_ARG_BIAS))); CHECK_TRUE(dnnl_memory_desc_equal(c3_dst_md, dnnl_primitive_desc_query_md( c3_pd, dnnl_query_exec_arg_md, DNNL_ARG_DST))); /* create a convolution and execute it */ CHECK(dnnl_primitive_create(&c3, c3_pd)); CHECK(dnnl_primitive_desc_destroy(c3_pd)); dnnl_exec_arg_t c3_args[4] = { {DNNL_ARG_SRC, c3_src}, {DNNL_ARG_WEIGHTS, c3_weights}, {DNNL_ARG_BIAS, c3_bias}, {DNNL_ARG_DST, c3_dst}, }; CHECK(dnnl_primitive_execute(c3, stream, 4, c3_args)); CHECK(dnnl_primitive_destroy(c3)); /* create a reorder primitive descriptor */ dnnl_primitive_desc_t r_pd; CHECK(dnnl_reorder_primitive_desc_create( &r_pd, c3_dst_md, engine, out_md, engine, NULL)); /* create a reorder and execute it */ dnnl_primitive_t r; CHECK(dnnl_primitive_create(&r, r_pd)); CHECK(dnnl_primitive_desc_destroy(r_pd)); dnnl_exec_arg_t r_args[2] = { {DNNL_ARG_FROM, c3_dst}, {DNNL_ARG_TO, out}, }; CHECK(dnnl_primitive_execute(r, stream, 2, r_args)); CHECK(dnnl_primitive_destroy(r)); CHECK(dnnl_stream_wait(stream)); /* clean-up */ CHECK(dnnl_memory_destroy(c3_src)); CHECK(dnnl_memory_destroy(c3_weights)); CHECK(dnnl_memory_destroy(c3_bias)); CHECK(dnnl_memory_destroy(c3_dst)); CHECK(dnnl_memory_destroy(out)); CHECK(dnnl_stream_destroy(stream)); CHECK(dnnl_engine_destroy(engine)); CHECK(dnnl_memory_desc_destroy(c3_src_md)); CHECK(dnnl_memory_desc_destroy(c3_weights_md)); CHECK(dnnl_memory_desc_destroy(c3_bias_md)); CHECK(dnnl_memory_desc_destroy(c3_dst_md)); CHECK(dnnl_memory_desc_destroy(out_md)); const dnnl_dim_t N = c3_dst_sizes[0], C = c3_dst_sizes[1], H = c3_dst_sizes[2], W = c3_dst_sizes[3]; for_(dnnl_dim_t n = 0; n < N; ++n) for_(dnnl_dim_t c = 0; c < C; ++c) for_(dnnl_dim_t h = 0; h < H; ++h) for (dnnl_dim_t w = 0; w < W; ++w) { dnnl_dim_t off = ((n * C + c) * H + h) * W + w; CHECK_TRUE(out_mem[off] == bias[c]); } free(src); free(weights); free(bias); free(dst); free(out_mem); } void test3() { const dnnl_dim_t mb = 2; const int ndims = 4; dnnl_dims_t l2_data_sizes = {mb, 256, 13, 13}; real_t *src = (real_t *)calloc(product(l2_data_sizes, ndims), sizeof(real_t)); real_t *dst = (real_t *)calloc(product(l2_data_sizes, ndims), sizeof(real_t)); CHECK_TRUE(src && dst); for (size_t i = 0; i < product(l2_data_sizes, ndims); ++i) src[i] = (i % 13) + 1; dnnl_engine_t engine; CHECK(dnnl_engine_create(&engine, dnnl_cpu, 0)); dnnl_stream_t stream; CHECK(dnnl_stream_create(&stream, engine, dnnl_stream_default_flags)); dnnl_memory_desc_t l2_data_md; dnnl_memory_t l2_src, l2_dst; // src, dst { CHECK(dnnl_memory_desc_create_with_tag( &l2_data_md, ndims, l2_data_sizes, dnnl_f32, dnnl_nchw)); CHECK(dnnl_memory_create(&l2_src, l2_data_md, engine, src)); CHECK(dnnl_memory_create(&l2_dst, l2_data_md, engine, dst)); } /* create an lrn */ dnnl_primitive_desc_t l2_pd; dnnl_primitive_t l2; CHECK(dnnl_lrn_forward_primitive_desc_create(&l2_pd, engine, dnnl_forward_inference, dnnl_lrn_across_channels, l2_data_md, l2_data_md, 5, 1e-4f, 0.75f, 1.0f, NULL)); CHECK_TRUE(dnnl_memory_desc_equal(l2_data_md, dnnl_primitive_desc_query_md(l2_pd, dnnl_query_src_md, 0))); CHECK_TRUE(dnnl_memory_desc_equal(l2_data_md, dnnl_primitive_desc_query_md(l2_pd, dnnl_query_dst_md, 0))); CHECK_TRUE(dnnl_primitive_desc_query_s32( l2_pd, dnnl_query_num_of_inputs_s32, 0) == 1); CHECK_TRUE(dnnl_primitive_desc_query_s32( l2_pd, dnnl_query_num_of_outputs_s32, 0) == 1); CHECK(dnnl_primitive_create(&l2, l2_pd)); CHECK(dnnl_primitive_desc_destroy(l2_pd)); dnnl_exec_arg_t l2_args[2] = { {DNNL_ARG_SRC, l2_src}, {DNNL_ARG_DST, l2_dst}, }; CHECK(dnnl_primitive_execute(l2, stream, 2, l2_args)); CHECK(dnnl_primitive_destroy(l2)); CHECK(dnnl_stream_wait(stream)); /* clean-up */ CHECK(dnnl_memory_destroy(l2_src)); CHECK(dnnl_memory_destroy(l2_dst)); CHECK(dnnl_stream_destroy(stream)); CHECK(dnnl_engine_destroy(engine)); CHECK(dnnl_memory_desc_destroy(l2_data_md)); const dnnl_dim_t N = l2_data_sizes[0], C = l2_data_sizes[1], H = l2_data_sizes[2], W = l2_data_sizes[3]; for_(dnnl_dim_t n = 0; n < N; ++n) for_(dnnl_dim_t c = 0; c < C; ++c) for_(dnnl_dim_t h = 0; h < H; ++h) for (dnnl_dim_t w = 0; w < W; ++w) { size_t off = ((n * C + c) * H + h) * W + w; real_t e = (off % 13) + 1; real_t diff = (real_t)fabs(dst[off] - e); if (diff / fabs(e) > 0.0125) printf("exp: %g, got: %g\n", e, dst[off]); CHECK_TRUE(diff / fabs(e) < 0.0125); } free(src); free(dst); } #endif int main() { test1(); #if DNNL_CPU_RUNTIME != DNNL_RUNTIME_SYCL test2(); test3(); #endif return 0; }