반응형

이번엔 line buffer 를 이용한 maxpool2d 모듈을 한 번 만들어 보도록 한다.

 

1. Overview

maxpool 은 pool size 중 가장 큰 값을 추출하는 과정이다.

예시를 들어서 하나의 input image 가 있다면,

(0,0) (0,1) (0,2) (0,3) (0,4) (0,5)
(1,0) ... ... ... ... ...
(2,0) ... ... ... ... ...
(3,0) ... ... ... ... ...
... ... ... ... ... ...
... ... ... ... ... ...

 

이 중에서

[(0,0), (0,1),

(1,0), (1,1)] 부터 시작해서 이 중에서 가장 큰 값을 찾아나가는 과정이다.

이 때 pool size 가 2라면 stride 가 2이고, 처음 연산을 진행한 후 윈도우 위치를 2만큼 옮긴다.

 

이 과정을 거치면서, stride 가 2이므로 input image size 가 절반으로 줄어든다.

 

input image size (height, width) : (6, 6) -> (3, 3) 

 

이를 Verilog, line buffer 를 사용해서 구현을 해 보도록 한다.

 


line buffer 에는, 늘 그렇듯이, 데이터가 한 클럭마다 계속 들어오는 구조를 지닌다.

근데 이 데이터가 들어오는 시점을 조절 해야 하는데, 이를 start flag 로 지정한다. (전 모듈이 convolution 이니, output o_valid 이 처음 들어왔다면 시작 신호로 보면 될 것이다.)

start flag 가 들어 왔다면, 이후에는 지속적으로 연산을 진행한다. (이는 clock dependency 를 가진다)

 

따라서 나는 State 를 다음과 같이 설정했다.

Start signal 을 받는 IDLE, 

line buffer 를 채우는 동작인 BUF_FILL,

버퍼가 다 채워지면 실제 maxpool 연산을 수행하는 WORK.

 

초기에 만든 코드는 다음과 같다. (동작하지는 않을 것이다.)

module maxpool2d#
    (
        parameter DATA_WIDTH = 16,  // fixed point 16-bit
        parameter IMG_WIDTH = 6,    // original image size
        parameter IMG_HEIGHT = 6,
    )
    (
        input i_clk,
        input i_rst,
        input i_start,
        input [DATA_WIDTH-1:0] i_data,
        output reg [DATA_WIDTH-1:0] o_data,
        output reg o_valid                      // current o_data validation flag
    );

/*========================================\\
                P A R A M                  
\\========================================*/
localparam  IDLE = 0,
            BUF_FILL = 1,
            WORK = 2;

localparam  LINE_BUF_SIZE   = IMG_WIDTH * 2;
localparam  MAX_WORK        = IMG_WIDTH / 2;
localparam  HALF_BUF_SIZE   = LINE_BUF_SIZE / 2;

/*========================================\\
                R E G                   
\\========================================*/
reg [DATA_WIDTH-1:0] c_line_buf[0:LINE_BUF_SIZE-1];
reg [DATA_WIDTH-1:0] n_line_buf[0:LINE_BUF_SIZE-1];       // 2 lines of image width size buffer


reg [1:0] c_state, n_state;
reg [$clog2(IMG_WIDTH):0] c_x, n_x;         // line buf x-counter
reg [$clog2(IMG_WIDTH):0] c_y, n_y;         // line buf y-counter
reg [$clog2(MAX_WORK):0] c_work_cnt, n_work_cnt;
integer i;
/*========================================\\
                F F                   
\\========================================*/
always@(posedge i_clk, negedge i_rst)
if(!i_rst) begin
    c_state <= 0;
    for(i = 0; i < LINE_BUF_SIZE; i=i+1)
        c_line_buf[i] <= 0;
    c_x <= 0;
    c_y <= 0;
    c_work_cnt <= 0;
end else begin
    c_state <= n_state;
    for(i = 0; i < LINE_BUF_SIZE; i=i+1)
        c_line_buf[i] <= n_line_buf[i];
    c_x <= n_x;
    c_y <= n_y;
    c_work_cnt <= n_work_cnt;
