ทำความเข้าใจ Optimizer

Original article was published by Pakawat Nakwijit on Deep Learning on Medium


ต่อมา ขอพูดถึงในรูปแบบของสมการกันบ้าง

ตามนิยามที่อธิบายไปแล้ว ว่า momentum คือ การคงสภาพความเร็วเดิม ซึ่งความเร็วในที่นี้ คือ gradient ที่จะอัพเดตในแต่ละรอบ ดังนั้น momentum จึงสามารถเขียนเป็นสมการได้ดังนี้

อธิบายในเป็นข้อความ คือ สิ่งที่จะอัพเดตในรอบนี้ v_new คือ ส่วนหนึ่งของการอัพเดตครั้งก่อน (η * v_old) บวกกับ gradient ของรอบนี้ (-α * dL/dw)

โดย η คือ momentum factor ซึ่งมีค่าระหว่าง 0 ถึง 1

จากนั้น การอัพเดต parameter จึงเป็นไปตามสมการนี้

Regularization

เป็นอีกหนึ่ง term ที่มักพบเมื่อมีการใช้ optimizer ซึ่งทำหน้าที่จำกัดไม่ให้โมเดลยึดติดกับข้อมูลที่ใช้เทรนมากเกินไป ซึ่งมักทำให้โมเดลไม่สามารถทำนายข้อมูลแปลกๆที่ไม่เคยเจอมาก่อนได้

โดยทั่วไปแล้ว regularization สามารถทำได้โดยการเพิ่ม term เข้าไปใน loss function เพิ่มกำหนดลักษณะบางอย่างของโมเดล หากฝ่าฝืน จะทำให้ loss ในส่วน regularization เพิ่มขึ้น เช่น L2 regularization ซึ่งจะลงโทษ เมื่อโมเดลพยายามใช้ parameter ที่มีค่าเยอะๆ

ซึ่งเมื่อใช้ L2 regularization แล้ว จำเป็นต้องเปลี่ยนสมการการปรับ parameters ดังนี้

ซึ่ง term ใหม่ที่เพิ่มเข้ามานี้มักโดนเรียกว่า weight decay เพราะว่า มันทำหน้าที่ ลบ wt ออกไป ตาม decay rate λ

แต่อย่างไรก็ตาม สามารถมองได้ว่า regularization เป็นเพียงแค่ส่วนหนึ่งของ loss function โดยไม่ได้กระทบกับการทำงานของ optimizer 🙂

//ทำงานแบบเดิม เพิ่มเติม คือ สมการเยอะขึ้น

Advanced Optimizers

ต่อมา จะมาพูดถึง optimizers แบบซับซ้อนกันบ้าง ซึ่งมีเป้าหมายเหมือนกัน คือ ช่วยเร่งให้โมเดลยิ่งเก่งขึ้นในระยะเวลาที่ลดลง

โดยทั่วไปแล้ว advanced optimizers เป็นการต่อยอด gradient descent ไม่ว่าจะเป็น [เพิ่มเติม]

  1. เปลี่ยนทิศทางในอัพเดตในรอบต่อไป (gradient term)
  2. เปลี่ยนขนาดในอัพเดตในรอบต่อไป (learning rate term)
  3. เปลี่ยนทั่งคู่เลย

Nesterov Accelerated Gradients (NAG)

NAG เป็นเทคนิคที่เพิ่มเติมจาก momentum โดยเมื่อเปรียบเทียบกับลูกแก้วที่มี momentum ในตัวอย่างก่อนๆ NAG เป็นเสมือนการเปลี่ยนลูกแก้วธรรมดาเป็นลูกแก้ววิเศษที่สามารถรู้อนาคตของตัวเองได้

โดยถ้าลูกแก้วรู้ว่า ในรอบต่อไปมันจะเจอกับทางลาดขึ้น ซึ่งจะทำให้มันต้องไหลย้อนกลับมาทางเดิมอีกครั้ง ลูกแก้วก็จะชิงลดความเร็วลงไปก่อนเลย เพื่อจะไม่ต้องเสียเวลากลิ้งไปมาๆ

