// Copyright 2021 The Pigweed Authors // // Licensed under the Apache License, Version 2.0 (the "License"); you may not // use this file except in compliance with the License. You may obtain a copy of // the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the // License for the specific language governing permissions and limitations under // the License. import {Status} from '@pigweed/pw_status'; import {Message} from 'google-protobuf'; import WaitQueue = require('wait-queue'); import {PendingCalls, Rpc} from './rpc_classes'; export type Callback = (a: any) => any; class RpcError extends Error { status: Status; constructor(rpc: Rpc, status: Status) { let message = ''; if (status === Status.NOT_FOUND) { message = ': the RPC server does not support this RPC'; } else if (status === Status.DATA_LOSS) { message = ': an error occurred while decoding the RPC payload'; } super(`${rpc.method.name} failed with error ${Status[status]}${message}`); this.status = status; } } class RpcTimeout extends Error { readonly rpc: Rpc; readonly timeoutMs: number; constructor(rpc: Rpc, timeoutMs: number) { super(`${rpc.method.name} timed out after ${timeoutMs} ms`); this.rpc = rpc; this.timeoutMs = timeoutMs; } } /** Represent an in-progress or completed RPC call. */ export class Call { // Responses ordered by arrival time. Undefined signifies stream completion. private responseQueue = new WaitQueue(); protected responses: Message[] = []; private rpcs: PendingCalls; private rpc: Rpc; private onNext: Callback; private onCompleted: Callback; private onError: Callback; status?: Status; error?: Status; callbackException?: Error; constructor( rpcs: PendingCalls, rpc: Rpc, onNext: Callback, onCompleted: Callback, onError: Callback ) { this.rpcs = rpcs; this.rpc = rpc; this.onNext = onNext; this.onCompleted = onCompleted; this.onError = onError; } /* Calls the RPC. This must be called immediately after construction. */ invoke(request?: Message, ignoreErrors = false): void { const previous = this.rpcs.sendRequest( this.rpc, this, ignoreErrors, request ); if (previous !== undefined && !previous.completed) { previous.handleError(Status.CANCELLED); } } get completed(): boolean { return this.status !== undefined || this.error !== undefined; } private invokeCallback(func: () => {}) { try { func(); } catch (err: unknown) { if (err instanceof Error) { console.error( `An exception was raised while invoking a callback: ${err}` ); this.callbackException = err; } console.error(`Unexpected item thrown while invoking callback: ${err}`); } } handleResponse(response: Message): void { this.responses.push(response); this.responseQueue.push(response); this.invokeCallback(() => this.onNext(response)); } handleCompletion(status: Status) { this.status = status; this.responseQueue.push(undefined); this.invokeCallback(() => this.onCompleted(status)); } handleError(error: Status): void { this.error = error; this.responseQueue.push(undefined); this.invokeCallback(() => this.onError(error)); } private async queuePopWithTimeout( timeoutMs: number ): Promise { return new Promise(async (resolve, reject) => { let timeoutExpired = false; const timeoutWatcher = setTimeout(() => { timeoutExpired = true; reject(new RpcTimeout(this.rpc, timeoutMs)); }, timeoutMs); const response = await this.responseQueue.shift(); if (timeoutExpired) { this.responseQueue.unshift(response); return; } clearTimeout(timeoutWatcher); resolve(response); }); } /** * Yields responses up the specified count as they are added. * * Throws an error as soon as it is received even if there are still * responses in the queue. * * Usage * ``` * for await (const response of call.getResponses(5)) { * console.log(response); * } * ``` * * @param {number} count The number of responses to read before returning. * If no value is specified, getResponses will block until the stream * either ends or hits an error. * @param {number} timeout The number of milliseconds to wait for a response * before throwing an error. */ async *getResponses( count?: number, timeoutMs?: number ): AsyncGenerator { this.checkErrors(); if (this.completed && this.responseQueue.length == 0) { return; } let remaining = count ?? Number.POSITIVE_INFINITY; while (remaining > 0) { const response = timeoutMs === undefined ? await this.responseQueue.shift() : await this.queuePopWithTimeout(timeoutMs!); this.checkErrors(); if (response === undefined) { return; } yield response!; remaining -= 1; } } cancel(): boolean { if (this.completed) { return false; } this.error = Status.CANCELLED; return this.rpcs.sendCancel(this.rpc); } private checkErrors(): void { if (this.callbackException !== undefined) { throw this.callbackException; } if (this.error !== undefined) { throw new RpcError(this.rpc, this.error); } } protected async unaryWait(timeoutMs?: number): Promise<[Status, Message]> { for await (const response of this.getResponses(1, timeoutMs)) { } if (this.status === undefined) { throw Error('Unexpected undefined status at end of stream'); } if (this.responses.length !== 1) { throw Error(`Unexpected number of responses: ${this.responses.length}`); } return [this.status!, this.responses[0]]; } protected async streamWait(timeoutMs?: number): Promise<[Status, Message[]]> { for await (const response of this.getResponses(undefined, timeoutMs)) { } if (this.status === undefined) { throw Error('Unexpected undefined status at end of stream'); } return [this.status!, this.responses]; } protected sendClientStream(request: Message) { this.checkErrors(); if (this.status !== undefined) { throw new RpcError(this.rpc, Status.FAILED_PRECONDITION); } this.rpcs.sendClientStream(this.rpc, request); } protected finishClientStream(requests: Message[]) { for (const request of requests) { this.sendClientStream(request); } if (!this.completed) { this.rpcs.sendClientStreamEnd(this.rpc); } } } /** Tracks the state of a unary RPC call. */ export class UnaryCall extends Call { /** Awaits the server response */ async complete(timeoutMs?: number): Promise<[Status, Message]> { return await this.unaryWait(timeoutMs); } } /** Tracks the state of a client streaming RPC call. */ export class ClientStreamingCall extends Call { /** Gets the last server message, if it exists */ get response(): Message | undefined { return this.responses.length > 0 ? this.responses[this.responses.length - 1] : undefined; } /** Sends a message from the client. */ send(request: Message) { this.sendClientStream(request); } /** Ends the client stream and waits for the RPC to complete. */ async finishAndWait( requests: Message[] = [], timeoutMs?: number ): Promise<[Status, Message[]]> { this.finishClientStream(requests); return await this.streamWait(timeoutMs); } } /** Tracks the state of a server streaming RPC call. */ export class ServerStreamingCall extends Call { complete(timeoutMs?: number): Promise<[Status, Message[]]> { return this.streamWait(timeoutMs); } } /** Tracks the state of a bidirectional streaming RPC call. */ export class BidirectionalStreamingCall extends Call { /** Sends a message from the client. */ send(request: Message) { this.sendClientStream(request); } /** Ends the client stream and waits for the RPC to complete. */ async finishAndWait( requests: Array = [], timeoutMs?: number ): Promise<[Status, Array]> { this.finishClientStream(requests); return await this.streamWait(timeoutMs); } }