116 lines
3.3 KiB
Python
116 lines
3.3 KiB
Python
"""
|
||
Blog Semantic Search
|
||
====================
|
||
讀取 blog_embeddings.json,讓使用者輸入(或貼上)任意文字,
|
||
用同一個 embedding 模型算出向量,列出最相關的 20 篇文章。
|
||
|
||
需求:
|
||
- ollama 已啟動且已 pull qwen3-embedding:8b
|
||
- pip install numpy requests
|
||
- 已經跑過 blog_embeddings.py 產生 blog_embeddings.json
|
||
"""
|
||
|
||
import json
|
||
import numpy as np
|
||
import requests
|
||
|
||
# ============================================================
|
||
# 設定區
|
||
# ============================================================
|
||
|
||
EMBEDDINGS_FILE = "./blog_embeddings.json"
|
||
OLLAMA_URL = "http://localhost:11434/api/embed"
|
||
OLLAMA_MODEL = "qwen3-embedding:8b"
|
||
|
||
# ============================================================
|
||
|
||
|
||
def get_embedding(text: str) -> list[float]:
|
||
"""透過 Ollama API 取得 embedding 向量。"""
|
||
resp = requests.post(OLLAMA_URL, json={
|
||
"model": OLLAMA_MODEL,
|
||
"input": text,
|
||
})
|
||
resp.raise_for_status()
|
||
return resp.json()["embeddings"][0]
|
||
|
||
|
||
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||
"""計算兩個向量的 cosine similarity。"""
|
||
dot = np.dot(a, b)
|
||
norm = np.linalg.norm(a) * np.linalg.norm(b)
|
||
if norm == 0:
|
||
return 0.0
|
||
return float(dot / norm)
|
||
|
||
|
||
def format_row(rank: int, sim: float, title: str, slug: str) -> str:
|
||
"""格式化一行結果。"""
|
||
bar = "█" * int(sim * 30)
|
||
return f" {rank:3d}. {sim:.4f} {bar} {title} ({slug})"
|
||
|
||
|
||
def read_query() -> str:
|
||
"""讀取使用者輸入,支援多行貼上(連按兩次 Enter 結束)。"""
|
||
print("請輸入搜尋文字(連按兩次 Enter 結束):")
|
||
lines = []
|
||
empty_count = 0
|
||
while True:
|
||
line = input()
|
||
if line == "":
|
||
empty_count += 1
|
||
if empty_count >= 2:
|
||
break
|
||
lines.append(line)
|
||
else:
|
||
empty_count = 0
|
||
lines.append(line)
|
||
return "\n".join(lines).strip()
|
||
|
||
|
||
def main():
|
||
# 讀取 embeddings
|
||
with open(EMBEDDINGS_FILE, encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
|
||
if not data:
|
||
print("找不到任何文章資料。")
|
||
return
|
||
|
||
print(f"已載入 {len(data)} 篇文章的 embeddings\n")
|
||
|
||
# 預先轉成 numpy array,之後比較比較快
|
||
vectors = [np.array(item["embedding"]) for item in data]
|
||
|
||
while True:
|
||
query = read_query()
|
||
if not query:
|
||
print("沒有輸入,掰掰!")
|
||
break
|
||
|
||
print(f"\n正在產生 embedding(模型:{OLLAMA_MODEL})...")
|
||
query_vec = np.array(get_embedding(query))
|
||
|
||
# 計算與所有文章的相似度
|
||
similarities = []
|
||
for i, item in enumerate(data):
|
||
sim = cosine_similarity(query_vec, vectors[i])
|
||
similarities.append({
|
||
"slug": item["slug"],
|
||
"title": item.get("title", "(無標題)"),
|
||
"similarity": sim,
|
||
})
|
||
|
||
similarities.sort(key=lambda x: x["similarity"], reverse=True)
|
||
|
||
top = similarities[:20]
|
||
print(f"\n🔍 與輸入文字最相關的 20 篇:\n")
|
||
for rank, s in enumerate(top, 1):
|
||
print(format_row(rank, s["similarity"], s["title"], s["slug"]))
|
||
|
||
print("\n" + "─" * 60 + "\n")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|