โดยใช้สมการ momentum เดิม

Gradient update with momentum term

NAG ประมาณค่า parameter ต่างๆในอนาคต ตามสมการด้านล่าง โดยลบ gradient term โดยใช้สมมติฐานว่า gradient term เป็นแค่เลขน้อยๆ เมื่อเทียบกับ term อื่นๆ

จากนั้น จึงแทนสมการคำนวน สิ่งที่จะอัพเดต เป็น สิ่งที่จะอัพเดตในรอบนี้ v_new คือ ส่วนหนึ่งของการอัพเดตครั้งก่อน (η * v_old) บวกกับ gradient ที่จะเกิดขึ้นในอนาคต (-α * dL/dw_future)

Adaptive Optimizers

อีกหนึ่งข้อสังเกตที่เกิดขึ้นหลัง นักวิจัยศึกษา Gradient descent มาระยะหนึ่ง คือ จริงๆแล้ว learning rate และ momentum ไม่ควรเป็นค่าที่ใช้ร่วมกับสำหรับทุกๆ parameters เพราะ parameters แต่ละตัว ต่างทำหน้าที่แตกต่างกัน parameters ที่อยู่ใน neural network ชั้นแรกๆ ทำหน้าที่จัดการความหมายระดับคำ แต่ ชั้นหลังๆ ทำหน้าที่จัดการความหมายโดยรวม — ตัวอย่างนี้ อาจจะไม่ตรงตามความเป็นจริง ยกขึ้นมาเพื่อให้เห็นภาพตรงกัน

ทำให้เกิดแนวคิดที่เป็น Adaptive optimizers ซึ่งจะให้ learning rate และ momentum ของ parameters แต่ละตัวไม่เท่ากัน บางคนต้องการเปลี่ยนแปลงอย่างรวดเร็ว ก็จะมี learning rate มากๆ แต่บางตัวไม่ ก็จะมี learning rate น้อยๆ

Adagrad

Adagrad เป็นแนวคิดแรกๆของการทำ adaptive optimization โดยนิยามตัวแปรขึ้นมาอีกหนึ่งตัว เรียกว่า cache ซึ่งเป็นผลรวมของ gradients ของ parameters

โดย cache ของแต่ละ parameter จะเก็บแยกกัน ถ้าโมเดลมี parameters 10 ตัว ก็จะมี caches 10 ค่า ใช้สำหรับ parameters แต่ละตัว

ซึ่ง cache นี้เป็นลดบทบาทของ learning rate โดย

  • ขณะที่ parameters wi อัพเดตไปแล้วเยอะๆ จะทำให้ learning rate ของ wi มีค่าน้อยๆ ในรอบต่อๆไป wi จะโดยอัพเดตน้อยลง
  • ขณะที่ parameters wj ไม่เคยโดยอัพเดตเลย cache จะช่วยเพิ่ม learning rate ของ wj ให้มีค่ามากขึ้น ในรอบต่อๆไป wj จะโดยอัพเดตเยอะขึ้น

โดยการอัพเดต จะเป็นไปตามสมการต่อไปนี้

ค่าที่อยู่ใน cache ใช้เป็นการปรับจูน learning rate ตามที่อธิบายไปแล้ว โดยใส่รูท เพื่อทำให้ cache อยู่ในหน่วยเดียวกับ α (เดิม cache มีหน่วยเป็น α² เพราะว่า เป็นผลรวมของ gradient² )และ e เป็นตัวเลขเล็กๆ ใช้เพื่อป้องกันไม่ให้เกิดเหตุการณ์ที่เศษส่วนเป็น 0 ซึ่งโดยปกติแล้ว จะให้ e = 1e-8

