bpe_framework/include/lm/models/transformer_block.hpp
2025-08-27 14:02:03 -07:00

33 lines
796 B
C++

#pragma once
#include "lm/core/tensor.hpp"
#include "lm/models/attention.hpp"
#include "lm/models/feed_forward.hpp"
#include "lm/models/layer_norm.hpp"
#include <memory>
#include <vector>
namespace lm {
class TransformerBlock {
public:
TransformerBlock(size_t d_model, size_t num_heads, size_t d_ff, float dropout);
std::vector<Tensor> parameters() const;
void set_training(bool training);
Tensor forward(const Tensor& input, const Tensor& mask = Tensor()) const;
private:
size_t d_model_, num_heads_, d_ff_;
float dropout_;
bool training_ = false;
std::unique_ptr<MultiHeadAttention> attention_;
std::unique_ptr<FeedForward> feed_forward_;
std::unique_ptr<LayerNorm> norm1_;
std::unique_ptr<LayerNorm> norm2_;
};
} // namespace lm