-
Notifications
You must be signed in to change notification settings - Fork 315
Expand file tree
/
Copy path07_custom.py
More file actions
122 lines (98 loc) · 3.34 KB
/
07_custom.py
File metadata and controls
122 lines (98 loc) · 3.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
Optional: Change where pretrained models from huggingface will be downloaded (cached) to:
export TRANSFORMERS_CACHE=/whatever/path/you/want
"""
# import os
# os.environ["TRANSFORMERS_CACHE"] = "/media/samuel/UDISK1/transformers_cache"
import os
import time
import torch
from dotenv import load_dotenv
from typing import ClassVar, Any, Iterator
from llama_index.core.llms import CustomLLM, CompletionResponse, LLMMetadata
from llama_index.core import (
SummaryIndex,
PromptHelper,
Settings,
SimpleDirectoryReader,
StorageContext,
load_index_from_storage,
)
from transformers import pipeline
# load_dotenv()
os.environ["OPENAI_API_KEY"] = "random"
def timeit():
"""
a utility decoration to time running time
"""
def decorator(func):
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
args = [str(arg) for arg in args]
print(f"[{(end - start):.8f} seconds]: f({args}) -> {result}")
return result
return wrapper
return decorator
prompt_helper = PromptHelper(
# maximum input size
context_window=2048,
# number of output tokens
num_output=256,
# the maximum overlap between chunks.
chunk_overlap_ratio=0.1,
)
class LocalOPT(CustomLLM):
# model_name = "facebook/opt-iml-max-30b" (this is a 60gb model)
model_name: ClassVar[str] = "facebook/opt-iml-1.3b" # ~2.63gb model
# https://huggingface.co/docs/transformers/main_classes/pipelines
_pipeline: ClassVar[Any] = pipeline(
"text-generation",
model=model_name,
device="cpu",
dtype=torch.float32,
)
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(model_name=self.model_name)
def complete(self, prompt: str, **kwargs) -> CompletionResponse:
response = self._pipeline(prompt, max_new_tokens=256)[0]["generated_text"]
# only return newly generated tokens
return CompletionResponse(text=response[len(prompt):])
def stream_complete(self, prompt: str, **kwargs) -> Iterator[CompletionResponse]:
raise NotImplementedError
@timeit()
def create_index():
print("Creating index")
Settings.llm = LocalOPT()
Settings.prompt_helper = prompt_helper
docs = SimpleDirectoryReader("news").load_data()
index = SummaryIndex.from_documents(docs)
print("Done creating index", index)
return index
@timeit()
def execute_query():
query_engine = index.as_query_engine()
response = query_engine.query(
"Who does Indonesia export its coal to in 2023?",
)
return response
if __name__ == "__main__":
"""
Check if a local cache of the model exists,
if not, it will download the model from huggingface
"""
if not os.path.exists("7_custom_opt"):
print("No local cache of model found, downloading from huggingface")
index = create_index()
index.storage_context.persist(persist_dir="./7_custom_opt")
else:
print("Loading local cache of model")
Settings.llm = LocalOPT()
Settings.prompt_helper = prompt_helper
storage_context = StorageContext.from_defaults(persist_dir="./7_custom_opt")
index = load_index_from_storage(storage_context)
response = execute_query()
print(response)
print(response.source_nodes)