end




always@* begin
    n_state = c_state;
    n_line_buf = c_line_buf;

    // x, y movement (always, buf_filling & working)
    n_x =  (c_x == IMG_WIDTH - 1)  ? c_x + 1 : 0;
    n_y =  (c_y == IMG_HEIGHT)     ? 0 : 
            (c_x == IMG_WIDTH - 1)  ? c_y + 1 : c_y;

    n_work_cnt = 0;        
    o_valid = 0;

    case(c_state)
        IDLE: begin
            n_x = 0;
            n_y = 0;
            for(i = 0; i < LINE_BUF_SIZE; i=i+1)    n_line_buf[i] = 0;
            
            
            if(i_start) begin
                n_state = BUF_FILL;
                n_line_buf[0] = i_data;        // start flag with 1 input data
                n_x = c_x + 1;                 // start flag with 1 input data
            end

        end

        BUF_FILL: begin
            // buffer movement
            n_line_buf[0] = i_data;
            for(i = 1; i < LINE_BUF_SIZE; i=i+1)
                n_line_buf[i] = c_line_buf[i-1];       // buffer left shift 

            

            if(c_x == IMG_WIDTH-1 && c_y[0]) begin      // if buffer filled
                n_state = WORK;


            end

        end

        WORK: begin
            o_valid = 1;   // wire ? reg ? 
            o_data  = result;
            // buffer movement
            n_line_buf[0] = i_data;
            for(i = 1; i < LINE_BUF_SIZE; i=i+1)
                n_line_buf[i] = c_line_buf[i-1];       // buffer left shift 

            // work counter, MAX == IMG_WIDTH / 2
            n_work_cnt = c_work_cnt + 1;

            n_state =  c_work_cnt == MAX_WORK ? 
                        c_y == (IMG_HEIGHT - 1) ? IDLE : BUF_FILL : c_state;

        end
    endcase

end

// make this to register ? or wire ?
wire signed [DATA_WIDTH-1:0] comp1, comp2, result;
assign comp1 = c_line_buf[LINE_BUF_SIZE - c_work_cnt] > c_line_buf[LINE_BUF_SIZE - c_work_cnt - 1] ? 
                c_line_buf[LINE_BUF_SIZE - c_work_cnt] : c_line_buf[LINE_BUF_SIZE - c_work_cnt - 1];
assign comp2 = c_line_buf[HALF_BUF_SIZE - c_work_cnt] > c_line_buf[HALF_BUF_SIZE - c_work_cnt - 1] ?
                c_line_buf[HALF_BUF_SIZE - c_work_cnt] : c_line_buf[HALF_BUF_SIZE - c_work_cnt - 1];
assign result = comp1 > comp2 ? comp1 : comp2;




endmodule

 

 

그리고 gpt 를 돌려서 수정한 코드는 다음과 같다. (이건 동작한다. 하지만 세부적인 검증은 해보지 않았다.)

module maxpool2d_gpt#
(
    parameter DATA_WIDTH = 16,  // fixed point 16-bit
    parameter IMG_WIDTH = 6,    // original image size
    parameter IMG_HEIGHT = 6
)
(
    input i_clk,
    input i_rst,
    input i_start,
    input [DATA_WIDTH-1:0] i_data,
    output [DATA_WIDTH-1:0] o_data,
    output reg o_valid                      // current o_data validation flag
);

/*========================================\\
                P A R A M
\\========================================*/
localparam  IDLE = 0,
            BUF_FILL = 1,
            WORK = 2;

localparam  LINE_BUF_SIZE   = IMG_WIDTH * 2;
localparam  MAX_WORK        = IMG_WIDTH / 2;
localparam  HALF_BUF_SIZE   = LINE_BUF_SIZE / 2;

/*========================================\\
            R E G & W I R E S
\\========================================*/
reg [DATA_WIDTH-1:0] c_line_buf[0:LINE_BUF_SIZE-1];
reg [DATA_WIDTH-1:0] n_line_buf[0:LINE_BUF_SIZE-1];       // 2 lines of image width size buffer

