sitammeur commited on
Commit
356ec31
·
verified ·
1 Parent(s): 71b6ff5

Update src/worker.js

Browse files
Files changed (1) hide show
  1. src/worker.js +56 -55
src/worker.js CHANGED
@@ -1,56 +1,57 @@
1
- /**
2
- * Worker script for zero-shot classification.
3
- * Loads the pipeline and handles classification requests.
4
- */
5
- import { env, pipeline } from '@huggingface/transformers';
6
-
7
- // Skip local model check since we are downloading the model from the Hugging Face Hub.
8
- env.allowLocalModels = false;
9
-
10
- /**
11
- * Class for zero-shot classification.
12
- * Loads the pipeline and handles classification requests.
13
- */
14
- class MyZeroShotClassificationPipeline {
15
- // Task and model for zero-shot classification.
16
- static task = 'zero-shot-classification';
17
- static model = 'MoritzLaurer/ModernBERT-large-zeroshot-v2.0';
18
- static instance = null;
19
-
20
- // Get the pipeline instance.
21
- static async getInstance(progress_callback = null) {
22
- if (this.instance === null) {
23
- this.instance = pipeline(this.task, this.model, {
24
- quantized: true,
25
- progress_callback,
26
- });
27
- }
28
-
29
- return this.instance;
30
- }
31
- }
32
-
33
- // Listen for messages from the main thread
34
- self.addEventListener('message', async (event) => {
35
- // Retrieve the pipeline. When called for the first time,
36
- // this will load the pipeline and save it for future use.
37
- const classifier = await MyZeroShotClassificationPipeline.getInstance(x => {
38
- // We also add a progress callback to the pipeline so that we can
39
- // track model loading.
40
- self.postMessage(x);
41
- });
42
-
43
- const { text, labels } = event.data;
44
-
45
- const split = text.split('\n');
46
- for (const line of split) {
47
- const output = await classifier(line, labels, {
48
- hypothesis_template: 'This text is about {}.',
49
- multi_label: true,
50
- });
51
- // Send the output back to the main thread
52
- self.postMessage({ status: 'output', output });
53
- }
54
- // Send the output back to the main thread
55
- self.postMessage({ status: 'complete' });
 
56
  });
 
1
+ /**
2
+ * Worker script for zero-shot classification.
3
+ * Loads the pipeline and handles classification requests.
4
+ */
5
+ import { env, pipeline } from '@huggingface/transformers';
6
+
7
+ // Skip local model check since we are downloading the model from the Hugging Face Hub.
8
+ env.allowLocalModels = false;
9
+
10
+ /**
11
+ * Class for zero-shot classification.
12
+ * Loads the pipeline and handles classification requests.
13
+ */
14
+ class MyZeroShotClassificationPipeline {
15
+ // Task and model for zero-shot classification.
16
+ static task = 'zero-shot-classification';
17
+ static model = 'MoritzLaurer/ModernBERT-large-zeroshot-v2.0';
18
+ static instance = null;
19
+
20
+ // Get the pipeline instance.
21
+ static async getInstance(progress_callback = null) {
22
+ if (this.instance === null) {
23
+ this.instance = pipeline(this.task, this.model, {
24
+ quantized: true,
25
+ progress_callback,
26
+ device: "webgpu"
27
+ });
28
+ }
29
+
30
+ return this.instance;
31
+ }
32
+ }
33
+
34
+ // Listen for messages from the main thread
35
+ self.addEventListener('message', async (event) => {
36
+ // Retrieve the pipeline. When called for the first time,
37
+ // this will load the pipeline and save it for future use.
38
+ const classifier = await MyZeroShotClassificationPipeline.getInstance(x => {
39
+ // We also add a progress callback to the pipeline so that we can
40
+ // track model loading.
41
+ self.postMessage(x);
42
+ });
43
+
44
+ const { text, labels } = event.data;
45
+
46
+ const split = text.split('\n');
47
+ for (const line of split) {
48
+ const output = await classifier(line, labels, {
49
+ hypothesis_template: 'This text is about {}.',
50
+ multi_label: true,
51
+ });
52
+ // Send the output back to the main thread
53
+ self.postMessage({ status: 'output', output });
54
+ }
55
+ // Send the output back to the main thread
56
+ self.postMessage({ status: 'complete' });
57
  });