[STM32]:使用X-CUBE-AI进行模型推理的指南

基于X-CUBE-AI的模型推理

本文所使用的版本如下:

  • X-CUBE-AI:8.1.0
  • STM32CUBEMX:6.7.0
  • 基于CUBEMX导出模型

    首先需要再软件包选中X-CUBE-AI:

    导入模型进行转换,这里选择STM32Cube.AI Runtime

    在底部有RAM与ROM的开销占用:

    基于STM32实现模型推理

    STM32提供了相关了文档,可以到pack包安装的地方查看这篇文章,我的安装路径如下,每个人的电脑都不一样

    file:///D:/IDE/STM32CUBEMX/Repository/Packs/STMicroelectronics/X-CUBE-AI/8.1.0/Documentation/how_to_run_a_model_locally.html
    

    接下来,我们按照文档编写图例代码,本文所使用的模型输入为2048长度的一维浮点数据。

    1.引入必要的头文件

    #include "stdio.h"
    #include <stdlib.h>
    #include <time.h>
    #include <string.h>
    #include "network.h"
    #include "network_data.h"
    

    2.创建模型的输入输出以及句柄

    AI_ALIGNED(32)
    static ai_u8 activations[AI_NETWORK_DATA_ACTIVATIONS_SIZE];
    AI_ALIGNED(32)
    static ai_float in_data[AI_NETWORK_IN_1_SIZE]; //这里记得修改为自己的类型,以及长度选择SIZE,不要是byte
    AI_ALIGNED(32)
    static ai_float out_data[AI_NETWORK_OUT_1_SIZE]; //这里也是改为size
    ai_buffer *ai_input;
    ai_buffer *ai_output;
    ai_handle network = AI_HANDLE_NULL;
    ai_error err;
    ai_network_report report;
    

    3.创建模型初始化代码

    int ai_init()
    {
      const ai_handle acts[] = {activations};
      err = ai_network_create_and_init(&network, acts, NULL);
      if (err.type != AI_ERROR_NONE)
      {
        printf("ai init_and_create error\n");
        return -1;
      }
      else
      {
        printf("ai init success\n");
      }
    
      if (ai_network_get_report(network, &report) != true)
      {
        printf("ai get report error\n");
        return -1;
      }
    
      printf("Model name      : %s\n", report.model_name);
      printf("Model signature : %s\n", report.model_signature);
      return 0;
    }
    

    3.赋值与推理

    int ai_run(ai_float *in_data, ai_float *out_data, float *data, int length)
    {
      ai_i32 n_batch;
    
      for (int i = 0; i < length; i++)
      {
        in_data[i] = data[i];
      }
    
      ai_input = ai_network_inputs_get(network, NULL);
      ai_output = ai_network_outputs_get(network, NULL);
      ai_input[0].data = AI_HANDLE_PTR(in_data);
      ai_output[0].data = AI_HANDLE_PTR(out_data);
    
      n_batch = ai_network_run(network, &ai_input[0], &ai_output[0]);
      if (n_batch != 1)
      {
        ai_network_get_error(network);
        printf("run failed\r\n");
        return -1;
      };
    
      return 0; // success;
    }
    

    接下来,我们就可以根据out_data来查看推理结果

      for (int i = 0; i < AI_NETWORK_OUT_1_SIZE; i++)
      {
        printf("%.2f, ", out_data[i]);
      }
    


    和我们上位机的结果保持一致

    全部代码

    /* USER CODE BEGIN Header */
    /**
     ******************************************************************************
     * @file           : main.c
     * @brief          : Main program body
     ******************************************************************************
     * @attention
     *
     * Copyright (c) 2024 STMicroelectronics.
     * All rights reserved.
     *
     * This software is licensed under terms that can be found in the LICENSE file
     * in the root directory of this software component.
     * If no LICENSE file comes with this software, it is provided AS-IS.
     *
     ******************************************************************************
     */
    /* USER CODE END Header */
    /* Includes ------------------------------------------------------------------*/
    #include "main.h"
    
    /* Private includes ----------------------------------------------------------*/
    /* USER CODE BEGIN Includes */
    #include "stdio.h"
    #include <stdlib.h>
    #include <time.h>
    #include <string.h>
    #include "network.h"
    #include "network_data.h"
    /* USER CODE END Includes */
    
    /* Private typedef -----------------------------------------------------------*/
    /* USER CODE BEGIN PTD */
    
    /* USER CODE END PTD */
    
    /* Private define ------------------------------------------------------------*/
    /* USER CODE BEGIN PD */
    /* USER CODE END PD */
    
    /* Private macro -------------------------------------------------------------*/
    /* USER CODE BEGIN PM */
    
    /* USER CODE END PM */
    
    /* Private variables ---------------------------------------------------------*/
    CRC_HandleTypeDef hcrc;
    
    I2C_HandleTypeDef hi2c1;
    
    UART_HandleTypeDef huart1;
    
    /* USER CODE BEGIN PV */
    
    /* USER CODE END PV */
    
    /* Private function prototypes -----------------------------------------------*/
    void SystemClock_Config(void);
    static void MX_GPIO_Init(void);
    static void MX_CRC_Init(void);
    static void MX_I2C1_Init(void);
    static void MX_USART1_UART_Init(void);
    /* USER CODE BEGIN PFP */
    
    /* USER CODE END PFP */
    
    /* Private user code ---------------------------------------------------------*/
    /* USER CODE BEGIN 0 */
    int fputc(int ch, FILE *f)
    {
      HAL_UART_Transmit(&huart1, (uint8_t *)&ch, 1, 0xFFFF);
      return ch;
    }
    
    AI_ALIGNED(32)
    static ai_u8 activations[AI_NETWORK_DATA_ACTIVATIONS_SIZE];
    AI_ALIGNED(32)
    static ai_float in_data[AI_NETWORK_IN_1_SIZE];
    AI_ALIGNED(32)
    static ai_float out_data[AI_NETWORK_OUT_1_SIZE];
    ai_buffer *ai_input;
    ai_buffer *ai_output;
    ai_handle network = AI_HANDLE_NULL;
    ai_error err;
    ai_network_report report;
    
    //替换为自己的数据
    float data[] ={};
    
    /**
     * @brief ai init
     *
     * @return int
     */
    int ai_init()
    {
      const ai_handle acts[] = {activations};
      err = ai_network_create_and_init(&network, acts, NULL);
      if (err.type != AI_ERROR_NONE)
      {
        printf("ai init_and_create error\n");
        return -1;
      }
      else
      {
        printf("ai init success\n");
      }
    
      if (ai_network_get_report(network, &report) != true)
      {
        printf("ai get report error\n");
        return -1;
      }
    
      printf("Model name      : %s\n", report.model_name);
      printf("Model signature : %s\n", report.model_signature);
      return 0;
    }
    
    int ai_run(ai_float *in_data, ai_float *out_data, float *data, int length)
    {
      ai_i32 n_batch;
    
      for (int i = 0; i < length; i++)
      {
        in_data[i] = data[i];
      }
    
      ai_input = ai_network_inputs_get(network, NULL);
      ai_output = ai_network_outputs_get(network, NULL);
      ai_input[0].data = AI_HANDLE_PTR(in_data);
      ai_output[0].data = AI_HANDLE_PTR(out_data);
    
      n_batch = ai_network_run(network, &ai_input[0], &ai_output[0]);
      if (n_batch != 1)
      {
        ai_network_get_error(network);
        printf("run failed\r\n");
        return -1;
      };
    
      return 0; // success;
    }
    
    /* USER CODE END 0 */
    
    /**
     * @brief
     *
     */
    int main(void)
    {
      /* USER CODE BEGIN 1 */
    
      /* USER CODE END 1 */
    
      /* MCU Configuration--------------------------------------------------------*/
    
      /* Reset of all peripherals, Initializes the Flash interface and the Systick. */
      HAL_Init();
    
      /* USER CODE BEGIN Init */
    
      /* USER CODE END Init */
    
      /* Configure the system clock */
      SystemClock_Config();
    
      /* USER CODE BEGIN SysInit */
    
      /* USER CODE END SysInit */
    
      /* Initialize all configured peripherals */
      MX_GPIO_Init();
      MX_CRC_Init();
      MX_I2C1_Init();
      MX_USART1_UART_Init();
      /* USER CODE BEGIN 2 */
    
      if (ai_init() != 0)
      {
        return -1;
      }
    
      if (ai_run(in_data, out_data, data, AI_NETWORK_IN_1_SIZE) != 0)
      {
        return -1;
      }
    
      for (int i = 0; i < AI_NETWORK_OUT_1_SIZE; i++)
      {
        printf("%.2f, ", out_data[i]);
      }
      /* USER CODE END 2 */
    
      /* Infinite loop */
      /* USER CODE BEGIN WHILE */
      while (1)
      {
        /* USER CODE END WHILE */
    
        /* USER CODE BEGIN 3 */
        HAL_GPIO_TogglePin(LedHeart_GPIO_Port, LedHeart_Pin);
        HAL_Delay(1000);
      }
      /* USER CODE END 3 */
    }
    
    /**
     * @brief System Clock Configuration
     * @retval None
     */
    void SystemClock_Config(void)
    {
      RCC_OscInitTypeDef RCC_OscInitStruct = {0};
      RCC_ClkInitTypeDef RCC_ClkInitStruct = {0};
    
      /** Configure the main internal regulator output voltage
       */
      __HAL_RCC_PWR_CLK_ENABLE();
      __HAL_PWR_VOLTAGESCALING_CONFIG(PWR_REGULATOR_VOLTAGE_SCALE1);
    
      /** Initializes the RCC Oscillators according to the specified parameters
       * in the RCC_OscInitTypeDef structure.
       */
      RCC_OscInitStruct.OscillatorType = RCC_OSCILLATORTYPE_HSE;
      RCC_OscInitStruct.HSEState = RCC_HSE_ON;
      RCC_OscInitStruct.PLL.PLLState = RCC_PLL_ON;
      RCC_OscInitStruct.PLL.PLLSource = RCC_PLLSOURCE_HSE;
      RCC_OscInitStruct.PLL.PLLM = 4;
      RCC_OscInitStruct.PLL.PLLN = 168;
      RCC_OscInitStruct.PLL.PLLP = RCC_PLLP_DIV2;
      RCC_OscInitStruct.PLL.PLLQ = 4;
      if (HAL_RCC_OscConfig(&RCC_OscInitStruct) != HAL_OK)
      {
        Error_Handler();
      }
    
      /** Initializes the CPU, AHB and APB buses clocks
       */
      RCC_ClkInitStruct.ClockType = RCC_CLOCKTYPE_HCLK | RCC_CLOCKTYPE_SYSCLK | RCC_CLOCKTYPE_PCLK1 | RCC_CLOCKTYPE_PCLK2;
      RCC_ClkInitStruct.SYSCLKSource = RCC_SYSCLKSOURCE_PLLCLK;
      RCC_ClkInitStruct.AHBCLKDivider = RCC_SYSCLK_DIV1;
      RCC_ClkInitStruct.APB1CLKDivider = RCC_HCLK_DIV4;
      RCC_ClkInitStruct.APB2CLKDivider = RCC_HCLK_DIV2;
    
      if (HAL_RCC_ClockConfig(&RCC_ClkInitStruct, FLASH_LATENCY_5) != HAL_OK)
      {
        Error_Handler();
      }
    
      /** Enables the Clock Security System
       */
      HAL_RCC_EnableCSS();
    }
    
    /**
     * @brief CRC Initialization Function
     * @param None
     * @retval None
     */
    static void MX_CRC_Init(void)
    {
    
      /* USER CODE BEGIN CRC_Init 0 */
    
      /* USER CODE END CRC_Init 0 */
    
      /* USER CODE BEGIN CRC_Init 1 */
    
      /* USER CODE END CRC_Init 1 */
      hcrc.Instance = CRC;
      if (HAL_CRC_Init(&hcrc) != HAL_OK)
      {
        Error_Handler();
      }
      /* USER CODE BEGIN CRC_Init 2 */
    
      /* USER CODE END CRC_Init 2 */
    }
    
    /**
     * @brief I2C1 Initialization Function
     * @param None
     * @retval None
     */
    static void MX_I2C1_Init(void)
    {
    
      /* USER CODE BEGIN I2C1_Init 0 */
    
      /* USER CODE END I2C1_Init 0 */
    
      /* USER CODE BEGIN I2C1_Init 1 */
    
      /* USER CODE END I2C1_Init 1 */
      hi2c1.Instance = I2C1;
      hi2c1.Init.ClockSpeed = 100000;
      hi2c1.Init.DutyCycle = I2C_DUTYCYCLE_2;
      hi2c1.Init.OwnAddress1 = 0;
      hi2c1.Init.AddressingMode = I2C_ADDRESSINGMODE_7BIT;
      hi2c1.Init.DualAddressMode = I2C_DUALADDRESS_DISABLE;
      hi2c1.Init.OwnAddress2 = 0;
      hi2c1.Init.GeneralCallMode = I2C_GENERALCALL_DISABLE;
      hi2c1.Init.NoStretchMode = I2C_NOSTRETCH_DISABLE;
      if (HAL_I2C_Init(&hi2c1) != HAL_OK)
      {
        Error_Handler();
      }
      /* USER CODE BEGIN I2C1_Init 2 */
    
      /* USER CODE END I2C1_Init 2 */
    }
    
    /**
     * @brief USART1 Initialization Function
     * @param None
     * @retval None
     */
    static void MX_USART1_UART_Init(void)
    {
    
      /* USER CODE BEGIN USART1_Init 0 */
    
      /* USER CODE END USART1_Init 0 */
    
      /* USER CODE BEGIN USART1_Init 1 */
    
      /* USER CODE END USART1_Init 1 */
      huart1.Instance = USART1;
      huart1.Init.BaudRate = 115200;
      huart1.Init.WordLength = UART_WORDLENGTH_8B;
      huart1.Init.StopBits = UART_STOPBITS_1;
      huart1.Init.Parity = UART_PARITY_NONE;
      huart1.Init.Mode = UART_MODE_TX_RX;
      huart1.Init.HwFlowCtl = UART_HWCONTROL_NONE;
      huart1.Init.OverSampling = UART_OVERSAMPLING_16;
      if (HAL_UART_Init(&huart1) != HAL_OK)
      {
        Error_Handler();
      }
      /* USER CODE BEGIN USART1_Init 2 */
    
      /* USER CODE END USART1_Init 2 */
    }
    
    /**
     * @brief GPIO Initialization Function
     * @param None
     * @retval None
     */
    static void MX_GPIO_Init(void)
    {
      GPIO_InitTypeDef GPIO_InitStruct = {0};
    
      /* GPIO Ports Clock Enable */
      __HAL_RCC_GPIOH_CLK_ENABLE();
      __HAL_RCC_GPIOA_CLK_ENABLE();
      __HAL_RCC_GPIOD_CLK_ENABLE();
      __HAL_RCC_GPIOB_CLK_ENABLE();
    
      /*Configure GPIO pin Output Level */
      HAL_GPIO_WritePin(LedHeart_GPIO_Port, LedHeart_Pin, GPIO_PIN_RESET);
    
      /*Configure GPIO pin : LedHeart_Pin */
      GPIO_InitStruct.Pin = LedHeart_Pin;
      GPIO_InitStruct.Mode = GPIO_MODE_OUTPUT_PP;
      GPIO_InitStruct.Pull = GPIO_NOPULL;
      GPIO_InitStruct.Speed = GPIO_SPEED_FREQ_LOW;
      HAL_GPIO_Init(LedHeart_GPIO_Port, &GPIO_InitStruct);
    }
    
    /* USER CODE BEGIN 4 */
    
    /* USER CODE END 4 */
    
    /**
     * @brief  This function is executed in case of error occurrence.
     * @retval None
     */
    void Error_Handler(void)
    {
      /* USER CODE BEGIN Error_Handler_Debug */
      /* User can add his own implementation to report the HAL error return state */
      __disable_irq();
      while (1)
      {
      }
      /* USER CODE END Error_Handler_Debug */
    }
    
    #ifdef USE_FULL_ASSERT
    /**
     * @brief  Reports the name of the source file and the source line number
     *         where the assert_param error has occurred.
     * @param  file: pointer to the source file name
     * @param  line: assert_param error line source number
     * @retval None
     */
    void assert_failed(uint8_t *file, uint32_t line)
    {
      /* USER CODE BEGIN 6 */
      /* User can add his own implementation to report the file name and line number,
         ex: printf("Wrong parameters value: file %s on line %d\r\n", file, line) */
      /* USER CODE END 6 */
    }
    #endif /* USE_FULL_ASSERT */
    
    

    作者:YAN_KEEE

    物联沃分享整理
    物联沃-IOTWORD物联网 » [STM32]:使用X-CUBE-AI进行模型推理的指南

    发表评论