ทั้งนี้ เนื่องด้วย cache เป็นผลรวมของ gradient² ซึ่งจะมีค่าเพิ่มขึ้นเสมอ เพราะว่าค่ายกกำลังสองเป็นบวกเสมอ จึงพูดได้ว่า cache จะมีค่าเพิ่มขึ้นเรื่อยๆอย่างรวดเร็ว ส่งผลให้ learning rate จะลดลงอย่างรวดเร็ว ซึ่งบางครั้ง learning rate ลดลงเร็วเกินไปจนทำให้ parameters บางตัว ไม่เปลี่ยนแปลงอีกต่อไป (learning rate 0) ในขณะที่ยังไม่สิ้นสุดการเทรน

RMSProp

RMSProp พยายามแก้ปัญหาของ Adagrad ด้วยการแทนที่ cache ด้วยการใช้ exponential average ตามสมการ โดยมี γ เป็น hyperparameter กำหนด decay rate ซึ่งจะลดทอนความสำคัญของ gradient² ในรอบก่อนๆ โดยปกติแล้ว กำหนดให้ γ อยู่ที่ 0.9 ถึง 0.99

ตามตัวอย่างด้านล่าง จะพบว่า ยิ่งเวลาผ่านไป จะยิ่งทำให้ gradient² ในรอบก่อนๆมีค่อยๆมีค่าความสำคัญลดลง หรือ พูดอีกแง่ คือ cache เน้นสนใจเฉพาะ gradient² จากรอบที่อยู่ไม่ห่างจากรอบปัจจุบัน ซึ่งนี้ทำให้อัตราการเพิ่มขึ้นของ cache ช้าลงกว่าเดิม

แจกแจงการคำนวน cache โดยให้ γ = 0.9 และแทน w_t ด้วย parameter w ในรอบที่ t

Adam

Adam เป็นเทคนิคที่พูดได้ว่า popular ที่สุดในปัจจุบัน ซึ่งรวมเอาข้อดีจากทั้ง RMSProp และ momentum เข้าไว้ด้วยกัน โดยใช้รวมเอาไว้ด้วยกันซะเลย

สมการแรก ใช้คำนวน momentum

สมการที่ 2 ใช้คำนวน cache ที่จะใช้ในการปรับ learning rate

โดย β1 และ β2 มักจะกำหนดอยู่ที่ 0.9 และ 0.99 ตามลำดับ — ปกติก็ใช้ค่า default นี้เลย

แต่จากการวิเคราะห์ พบว่า ทั้ง momemtum term และ cache term ที่คำนวนด้วยวิธีนี้มี bias เล็กๆ ที่เกิดจากว่า m0 และ cache0 ที่เป็นค่าเริ่มต้น กำหนดให้เป็น 0 ทำให้มีอาการเพี้ยนๆในระยะแรก ทั้ง 2 term จำเป็นจะถูกนำแก้ bias ก่อน ตามสมการ

เพราะว่าเป็น exponential average จึงแก้ bias ด้วยการหารด้วย 1 — (β1 และ β2 ยกกำลัง t)

จากนั้นก็เอาแทนในสมการของ RMSProp เดิม โดยใช้ m และ cache ที่แก้ bias แล้ว

แต่ที่น่าสนใจคือ ในหลายๆครั้ง Adam ที่ดูเหมือนว่า สามารถเอาจุดเด่นของๆทุกๆ optimizers มาใช้ประโยชน์ได้อย่างมีประสิทธิภาพ แต่กลับทำงานได้แย่กว่า SGD with momentum ในบางกรณี

จึงมีความพยายามในการพัฒนา Adam ออกมาในหลากหลายรูปแบบ เช่น AdaMax ที่แทนสมการ cache โดยการใช้ ℓ-∞ converges ซึ่งออกมาในรูปของ max(…) หรือ NAdam ซึ่งรวม NAG เข้ากับ Adam หรือ AdamW ที่เป็นการเอา Adam มารวมกับ weight decay [เพิ่มเติม] แต่อย่างไรก็ตาม SGD with momentum และ Adam ก็ยังเป็น optimizer มาตรฐานที่ใช้กันทั่วๆไป