bpe_framework/include/lm/models/conversation_model.hpp
2025-09-13 12:45:42 -07:00

55 lines
1.5 KiB
C++

// Enhanced conversation_model.hpp
#pragma once
#include "transformer_model.hpp"
#include "bpe_tokenizer.hpp"
#include "context_manager.hpp"
#include <string>
#include <vector>
#include <memory>
namespace lm {
class ConversationModel {
public:
ConversationModel(size_t vocab_size,
size_t d_model = 512,
size_t n_layers = 6,
size_t n_heads = 8,
size_t d_ff = 2048,
float dropout = 0.1);
// Train the model
void train(const std::vector<std::string>& conversations);
// Generate a response with context management
std::string generate_response(const std::string& user_input);
// Context management
void clear_context();
void set_system_prompt(const std::string& prompt);
size_t get_context_token_count() const;
// Save and load
bool save_model(const std::string& path);
bool load_model(const std::string& path);
// Set tokenizer
void set_tokenizer(std::shared_ptr<BPETokenizer> tokenizer) {
tokenizer_ = tokenizer;
context_manager_ = std::make_unique<ContextManager>(2048, 20);
}
private:
std::shared_ptr<BPETokenizer> tokenizer_;
std::unique_ptr<TransformerModel> transformer_;
std::unique_ptr<ContextManager> context_manager_;
std::string system_prompt_;
// Format conversation for training
std::string format_conversation(const std::vector<std::string>& turns);
};
} // namespace lm