๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
์ธ๊ณต์ง€๋Šฅ

์‹ ๊ฒฝ๋ง ํ•™์Šต์— ํ•„์š”ํ•œ ํ•จ์ˆ˜๋“ค

by sh119 2025. 7. 24.

์ง€๊ธˆ๊นŒ์ง€ ํผ์…‰ํŠธ๋ก , ์‹ ๊ฒฝ๋ง์— ๋Œ€ํ•ด ๋ฐฐ์šฐ๊ณ  ์ด์ œ ์‹ ๊ฒฝ๋ง ํ•™์Šต์˜ ๋Œ€๋žต์ ์ธ ๋ฐฉํ–ฅ์— ๋Œ€ํ•ด ๋ฐฐ์› ๋‹ค. 

์‹ ๊ฒฝ๋ง ํ•™์Šต์„ ์œ„ํ•ด ์ค‘์š”ํ•œ ์‚ฌ์‹ค๋“ค์ด ๋งŽ์ง€๋งŒ, ๊ทธ๊ฑด ๋‹ค๋ฅธ ํŽ˜์ด์ง€์—์„œ ์ •๋ฆฌํ•ด๋ณด๋„๋ก ํ•˜๊ณ  ํฐ ํ๋ฆ„์€ ์•„๋ž˜์™€ ๊ฐ™๋‹ค.

 

"๋ฏธ๋‹ˆ๋ฐฐ์น˜๋กœ ๋ฐ์ดํ„ฐ ๋ฝ‘๊ธฐ → ์†์‹คํ•จ์ˆ˜๋กœ ๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐ → ๊ฒฝ์‚ฌํ•˜๊ฐ•๋ฒ•์œผ๋กœ ๊ฐ€์ค‘์น˜ ๊ฐฑ์‹  → ๋ฐ˜๋ณต"

 

์—ฌ๊ธฐ์—์„  ์‹ ๊ฒฝ๋ง ํ•™์Šต์„ ์ฝ”๋“œ๋กœ ๊ตฌํ˜„ํ•จ์— ์žˆ์–ด ํ•„์š”ํ•œ ํ•จ์ˆ˜๋“ค์„ ์ด ์ •๋ฆฌํ•ด ๋ณด๋ ค๊ณ  ํ•œ๋‹ค.

 

1. ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ

  • get_mini_batch()
    • ๋ฐ์ดํ„ฐ์…‹์—์„œ ๋ฌด์ž‘์œ„๋กœ ์ผ๋ถ€ ๋ฐ์ดํ„ฐ๋ฅผ ๋ฝ‘์•„์˜ค๋Š” ํ•จ์ˆ˜
    • ์ž…๋ ฅ: ์ „์ฒด ๋ฐ์ดํ„ฐ (x, t), ๋ฐฐ์น˜ ํฌ๊ธฐ
    • ์ถœ๋ ฅ: x_batch, t_batch

2.  ์ˆœ์ „ํŒŒ (Forward propagation)

  • predict(x)
    • ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ๋„ฃ์—ˆ์„ ๋•Œ ์ถœ๋ ฅ y๋ฅผ ๊ณ„์‚ฐ
    • ๋‚ด๋ถ€์ ์œผ๋กœ: A=XW+B → Z=f(A)Y = softmax(Z)
  • loss(x, t)
    • ํ˜„์žฌ ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ์˜ˆ์ธกํ•œ ๊ฐ’๊ณผ ์ •๋‹ต ์‚ฌ์ด์˜ ์†์‹ค๊ฐ’ ๊ณ„์‚ฐ
    • ๋‚ด๋ถ€์ ์œผ๋กœ: predict(x) ํ˜ธ์ถœ + ์†์‹ค ํ•จ์ˆ˜(cross_entropy_error())

3.  ์†์‹ค ํ•จ์ˆ˜ (Loss function)

  • mean_squared_error(y, t)
    • ํšŒ๊ท€ ๋ฌธ์ œ์šฉ
  • cross_entropy_error(y, t)
    • ๋ถ„๋ฅ˜ ๋ฌธ์ œ์šฉ (์›-ํ•ซ ์ธ์ฝ”๋”ฉ๋œ t์™€ softmax y๋ฅผ ๋น„๊ต)

4. ๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐ (Gradient)

  • numerical_gradient(f, W)
    • ์ˆ˜์น˜๋ฏธ๋ถ„์œผ๋กœ ๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐ
    • f: ์†์‹ค ํ•จ์ˆ˜, W: ํŒŒ๋ผ๋ฏธํ„ฐ
  • gradient(x, t)
    • ์‹ ๊ฒฝ๋ง ์ „์ฒด ํŒŒ๋ผ๋ฏธํ„ฐ(W, b)์— ๋Œ€ํ•œ ์†์‹ค ํ•จ์ˆ˜์˜ ๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐ
    • ๋‚ด๋ถ€์ ์œผ๋กœ: numerical_gradient()๋ฅผ ๊ฐ ๋ ˆ์ด์–ด ํŒŒ๋ผ๋ฏธํ„ฐ์— ์ ์šฉ
    • (5์žฅ์—์„œ๋Š” ์ด๊ฒŒ ์—ญ์ „ํŒŒ ๋ฒ„์ „์œผ๋กœ ๋Œ€์ฒด)

5. ๋งค๊ฐœ๋ณ€์ˆ˜ ๊ฐฑ์‹  (Parameter Update)

  • SGD(params, grads, lr)
    • ๊ฒฝ์‚ฌํ•˜๊ฐ•๋ฒ•.
    • W←W−η⋅∇W
    • ๋ชจ๋“  ํŒŒ๋ผ๋ฏธํ„ฐ(W, b)์— ๋Œ€ํ•ด ๋ฐ˜๋ณต

6. ํ•™์Šต ๋ฃจํ”„

  • train()
    • ์ „์ฒด ๋ฐ์ดํ„ฐ ๋ฐ˜๋ณต ํ•™์Šต
    • ์ ˆ์ฐจ:
      1. get_mini_batch()
      2. gradient()
      3. SGD() : ๋งค๊ฐœ๋ณ€์ˆ˜ ๊ฐฑ์‹ 
      4. ์†์‹ค/์ •ํ™•๋„ ๊ธฐ๋ก

7. ์ •ํ™•๋„ ํ‰๊ฐ€

  • accuracy(x, t)
    • ํ˜„์žฌ ๋ชจ๋ธ์ด ๋งž์ถ˜ ๋น„์œจ ๊ณ„์‚ฐ (ํ…Œ์ŠคํŠธ์…‹ ํ‰๊ฐ€)

 

์ฆ‰, ์ •๋ฆฌํ•˜๋ฉด:

  • ๋ฐ์ดํ„ฐ ๋ฝ‘๊ธฐ: get_mini_batch()
  • ์ˆœ์ „ํŒŒ: predict()
  • ์†์‹ค: loss(), cross_entropy_error()
  • ๊ธฐ์šธ๊ธฐ: numerical_gradient(), gradient()
  • ์—…๋ฐ์ดํŠธ: SGD()
  • ํ‰๊ฐ€: accuracy()

์ด๊ฑธ ํ•ฉ์ณ์„œ TwoLayerNet ๊ฐ™์€ ํด๋ž˜์Šค๊ฐ€ ๋งŒ๋“ค์–ด์ง€๊ณ , ํ•™์Šต ๋ฃจํ”„์—์„œ ๊ณ„์† ๋ฐ˜๋ณตํ•˜๋Š” ๊ตฌ์กฐ๊ฐ€ ๋œ๋‹ค. 

 

 

๋งˆ์ง€๋ง‰์œผ๋กœ ํ•ด๋‹น ๊ณผ์ •์„ ์•„๋ž˜์™€ ๊ฐ™์ด ๊ฐ„๋‹จํ•˜๊ฒŒ ๋‹ค์ด์–ด๊ทธ๋žจ์œผ๋กœ ๊ทธ๋ ค๋ณด์•˜๋‹ค!

๋ฐ‘๋ฐ”๋‹ฅ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•˜๋Š” ๋”ฅ๋Ÿฌ๋‹1๊ถŒ์˜ 4์žฅ์˜ ํ๋ฆ„์€ ์•„๋ž˜ ๋‹ค์ด์–ด๊ทธ๋žจ์œผ๋กœ ์„ค๋ช… ๊ฐ€๋Šฅํ•˜๋‹ค.