Files
blog-embedding/blog_search.py
2026-03-30 04:44:51 +00:00

116 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Blog Semantic Search
====================
讀取 blog_embeddings.json讓使用者輸入或貼上任意文字
用同一個 embedding 模型算出向量,列出最相關的 20 篇文章。
需求:
- ollama 已啟動且已 pull qwen3-embedding:8b
- pip install numpy requests
- 已經跑過 blog_embeddings.py 產生 blog_embeddings.json
"""
import json
import numpy as np
import requests
# ============================================================
# 設定區
# ============================================================
EMBEDDINGS_FILE = "./blog_embeddings.json"
OLLAMA_URL = "http://localhost:11434/api/embed"
OLLAMA_MODEL = "qwen3-embedding:8b"
# ============================================================
def get_embedding(text: str) -> list[float]:
"""透過 Ollama API 取得 embedding 向量。"""
resp = requests.post(OLLAMA_URL, json={
"model": OLLAMA_MODEL,
"input": text,
})
resp.raise_for_status()
return resp.json()["embeddings"][0]
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""計算兩個向量的 cosine similarity。"""
dot = np.dot(a, b)
norm = np.linalg.norm(a) * np.linalg.norm(b)
if norm == 0:
return 0.0
return float(dot / norm)
def format_row(rank: int, sim: float, title: str, slug: str) -> str:
"""格式化一行結果。"""
bar = "" * int(sim * 30)
return f" {rank:3d}. {sim:.4f} {bar} {title} ({slug})"
def read_query() -> str:
"""讀取使用者輸入,支援多行貼上(連按兩次 Enter 結束)。"""
print("請輸入搜尋文字(連按兩次 Enter 結束):")
lines = []
empty_count = 0
while True:
line = input()
if line == "":
empty_count += 1
if empty_count >= 2:
break
lines.append(line)
else:
empty_count = 0
lines.append(line)
return "\n".join(lines).strip()
def main():
# 讀取 embeddings
with open(EMBEDDINGS_FILE, encoding="utf-8") as f:
data = json.load(f)
if not data:
print("找不到任何文章資料。")
return
print(f"已載入 {len(data)} 篇文章的 embeddings\n")
# 預先轉成 numpy array之後比較比較快
vectors = [np.array(item["embedding"]) for item in data]
while True:
query = read_query()
if not query:
print("沒有輸入,掰掰!")
break
print(f"\n正在產生 embedding模型{OLLAMA_MODEL}...")
query_vec = np.array(get_embedding(query))
# 計算與所有文章的相似度
similarities = []
for i, item in enumerate(data):
sim = cosine_similarity(query_vec, vectors[i])
similarities.append({
"slug": item["slug"],
"title": item.get("title", "(無標題)"),
"similarity": sim,
})
similarities.sort(key=lambda x: x["similarity"], reverse=True)
top = similarities[:20]
print(f"\n🔍 與輸入文字最相關的 20 篇:\n")
for rank, s in enumerate(top, 1):
print(format_row(rank, s["similarity"], s["title"], s["slug"]))
print("\n" + "" * 60 + "\n")
if __name__ == "__main__":
main()