#!/usr/bin/env python3
"""
HTTP TTS 客户端测试脚本
"""

import requests
import json
import base64
import soundfile as sf
import numpy as np
import time
import sys
import os
from typing import List

# 配置
BASE_URL = "http://107.151.234.179:8001"


def test_single_sentence_streaming(voice: str):
    """
    测试单句合成功能（流式）
    """
    print(f"\n=== HTTP 单句合成测试（流式）===")
    print(f"使用音色: {voice}")
    
    output_dir = 'test_dir/'
    
    # 创建输出目录
    os.makedirs(output_dir, exist_ok=True)
    

    lines = ["わぁ〜、雑誌（ざっし）に載ってるケーキ、めっちゃ美味しそうだモン！"]
    print(f"流式测试 {len(lines)} 句文本")
    
    success_count = 0
    
    for i, line in enumerate(lines):
        print(f"\n流式合成第 {i+1}/{len(lines)} 句: {line[:50]}...")
         
        start_time = time.time()
        first_chunk_time = None
        
        req = {
            'input': line,
            'voice': voice,
            'response_format': 'pcm',
            'sample_rate': 24000,
            'stream': True
        }
        
        try:
            response = requests.post(f"{BASE_URL}/v1/audio/speech", 
                                   json=req, 
                                   stream=True,
                                   timeout=120)
            
            if response.status_code == 200:
                audio_chunks = []
                chunk_count = 0
                
                for line_data in response.iter_lines():
                    if line_data:
                        line_str = line_data.decode('utf-8')
                        if line_str.startswith('data:'):
                            data_str = line_str[5:]
                            try:
                                chunk_data = json.loads(data_str)
                                
                                if 'error' in chunk_data:
                                    print(f"  服务端错误: {chunk_data['error']}")
                                    break
                                
                                if chunk_data.get('audio'):
                                    if first_chunk_time is None:
                                        first_chunk_time = time.time()
                                        first_delay = (first_chunk_time - start_time) * 1000
                                        print(f"  首包延迟: {first_delay:.1f}ms")
                                    
                                    chunk_count += 1
                                    audio_bytes = base64.b64decode(chunk_data['audio'])
                                    audio_chunk = np.frombuffer(audio_bytes, dtype=np.float32)
                                    audio_chunks.append(audio_chunk)
                                    
                                    chunk_duration = len(audio_chunk) / 24000 * 1000
                                    print(f"  块#{chunk_count}: {len(audio_chunk)}采样点, 时长{chunk_duration:.1f}ms")
                                
                                if chunk_data.get('done', False):
                                    print(f"  流式传输完成")
                                    break
                                    
                            except json.JSONDecodeError:
                                continue
                
                if audio_chunks:
                    combined_audio = np.concatenate(audio_chunks)
                    output_path = os.path.join(output_dir, f"{voice}_{i}_streaming.wav")
                    sf.write(output_path, combined_audio, samplerate=24000, format='WAV')
                    
                    total_elapsed = time.time() - start_time
                    audio_duration = len(combined_audio) / 24000
                    rtf = total_elapsed / audio_duration if audio_duration > 0 else 0
                    
                    print(f"  保存成功: {output_path}")
                    print(f"  总耗时: {total_elapsed:.2f}s, RTF: {rtf:.3f}")
                    
                    success_count += 1
                else:
                    print(f"  未接收到音频数据")
                    
            else:
                print(f"  HTTP请求失败: {response.status_code}")
                
        except Exception as e:
            print(f"  流式请求失败: {e}")
    
    print(f"\n流式合成测试完成: {success_count}/{len(lines)}")
    return success_count == len(lines)



def main():
    """主测试函数"""
    if len(sys.argv) < 2:
        print("用法: python test_http_client.py <voice_name>")
        print("示例: python test_http_client.py 800")
        sys.exit(1)
    
    voice = sys.argv[1]
    
    print("HTTP TTS 客户端测试")
    print("=" * 60)
    print(f"目标服务: {BASE_URL}")
    print(f"使用音色: {voice}")
    
    # 检查服务状态
    try:
        response = requests.get(f"{BASE_URL}/health", timeout=10)
        if response.status_code == 200:
            print("服务连接正常")
        else:
            print(f"服务响应异常: {response.status_code}")
            sys.exit(1)
    except Exception as e:
        print(f"无法连接服务: {e}")
        print("请确保HTTP服务器已启动: python start_http.py")
        sys.exit(1)
    
    # 执行测试
    results = []
    

    
    # 2. 单句合成测试（流式）
    print(f"\n{'='*60}")
    results.append(("单句合成（流式）", test_single_sentence_streaming(voice)))
    

    
    # 总结
    print(f"\n{'='*60}")


if __name__ == "__main__":
    try:
        sys.exit(main())
    except KeyboardInterrupt:
        print("用户中断测试")
        sys.exit(1)
    except Exception as e:
        print(f"程序异常: {e}")
        sys.exit(1)