| import { defaultModel } from "$lib/server/models"; |
| import { modelEndpoint } from "./modelEndpoint"; |
| import { trimSuffix } from "$lib/utils/trimSuffix"; |
| import { trimPrefix } from "$lib/utils/trimPrefix"; |
| import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken"; |
| import { AwsClient } from "aws4fetch"; |
|
|
| interface Parameters { |
| temperature: number; |
| truncate: number; |
| max_new_tokens: number; |
| stop: string[]; |
| } |
| export async function generateFromDefaultEndpoint( |
| prompt: string, |
| parameters?: Partial<Parameters> |
| ) { |
| const newParameters = { |
| ...defaultModel.parameters, |
| ...parameters, |
| return_full_text: false, |
| }; |
|
|
| const randomEndpoint = modelEndpoint(defaultModel); |
|
|
| const abortController = new AbortController(); |
|
|
| let resp: Response; |
|
|
| if (randomEndpoint.host === "sagemaker") { |
| const requestParams = JSON.stringify({ |
| ...newParameters, |
| inputs: prompt, |
| }); |
|
|
| const aws = new AwsClient({ |
| accessKeyId: randomEndpoint.accessKey, |
| secretAccessKey: randomEndpoint.secretKey, |
| sessionToken: randomEndpoint.sessionToken, |
| service: "sagemaker", |
| }); |
|
|
| resp = await aws.fetch(randomEndpoint.url, { |
| method: "POST", |
| body: requestParams, |
| signal: abortController.signal, |
| headers: { |
| "Content-Type": "application/json", |
| }, |
| }); |
| } else { |
| resp = await fetch(randomEndpoint.url, { |
| headers: { |
| "Content-Type": "application/json", |
| Authorization: randomEndpoint.authorization, |
| }, |
| method: "POST", |
| body: JSON.stringify({ |
| ...newParameters, |
| inputs: prompt, |
| }), |
| signal: abortController.signal, |
| }); |
| } |
|
|
| if (!resp.ok) { |
| throw new Error(await resp.text()); |
| } |
|
|
| if (!resp.body) { |
| throw new Error("Response body is empty"); |
| } |
|
|
| const decoder = new TextDecoder(); |
| const reader = resp.body.getReader(); |
|
|
| let isDone = false; |
| let result = ""; |
|
|
| while (!isDone) { |
| const { done, value } = await reader.read(); |
|
|
| isDone = done; |
| result += decoder.decode(value, { stream: true }); |
| } |
|
|
| |
| reader.releaseLock(); |
|
|
| const results = await JSON.parse(result); |
|
|
| let generated_text = trimSuffix( |
| trimPrefix(trimPrefix(results[0].generated_text, "<|startoftext|>"), prompt), |
| PUBLIC_SEP_TOKEN |
| ).trimEnd(); |
|
|
| for (const stop of [...(newParameters?.stop ?? []), "<|endoftext|>"]) { |
| if (generated_text.endsWith(stop)) { |
| generated_text = generated_text.slice(0, -stop.length).trimEnd(); |
| } |
| } |
|
|
| return generated_text; |
| } |
|
|