reg [1:0] c_state, n_state;
reg [$clog2(IMG_WIDTH):0] c_x, n_x;         // line buf x-counter
reg [$clog2(IMG_WIDTH):0] c_y, n_y;         // line buf y-counter
reg [$clog2(MAX_WORK):0] c_work_cnt, n_work_cnt;


reg n_o_valid;
reg [DATA_WIDTH-1:0] n_o_data;
integer i;

wire signed [DATA_WIDTH-1:0] comp1, comp2, result;

/*========================================\\
                F F
\\========================================*/
always@(posedge i_clk or negedge i_rst) begin
    if(!i_rst) begin
        c_state <= IDLE;
        for(i = 0; i < LINE_BUF_SIZE; i=i+1)
            c_line_buf[i] <= 0;
        c_x <= 0;
        c_y <= 0;
        c_work_cnt <= 0;
    end else begin
        c_state <= n_state;
        for(i = 0; i < LINE_BUF_SIZE; i=i+1)
            c_line_buf[i] <= n_line_buf[i];
        c_x <= n_x;
        c_y <= n_y;
        c_work_cnt <= n_work_cnt;  
    end
end

/*========================================\\
                C O M B
\\========================================*/


always@* begin
    n_state = c_state;
    for(i = 0; i < LINE_BUF_SIZE; i=i+1)
        n_line_buf[i] = c_line_buf[i];

    n_x = c_x;
    n_y = c_y;
    n_work_cnt = c_work_cnt;    
    o_valid = 0;    
   

    case(c_state)
        IDLE: begin
            n_x = 0;
            n_y = 0;
            for(i = 0; i < LINE_BUF_SIZE; i=i+1)
                n_line_buf[i] = 0;

            if(i_start) begin
                n_state = BUF_FILL;
                n_line_buf[0] = i_data;        // start flag with 1 input data
                n_x = c_x + 1;                 // start flag with 1 input data
            end
        end

        BUF_FILL: begin
            // buffer movement
            n_line_buf[0] = i_data;
            for(i = 1; i < LINE_BUF_SIZE; i=i+1)
                n_line_buf[i] = c_line_buf[i-1];       // buffer left shift 

            // x, y movement
            if (c_x == IMG_WIDTH - 1) begin
                n_x = 0;
                n_y = c_y + 1;
            end else begin
                n_x = c_x + 1;
                n_y = c_y;
            end

            if(c_y == IMG_HEIGHT + 1)
                n_state = IDLE;
            if(c_x == IMG_WIDTH-1 && c_y[0]) begin      // if buffer filled
                n_state = WORK;
                n_work_cnt = 0;
            end

        end

        WORK: begin
            o_valid = 1;
            
            // buffer movement
            n_line_buf[0] = i_data;
            for(i = 1; i < LINE_BUF_SIZE; i=i+1)
                n_line_buf[i] = c_line_buf[i-1];       // buffer left shift 

            // x, y movement
            if (c_x == IMG_WIDTH - 1) begin
                n_x = 0;
                n_y = c_y + 1;
            end else begin
                n_x = c_x + 1;
                n_y = c_y;
            end

            // work counter, MAX == IMG_WIDTH / 2
            n_work_cnt = c_work_cnt + 1;

            if (c_work_cnt == MAX_WORK - 1) begin
                if (c_y == IMG_HEIGHT - 1) begin
                    n_state = IDLE;
                end else begin
                    n_state = BUF_FILL;
                end
            end
        end
    endcase
end

// Result calculation
// cannot be replaced to parameter assign 
wire [$clog2(LINE_BUF_SIZE):0]  idx1_comp1,
                                idx2_comp1,
                                idx1_comp2,
                                idx2_comp2;
assign idx1_comp1 = LINE_BUF_SIZE - c_work_cnt-1;
assign idx2_comp1 = LINE_BUF_SIZE - c_work_cnt-2;
assign idx1_comp2 = HALF_BUF_SIZE - c_work_cnt-1;
assign idx2_comp2 = HALF_BUF_SIZE - c_work_cnt-2;


