sitammeur commited on
Commit
1a8ee5e
·
verified ·
1 Parent(s): f000f26

Update src/worker.js

Browse files
Files changed (1) hide show
  1. src/worker.js +173 -173
src/worker.js CHANGED
@@ -1,173 +1,173 @@
1
- import {
2
- AutoTokenizer,
3
- AutoModelForCausalLM,
4
- TextStreamer,
5
- InterruptableStoppingCriteria,
6
- } from "@huggingface/transformers";
7
-
8
- /**
9
- * This class uses the Singleton pattern to enable lazy-loading of the pipeline
10
- */
11
- class TextGenerationPipeline {
12
- static model_id = "onnx-community/LFM2-350M-Math-ONNX";
13
-
14
- static async getInstance(progress_callback = null) {
15
- this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, {
16
- progress_callback,
17
- });
18
-
19
- this.model ??= AutoModelForCausalLM.from_pretrained(this.model_id, {
20
- dtype: "fp16",
21
- device: "webgpu",
22
- progress_callback,
23
- });
24
-
25
- return Promise.all([this.tokenizer, this.model]);
26
- }
27
- }
28
-
29
- const stopping_criteria = new InterruptableStoppingCriteria();
30
-
31
- let past_key_values_cache = null;
32
- /**
33
- * Generate text based on the input messages
34
- */
35
- async function generate(messages) {
36
- // Retrieve the text-generation pipeline.
37
- const [tokenizer, model] = await TextGenerationPipeline.getInstance();
38
-
39
- const inputs = tokenizer.apply_chat_template(messages, {
40
- add_generation_prompt: true,
41
- return_dict: true,
42
- });
43
-
44
- let startTime;
45
- let numTokens = 0;
46
- let tps;
47
- const token_callback_function = () => {
48
- startTime ??= performance.now();
49
-
50
- if (numTokens++ > 0) {
51
- tps = (numTokens / (performance.now() - startTime)) * 1000;
52
- }
53
- };
54
- const callback_function = (output) => {
55
- self.postMessage({
56
- status: "update",
57
- output,
58
- tps,
59
- numTokens,
60
- });
61
- };
62
-
63
- const streamer = new TextStreamer(tokenizer, {
64
- skip_prompt: true,
65
- skip_special_tokens: true,
66
- callback_function,
67
- token_callback_function,
68
- });
69
-
70
- // Tell the main thread we are starting
71
- self.postMessage({ status: "start" });
72
-
73
- const { past_key_values, sequences } = await model.generate({
74
- ...inputs,
75
- // TODO: Add when model is fixed
76
- past_key_values: past_key_values_cache,
77
-
78
- // Sampling
79
- do_sample: false,
80
- temperature: 0.6,
81
- top_p: 0.95,
82
- repetition_penalty: 1.05,
83
-
84
- max_new_tokens: 512,
85
- streamer,
86
- stopping_criteria,
87
- return_dict_in_generate: true,
88
- });
89
- past_key_values_cache = past_key_values;
90
-
91
- const decoded = tokenizer.batch_decode(sequences, {
92
- skip_special_tokens: true,
93
- });
94
-
95
- // Send the output back to the main thread
96
- self.postMessage({
97
- status: "complete",
98
- output: decoded,
99
- });
100
- }
101
-
102
- /**
103
- * Helper function to perform feature detection for WebGPU
104
- */
105
- async function check() {
106
- try {
107
- const adapter = await navigator.gpu.requestAdapter();
108
- if (!adapter) {
109
- throw new Error("WebGPU is not supported (no adapter found)");
110
- }
111
- } catch (e) {
112
- self.postMessage({
113
- status: "error",
114
- data: e.toString(),
115
- });
116
- }
117
- }
118
-
119
- /**
120
- * Helper function to load the model and tokenizer
121
- */
122
- async function load() {
123
- self.postMessage({
124
- status: "loading",
125
- data: "Loading model...",
126
- });
127
-
128
- // Load the pipeline and save it for future use.
129
- const [tokenizer, model] = await TextGenerationPipeline.getInstance((x) => {
130
- // We also add a progress callback to the pipeline so that we can
131
- // track model loading.
132
- self.postMessage(x);
133
- });
134
-
135
- self.postMessage({
136
- status: "loading",
137
- data: "Compiling shaders and warming up the model...",
138
- });
139
-
140
- // Run model with dummy input to compile shaders
141
- const inputs = tokenizer("a");
142
- await model.generate({ ...inputs, max_new_tokens: 1 });
143
- self.postMessage({ status: "ready" });
144
- }
145
-
146
- // Listen for messages from the main thread
147
- self.addEventListener("message", async (e) => {
148
- const { type, data } = e.data;
149
-
150
- switch (type) {
151
- case "check":
152
- check();
153
- break;
154
-
155
- case "load":
156
- load();
157
- break;
158
-
159
- case "generate":
160
- stopping_criteria.reset();
161
- generate(data);
162
- break;
163
-
164
- case "interrupt":
165
- stopping_criteria.interrupt();
166
- break;
167
-
168
- case "reset":
169
- past_key_values_cache = null;
170
- stopping_criteria.reset();
171
- break;
172
- }
173
- });
 
