Spaces:
Runtime error
Runtime error
Update chess_board.py
Browse files- chess_board.py +8 -2
chess_board.py
CHANGED
|
@@ -26,7 +26,13 @@ class Game:
|
|
| 26 |
|
| 27 |
def compile_model(self):
|
| 28 |
self.model.compile(sampler=self.sampler)
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
def call_gemma(self, opening_move):
|
| 31 |
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
|
| 32 |
|
|
@@ -39,7 +45,7 @@ class Game:
|
|
| 39 |
instruction=f"Predict the next chess move in the sequence {str(self.sequence)}",
|
| 40 |
response="",)
|
| 41 |
|
| 42 |
-
output = self.model.generate(prompt, max_length=256)
|
| 43 |
|
| 44 |
gemma_move = output.split(' ')[-1].strip("'")
|
| 45 |
|
|
|
|
| 26 |
|
| 27 |
def compile_model(self):
|
| 28 |
self.model.compile(sampler=self.sampler)
|
| 29 |
+
|
| 30 |
+
@spaces.GPU
|
| 31 |
+
def inference_gemma(self, prompt, max_length=256):
|
| 32 |
+
"""Inference requires GPU"""
|
| 33 |
+
response = self.model.generate(prompt, max_length)
|
| 34 |
+
return response
|
| 35 |
+
|
| 36 |
def call_gemma(self, opening_move):
|
| 37 |
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
|
| 38 |
|
|
|
|
| 45 |
instruction=f"Predict the next chess move in the sequence {str(self.sequence)}",
|
| 46 |
response="",)
|
| 47 |
|
| 48 |
+
output = self.inference_gemma(prompt, max_length=256) #self.model.generate(prompt, max_length=256)
|
| 49 |
|
| 50 |
gemma_move = output.split(' ')[-1].strip("'")
|
| 51 |
|