// Enhanced conversation_model.hpp #pragma once #include "transformer_model.hpp" #include "bpe_tokenizer.hpp" #include "context_manager.hpp" #include #include #include namespace lm { class ConversationModel { public: ConversationModel(size_t vocab_size, size_t d_model = 512, size_t n_layers = 6, size_t n_heads = 8, size_t d_ff = 2048, float dropout = 0.1); // Train the model void train(const std::vector& conversations); // Generate a response with context management std::string generate_response(const std::string& user_input); // Context management void clear_context(); void set_system_prompt(const std::string& prompt); size_t get_context_token_count() const; // Save and load bool save_model(const std::string& path); bool load_model(const std::string& path); // Set tokenizer void set_tokenizer(std::shared_ptr tokenizer) { tokenizer_ = tokenizer; context_manager_ = std::make_unique(2048, 20); } private: std::shared_ptr tokenizer_; std::unique_ptr transformer_; std::unique_ptr context_manager_; std::string system_prompt_; // Format conversation for training std::string format_conversation(const std::vector& turns); }; } // namespace lm