1
+ import {
2
+ AutoTokenizer,
3
+ AutoModelForCausalLM,
4
+ TextStreamer,
5
+ InterruptableStoppingCriteria,
6
+ } from "@huggingface/transformers";
7
+
8
+ /**
9
+ * This class uses the Singleton pattern to enable lazy-loading of the pipeline
10
+ */
11
+ class TextGenerationPipeline {
12
+ static model_id = "onnx-community/LFM2-350M-Math-ONNX";
13
+
14
+ static async getInstance(progress_callback = null) {
15
+ this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, {
16
+ progress_callback,
17
+ });
18
+
19
+ this.model ??= AutoModelForCausalLM.from_pretrained(this.model_id, {
20
+ dtype: "q4f16",
21
+ device: "webgpu",
22
+ progress_callback,
23
+ });
24
+
25
+ return Promise.all([this.tokenizer, this.model]);
26
+ }
27
+ }
28
+
29
+ const stopping_criteria = new InterruptableStoppingCriteria();
30
+
31
+ let past_key_values_cache = null;
32
+ /**
33
+ * Generate text based on the input messages
34
+ */
35
+ async function generate(messages) {
36
+ // Retrieve the text-generation pipeline.
37
+ const [tokenizer, model] = await TextGenerationPipeline.getInstance();
38
+
39
+ const inputs = tokenizer.apply_chat_template(messages, {
40
+ add_generation_prompt: true,
41
+ return_dict: true,
42
+ });
43
+
44
+ let startTime;
45
+ let numTokens = 0;
46
+ let tps;
47
+ const token_callback_function = () => {
48
+ startTime ??= performance.now();
49
+
50
+ if (numTokens++ > 0) {
51
+ tps = (numTokens / (performance.now() - startTime)) * 1000;
52
+ }
53
+ };
54
+ const callback_function = (output) => {
55
+ self.postMessage({
56
+ status: "update",
57
+ output,
58
+ tps,
59
+ numTokens,
60
+ });
61
+ };
62
+
63
+ const streamer = new TextStreamer(tokenizer, {
64
+ skip_prompt: true,
65
+ skip_special_tokens: true,
66
+ callback_function,
67
+ token_callback_function,
68
+ });
69
+
70
+ // Tell the main thread we are starting
71
+ self.postMessage({ status: "start" });
72
+
73
+ const { past_key_values, sequences } = await model.generate({
74
+ ...inputs,
75
+ // TODO: Add when model is fixed
76
+ past_key_values: past_key_values_cache,
77
+
78
+ // Sampling
79
+ do_sample: false,
80
+ temperature: 0.6,
81
+ top_p: 0.95,
82
+ repetition_penalty: 1.05,
83
+
84
+ max_new_tokens: 512,
85
+ streamer,
86
+ stopping_criteria,
87
+ return_dict_in_generate: true,
88
+ });
89
+ past_key_values_cache = past_key_values;
90
+
91
+ const decoded = tokenizer.batch_decode(sequences, {
92
+ skip_special_tokens: true,
93
+ });
94
+
95
+ // Send the output back to the main thread
96
+ self.postMessage({
97
+ status: "complete",
98
+ output: decoded,
99
+ });
100
+ }
101
+
102
+ /**
103
+ * Helper function to perform feature detection for WebGPU
104
+ */
105
+ async function check() {
106
+ try {
107
+ const adapter = await navigator.gpu.requestAdapter();
108
+ if (!adapter) {
109
+ throw new Error("WebGPU is not supported (no adapter found)");
110
+ }
111
+ } catch (e) {
112
+ self.postMessage({
113
+ status: "error",
114
+ data: e.toString(),
115
+ });
116
+ }
117
+ }
118
+
119
+ /**
120
+ * Helper function to load the model and tokenizer
121
+ */
122
+ async function load() {
123
+ self.postMessage({
124
+ status: "loading",
125
+ data: "Loading model...",
126
+ });
127
+
128
+ // Load the pipeline and save it for future use.
129
+ const [tokenizer, model] = await TextGenerationPipeline.getInstance((x) => {
130
+ // We also add a progress callback to the pipeline so that we can
131
+ // track model loading.
132
+ self.postMessage(x);
133
+ });
134
+
135
+ self.postMessage({
136
+ status: "loading",
137
+ data: "Compiling shaders and warming up the model...",
138
+ });
139
+
140
+ // Run model with dummy input to compile shaders
141
+ const inputs = tokenizer("a");
142
+ await model.generate({ ...inputs, max_new_tokens: 1 });
143
+ self.postMessage({ status: "ready" });
144
+ }
145
+
146
+ // Listen for messages from the main thread
147
+ self.addEventListener("message", async (e) => {
148
+ const { type, data } = e.data;
149
+
150
+ switch (type) {
151
+ case "check":
152
+ check();
153
+ break;
154
+
155
+ case "load":
156
+ load();
157
+ break;
158
+
159
+ case "generate":
160
+ stopping_criteria.reset();
161
+ generate(data);
162
+ break;
163
+
164
+ case "interrupt":
165
+ stopping_criteria.interrupt();
166
+ break;
167
+
168
+ case "reset":
169
+ past_key_values_cache = null;
170
+ stopping_criteria.reset();
171
+ break;
172
+ }
173
+ });