#pragma once #include "lm/core/tensor.hpp" #include #include namespace lm { class MultiHeadAttention { public: MultiHeadAttention(size_t d_model, size_t num_heads, float dropout = 0.1f); std::vector parameters() const; void set_training(bool training); Tensor forward(const Tensor& query, const Tensor& key, const Tensor& value, const Tensor& mask = Tensor()) const; private: Tensor split_heads(const Tensor& x) const; Tensor combine_heads(const Tensor& x) const; Tensor scaled_dot_product_attention(const Tensor& q, const Tensor& k, const Tensor& v, const Tensor& mask) const; Tensor apply_dropout(const Tensor& input, float dropout_rate) const; size_t d_model_; size_t num_heads_; size_t d_k_; float dropout_; bool training_ = false; Tensor w_q_; Tensor w_k_; Tensor w_v_; Tensor w_o_; }; } // namespace lm