| |
|
| |
|
| |
|
| | import json
|
| | import os
|
| | import itertools
|
| | from pathlib import Path
|
| | from datasets import load_dataset
|
| | from transformers import AutoTokenizer
|
| | import langdetect
|
| | from tqdm import tqdm
|
| | import argparse
|
| | import random
|
| | from collections import defaultdict
|
| |
|
| |
|
| | class ConversationDataPreprocessor:
|
| | def __init__(self, output_dir="data", max_length=1024):
|
| | self.output_dir = Path(output_dir)
|
| | self.max_length = max_length
|
| | self.setup_directories()
|
| |
|
| | def setup_directories(self):
|
| | """Create necessary directories"""
|
| | dirs = ["conversation_raw", "conversation_processed", "conversation_final"]
|
| | for d in dirs:
|
| | (self.output_dir / d).mkdir(parents=True, exist_ok=True)
|
| |
|
| | def download_conversational_data(self, dataset_name="OpenAssistant/oasst1", num_conversations=20000):
|
| | """Download conversational dataset from HuggingFace"""
|
| | print(f"Downloading {num_conversations} conversations from {dataset_name}...")
|
| |
|
| | raw_path = self.output_dir / "conversation_raw" / f"{dataset_name.replace('/', '_')}_raw.jsonl"
|
| |
|
| | try:
|
| |
|
| | ds = load_dataset(dataset_name, split="train", streaming=True)
|
| |
|
| | downloaded = 0
|
| | with open(raw_path, "w", encoding="utf-8") as f:
|
| | for row in tqdm(itertools.islice(ds, num_conversations), total=num_conversations):
|
| |
|
| | f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| | downloaded += 1
|
| |
|
| | print(f"Raw conversational data saved to: {raw_path}")
|
| | print(f"Downloaded {downloaded} conversation records")
|
| | return raw_path
|
| |
|
| | except Exception as e:
|
| | print(f"Error downloading {dataset_name}: {e}")
|
| | print("Trying alternative dataset...")
|
| | return self.download_alternative_dataset(num_conversations)
|
| |
|
| | def download_alternative_dataset(self, num_conversations=20000):
|
| | """Try alternative conversational datasets if primary fails"""
|
| | alternative_datasets = [
|
| | "databricks/databricks-dolly-15k",
|
| | "tatsu-lab/alpaca",
|
| | "vicgalle/alpaca-gpt4"
|
| | ]
|
| |
|
| | for dataset_name in alternative_datasets:
|
| | try:
|
| | print(f"Trying {dataset_name}...")
|
| | raw_path = self.output_dir / "conversation_raw" / f"{dataset_name.replace('/', '_')}_raw.jsonl"
|
| |
|
| | ds = load_dataset(dataset_name, split="train")
|
| |
|
| |
|
| | if len(ds) > num_conversations:
|
| | ds = ds.shuffle(seed=42).select(range(num_conversations))
|
| |
|
| | with open(raw_path, "w", encoding="utf-8") as f:
|
| | for row in tqdm(ds):
|
| | f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| |
|
| | print(f"Successfully downloaded {len(ds)} records from {dataset_name}")
|
| | return raw_path
|
| |
|
| | except Exception as e:
|
| | print(f"Failed to download {dataset_name}: {e}")
|
| | continue
|
| |
|
| | raise Exception("All conversational datasets failed to download")
|
| |
|
| | def process_conversations(self, input_path, dataset_name="auto"):
|
| | """Process raw conversational data into standard format"""
|
| | print("Processing conversations into standard format...")
|
| |
|
| | input_path = Path(input_path)
|
| |
|
| |
|
| | if "OpenAssistant" in str(input_path) or "oasst" in str(input_path):
|
| | return self.process_openassistant_messages(input_path)
|
| | else:
|
| | return self.process_other_datasets(input_path)
|
| |
|
| | def process_openassistant_messages(self, input_path):
|
| | """Process OpenAssistant individual messages into conversation chains"""
|
| |
|
| | print("🚀 Processing OpenAssistant messages into conversations...")
|
| |
|
| |
|
| | messages = []
|
| | print("Loading messages...")
|
| |
|
| | with open(input_path, 'r', encoding='utf-8') as f:
|
| | for line in tqdm(f, desc="Reading messages"):
|
| | try:
|
| | msg = json.loads(line)
|
| |
|
| | if (msg.get('lang') == 'en' and
|
| | not msg.get('deleted', False) and
|
| | msg.get('review_result', False) and
|
| | msg.get('text', '').strip()):
|
| |
|
| | messages.append(msg)
|
| | except:
|
| | continue
|
| |
|
| | print(f"Loaded {len(messages)} valid English messages")
|
| |
|
| |
|
| | trees = defaultdict(list)
|
| | for msg in messages:
|
| | tree_id = msg.get('message_tree_id')
|
| | if tree_id:
|
| | trees[tree_id].append(msg)
|
| |
|
| | print(f"Found {len(trees)} conversation trees")
|
| |
|
| |
|
| | conversations = []
|
| |
|
| | for tree_id, tree_messages in tqdm(trees.items(), desc="Building conversations"):
|
| |
|
| | msg_dict = {msg['message_id']: msg for msg in tree_messages}
|
| |
|
| |
|
| | roots = [msg for msg in tree_messages if not msg.get('parent_id')]
|
| |
|
| | for root in roots:
|
| | try:
|
| |
|
| | paths = self.build_conversation_paths(root, msg_dict)
|
| |
|
| | for path in paths:
|
| |
|
| | conversation = []
|
| | for msg in path:
|
| | role = "user" if msg['role'] == "prompter" else "assistant"
|
| | conversation.append({
|
| | "role": role,
|
| | "content": msg['text'].strip()
|
| | })
|
| |
|
| |
|
| | if self.is_valid_conversation(conversation):
|
| | conversations.append({
|
| | "messages": conversation,
|
| | "tree_id": tree_id,
|
| | "source": "oasst1"
|
| | })
|
| | except Exception as e:
|
| |
|
| | continue
|
| |
|
| | print(f"Extracted {len(conversations)} valid conversations")
|
| |
|
| |
|
| | output_path = self.output_dir / "conversation_processed" / "conversations_standardized.jsonl"
|
| | with open(output_path, "w", encoding="utf-8") as f:
|
| | for conv in conversations:
|
| | f.write(json.dumps(conv, ensure_ascii=False) + "\n")
|
| |
|
| | print(f"Processed data saved to: {output_path}")
|
| | return output_path
|
| |
|
| | def build_conversation_paths(self, root_msg, msg_dict, max_length=8):
|
| | """Build all conversation paths starting from a root message - FIXED"""
|
| |
|
| | def build_paths_recursive(msg, current_path):
|
| | paths = []
|
| | new_path = current_path + [msg]
|
| |
|
| |
|
| | children = []
|
| | for candidate in msg_dict.values():
|
| | if candidate.get('parent_id') == msg['message_id']:
|
| | children.append(candidate)
|
| |
|
| | if not children:
|
| |
|
| | if len(new_path) >= 2:
|
| | paths.append(new_path)
|
| | else:
|
| |
|
| |
|
| | def get_rank(x):
|
| | rank = x.get('rank')
|
| | return rank if rank is not None else 999
|
| |
|
| | try:
|
| | children.sort(key=get_rank)
|
| | best_child = children[0]
|
| |
|
| | if len(new_path) < max_length:
|
| | child_paths = build_paths_recursive(best_child, new_path)
|
| | paths.extend(child_paths)
|
| |
|
| |
|
| | if len(new_path) >= 2:
|
| | paths.append(new_path)
|
| | except:
|
| |
|
| | if children and len(new_path) < max_length:
|
| | child_paths = build_paths_recursive(children[0], new_path)
|
| | paths.extend(child_paths)
|
| |
|
| | return paths
|
| |
|
| | return build_paths_recursive(root_msg, [])
|
| |
|
| | def is_valid_conversation(self, conversation):
|
| | """Validate conversation quality"""
|
| |
|
| |
|
| | if len(conversation) < 2:
|
| | return False
|
| |
|
| |
|
| | for i in range(1, len(conversation)):
|
| | if conversation[i]['role'] == conversation[i-1]['role']:
|
| | return False
|
| |
|
| |
|
| | for msg in conversation:
|
| | content = msg['content']
|
| | if len(content) < 5 or len(content) > 1500:
|
| | return False
|
| |
|
| |
|
| | total_length = sum(len(msg['content']) for msg in conversation)
|
| | if total_length < 20 or total_length > 3000:
|
| | return False
|
| |
|
| | return True
|
| |
|
| | def process_other_datasets(self, input_path):
|
| | """Process non-OpenAssistant datasets (Dolly, Alpaca, etc.)"""
|
| |
|
| | output_path = self.output_dir / "conversation_processed" / "conversations_standardized.jsonl"
|
| | conversations = []
|
| | total_count = 0
|
| | valid_count = 0
|
| |
|
| | with open(input_path, "r", encoding="utf-8") as infile:
|
| | for line in tqdm(infile, desc="Processing conversations"):
|
| | total_count += 1
|
| | try:
|
| | raw_data = json.loads(line)
|
| |
|
| |
|
| | conversation = self.extract_conversation_other_formats(raw_data)
|
| |
|
| | if conversation and self.validate_simple_conversation(conversation):
|
| | conversations.append(conversation)
|
| | valid_count += 1
|
| |
|
| | except Exception as e:
|
| | continue
|
| |
|
| |
|
| | with open(output_path, "w", encoding="utf-8") as outfile:
|
| | for conv in conversations:
|
| | outfile.write(json.dumps(conv, ensure_ascii=False) + "\n")
|
| |
|
| | print(f"Processed {valid_count}/{total_count} valid conversations")
|
| | print(f"Processed data saved to: {output_path}")
|
| | return output_path
|
| |
|
| | def extract_conversation_other_formats(self, raw_data):
|
| | """Extract conversation from various dataset formats"""
|
| |
|
| |
|
| | if 'instruction' in raw_data and 'response' in raw_data:
|
| | messages = [
|
| | {"role": "user", "content": raw_data['instruction'].strip()}
|
| | ]
|
| | if raw_data.get('context'):
|
| | messages[0]['content'] += f"\nContext: {raw_data['context'].strip()}"
|
| |
|
| | messages.append({
|
| | "role": "assistant",
|
| | "content": raw_data['response'].strip()
|
| | })
|
| |
|
| | return {
|
| | "messages": messages,
|
| | "category": raw_data.get('category', 'general'),
|
| | "source": "dolly"
|
| | }
|
| |
|
| |
|
| | elif 'instruction' in raw_data and 'output' in raw_data:
|
| | messages = [
|
| | {"role": "user", "content": raw_data['instruction'].strip()}
|
| | ]
|
| | if raw_data.get('input'):
|
| | messages[0]['content'] += f"\nInput: {raw_data['input'].strip()}"
|
| |
|
| | messages.append({
|
| | "role": "assistant",
|
| | "content": raw_data['output'].strip()
|
| | })
|
| |
|
| | return {
|
| | "messages": messages,
|
| | "source": "alpaca"
|
| | }
|
| |
|
| | return None
|
| |
|
| | def validate_simple_conversation(self, conversation):
|
| | """Validate conversation quality for simple formats"""
|
| | messages = conversation.get('messages', [])
|
| |
|
| |
|
| | if len(messages) < 1:
|
| | return False
|
| |
|
| |
|
| | for msg in messages:
|
| | content = msg.get('content', '').strip()
|
| | if not content or len(content) < 5:
|
| | return False
|
| |
|
| |
|
| | total_length = sum(len(msg['content']) for msg in messages)
|
| | if total_length < 10 or total_length > 2000:
|
| | return False
|
| |
|
| | return True
|
| |
|
| | def format_for_training(self, input_path, train_format="instruction"):
|
| | """Format conversations for fine-tuning"""
|
| | print(f"Formatting conversations for {train_format} training...")
|
| |
|
| | input_path = Path(input_path)
|
| | output_path = self.output_dir / "conversation_final" / "conversation_train.jsonl"
|
| | test_path = self.output_dir / "conversation_final" / "conversation_test.jsonl"
|
| |
|
| | conversations = []
|
| |
|
| |
|
| | with open(input_path, "r", encoding="utf-8") as f:
|
| | for line in f:
|
| | conv = json.loads(line)
|
| | conversations.append(conv)
|
| |
|
| |
|
| | random.shuffle(conversations)
|
| | split_point = int(len(conversations) * 0.9)
|
| | train_conversations = conversations[:split_point]
|
| | test_conversations = conversations[split_point:]
|
| |
|
| |
|
| | self.save_training_format(train_conversations, output_path, train_format)
|
| | self.save_training_format(test_conversations, test_path, train_format)
|
| |
|
| | print(f"Training conversations: {len(train_conversations)}")
|
| | print(f"Test conversations: {len(test_conversations)}")
|
| | print(f"Training data saved to: {output_path}")
|
| | print(f"Test data saved to: {test_path}")
|
| |
|
| |
|
| | if train_conversations:
|
| | print("\n📝 Sample conversations:")
|
| | for i, conv in enumerate(train_conversations[:3]):
|
| | print(f"\nConversation {i+1}:")
|
| | for j, msg in enumerate(conv['messages']):
|
| | content = msg['content'][:80] + "..." if len(msg['content']) > 80 else msg['content']
|
| | print(f" {j+1}. {msg['role'].title()}: {content}")
|
| |
|
| | return output_path, test_path
|
| |
|
| | def save_training_format(self, conversations, output_path, format_type):
|
| | """Save conversations in training format"""
|
| |
|
| | with open(output_path, "w", encoding="utf-8") as f:
|
| | for conv in conversations:
|
| | messages = conv['messages']
|
| |
|
| | if len(messages) >= 2:
|
| | if format_type == "instruction":
|
| |
|
| | input_messages = []
|
| | for msg in messages[:-1]:
|
| | input_messages.append(f"{msg['role'].title()}: {msg['content']}")
|
| |
|
| | training_example = {
|
| | "instruction": "Continue this conversation naturally and helpfully.",
|
| | "input": "\n".join(input_messages),
|
| | "output": messages[-1]['content']
|
| | }
|
| |
|
| | elif format_type == "chat":
|
| |
|
| | training_example = {
|
| | "messages": [
|
| | {"role": "system", "content": "You are MAP-NEO, a helpful AI assistant."}
|
| | ] + messages
|
| | }
|
| |
|
| | f.write(json.dumps(training_example, ensure_ascii=False) + "\n")
|
| |
|
| |
|
| | def main():
|
| | parser = argparse.ArgumentParser(description="Preprocess conversational data for fine-tuning")
|
| | parser.add_argument("--dataset", type=str, default="OpenAssistant/oasst1",
|
| | help="Dataset to download")
|
| | parser.add_argument("--num_conversations", type=int, default=20000,
|
| | help="Number of conversations to download")
|
| | parser.add_argument("--format", type=str, default="instruction",
|
| | choices=["instruction", "chat"],
|
| | help="Training format")
|
| | parser.add_argument("--output_dir", type=str, default="data",
|
| | help="Output directory")
|
| |
|
| | args = parser.parse_args()
|
| |
|
| |
|
| | preprocessor = ConversationDataPreprocessor(args.output_dir)
|
| |
|
| |
|
| | print("Starting conversational data preprocessing pipeline...")
|
| |
|
| |
|
| | raw_path = preprocessor.download_conversational_data(
|
| | args.dataset, args.num_conversations
|
| | )
|
| |
|
| |
|
| | processed_path = preprocessor.process_conversations(raw_path, args.dataset)
|
| |
|
| |
|
| | train_path, test_path = preprocessor.format_for_training(
|
| | processed_path, args.format
|
| | )
|
| |
|
| | print("\n" + "="*60)
|
| | print("🎉 Conversational data preprocessing complete!")
|
| | print(f"Training data: {train_path}")
|
| | print(f"Test data: {test_path}")
|
| | print("\n🚀 Ready for conversational fine-tuning!")
|
| | print("Next step: python finetune_conversational.py")
|
| | print("="*60)
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|