assign comp1 = (c_line_buf[idx1_comp1] > c_line_buf[idx2_comp1]) ? c_line_buf[idx1_comp1] : c_line_buf[idx2_comp1];
assign comp2 = (c_line_buf[idx1_comp2] > c_line_buf[idx2_comp2]) ? c_line_buf[idx1_comp2] : c_line_buf[idx2_comp2];
assign result = (comp1 > comp2) ? comp1 : comp2;
assign o_data = result;

endmodule

 

Testbench 는 다음과 같다

`timescale 1ns/1ps

module tb_maxpool2d;

    // Parameters
    parameter DATA_WIDTH = 16;
    parameter IMG_WIDTH = 16;
    parameter IMG_HEIGHT = 16;
    parameter CLK_PERIOD = 10;

    // Inputs
    reg i_clk;
    reg i_rst;
    reg i_start;
    reg [DATA_WIDTH-1:0] i_data;

    // Outputs
    wire [DATA_WIDTH-1:0] o_data;
    wire o_valid;

    // Instantiate the Unit Under Test (UUT)
    maxpool2d_gpt #
    (
        .DATA_WIDTH(DATA_WIDTH),
        .IMG_WIDTH(IMG_WIDTH),
        .IMG_HEIGHT(IMG_HEIGHT)
    )
    uut
    (
        .i_clk(i_clk),
        .i_rst(i_rst),
        .i_start(i_start),
        .i_data(i_data),
        .o_data(o_data),
        .o_valid(o_valid)
    );

    // Clock generation
    initial begin
        i_clk = 0;
        forever #(CLK_PERIOD/2) i_clk = ~i_clk;
    end

    // Test vectors
    reg [DATA_WIDTH-1:0] image_data [0:IMG_WIDTH*IMG_HEIGHT-1];

    integer i;

    initial begin
        // Initialize Inputs
        i_rst = 0;
        i_start = 0;
        i_data = 0;

        // Reset sequence
        #(CLK_PERIOD);
        i_rst = 1;
        #(CLK_PERIOD);

        // Load image data
        // Here, we generate a test pattern; you can replace this with actual image data if needed
        for (i = 0; i < IMG_WIDTH*IMG_HEIGHT; i = i + 1) begin
            image_data[i] = (i*3) % 256; // Example data: pixel values from 0, 3, 6, ...
        end

        // Start the maxpool2d operation
        i_start = 1;
            

        // Feed input data
        for (i = 0; i < IMG_WIDTH*IMG_HEIGHT; i = i + 1) begin
            i_data = image_data[i];
            #(CLK_PERIOD);
            i_start = 0;
        end

        // Wait for processing to complete
        #(CLK_PERIOD * (IMG_WIDTH * IMG_HEIGHT));

        // Finish simulation
        $stop;
    end

    // Monitor outputs
    initial begin
        $display("Time\tclk\trst\tstart\tdata_in\tdata_out\tvalid");
        $monitor("%0t\t%b\t%b\t%b\t%h\t%h\t%b", $time, i_clk, i_rst, i_start, i_data, o_data, o_valid);
    end

endmodule

 

Modelsim 돌려본 결과는 다음과 같다

 

이건 16x16 예제인데, valid 가 총 8번 뜨는 것을 확인 가능하다.(line 마다 한번 뜬다고 생각하면 되겠다.)

들어가는 input 값은 0, 3, 6, ... 3n 인데, mod 256 를 적용했다.

그래서 처음 버퍼가 다 찬 경우에

line 1 : 0, 3, 6, ...

line 2 : 48, 51, 54, ...

처음 시작이 51, 그다음은 57, ... 이렇게 나와야 할 것이다.

 

시뮬레이션의 결과도 보게 되면 51부터 시작해서 93까지 총 8개 (하나의 line) 나오는 것을 확인 가능하다.

 

 

반응형

+ Recent posts