bpe_framework/include/lm/models/feed_forward.hpp
2025-08-27 14:02:03 -07:00

33 lines
624 B
C++

#pragma once
#include "lm/core/tensor.hpp"
#include <vector>
namespace lm {
class FeedForward {
public:
FeedForward(size_t d_model, size_t d_ff, float dropout = 0.1f);
std::vector<Tensor> parameters() const;
void set_training(bool training);
Tensor forward(const Tensor& input) const;
private:
Tensor apply_dropout(const Tensor& input, float dropout_rate) const;
Tensor gelu(const Tensor& input) const;
size_t d_model_;
size_t d_ff_;
float dropout_;
bool training_ = false;
Tensor w1_;
Tensor b1_;
Tensor w2_;
Tensor b2_;
};
} // namespace lm