Skip to content

Commit ec7bd8f

Browse files
Merge pull request #6 from crispthinking/feature/multiple-inputs
Feature/multiple inputs
2 parents ca68de1 + 4690006 commit ec7bd8f

5 files changed

Lines changed: 289 additions & 59 deletions

File tree

README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,14 @@ Athena is a gRPC-based image classification service designed for CSAM (Child Sex
1414
- **Error Handling**: Comprehensive error codes and detailed error messages
1515
- **Monitoring**: Active deployment tracking and backlog monitoring
1616

17-
## Contributing
17+
# Contributing
18+
19+
## Updating the Protobuf definitions
20+
21+
Protobufs are stored in the [@crispthinking/athena-protobuffs](https://github.com/crispthinking/athena-protobufs) repository.
22+
23+
To update the protobuf definitions for client generation, run:
24+
`git subtree pull --prefix=athena-protobufs https://github.com/crispthinking/athena-protobufs.git <sha> --squash`
1825

1926
## Regenerating the TypeScript gRPC Client
2027

__tests__/unit/main.test.ts

Lines changed: 196 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { describe, it,} from 'vitest';
2-
import { ClassifierSdk, ClassifyImageOptions, ClassifyResponse, ImageFormat } from '../../src/main';
2+
import { ClassificationOutput, ClassifierSdk, ClassifyImageInput, ImageFormat } from '../../src';
33
import fs from 'fs';
44
import { randomUUID } from 'crypto';
55

@@ -31,7 +31,99 @@ describe('classifierHelper', () => {
3131
});
3232

3333
describe('classifyImage function', () => {
34-
it('should classify Steamboat-willie.jpg and return responses (integration smoke test)', async ({expect}) => {
34+
it('should classify 10 images in a single request and return responses (integration smoke test)', async ({expect, annotate}) => {
35+
// This is a smoke test. You must have a running gRPC server at localhost:50051 for this to pass.
36+
// You may want to mock the gRPC client for true unit testing.
37+
const imagePath = __dirname + '/Steamboat-willie.jpg';
38+
const sdk = new ClassifierSdk({
39+
deploymentId: process.env.VITE_ATHENA_DEPLOYMENT_ID,
40+
affiliate: process.env.VITE_ATHENA_AFFILIATE,
41+
authentication: {
42+
issuerUrl: process.env.VITE_OAUTH_ISSUER,
43+
clientId: process.env.VITE_ATHENA_CLIENT_ID,
44+
clientSecret: process.env.VITE_ATHENA_CLIENT_SECRET,
45+
scope: 'manage:classify'
46+
}
47+
});
48+
49+
// Generate 10 unique correlationIds
50+
const correlationIds = Array.from({ length: 10 }, () => randomUUID().toString());
51+
52+
correlationIds.sort((a, b) => a.localeCompare(b));
53+
54+
annotate(`Correlation IDs: ${correlationIds.join(', ')}`);
55+
56+
// Create 10 input objects, each with a new stream and unique correlationId
57+
const inputs: ClassifyImageInput[] = correlationIds.map((correlationId) => ({
58+
imageStream: fs.createReadStream(imagePath),
59+
format: ImageFormat.PNG,
60+
correlationId
61+
}));
62+
63+
// Create a promise to wrap the event emitter event 'data'
64+
const promise = new Promise<ClassificationOutput[]>((resolve, reject) => {
65+
const results:ClassificationOutput[] = [];
66+
67+
sdk.on('data', (data) => {
68+
if (data.globalError)
69+
{
70+
reject(data.globalError);
71+
}
72+
73+
// Check that all correlationIds are present in the outputs
74+
for(const result of data.outputs)
75+
{
76+
if (correlationIds.includes(result.correlationId)) {
77+
results.push(result);
78+
}
79+
}
80+
if (results.length == correlationIds.length) {
81+
resolve(results);
82+
}
83+
});
84+
sdk.once('error', (err) => {
85+
reject(err);
86+
});
87+
});
88+
89+
let error: any = undefined;
90+
91+
await sdk.open();
92+
93+
try {
94+
await sdk.sendClassifyRequest(inputs);
95+
} catch (err) {
96+
error = err;
97+
}
98+
99+
// Wait for classifier to process some data....
100+
const outputs = await promise;
101+
sdk.close();
102+
103+
expect(error).toBeUndefined();
104+
105+
outputs.sort((a, b) => a.correlationId.localeCompare(b.correlationId));
106+
107+
expect(outputs).toBeDefined();
108+
// Check that all correlationIds are present in the outputs
109+
expect(outputs.length).toBe(correlationIds.length);
110+
111+
const expectedOutputs = correlationIds.map(id =>(
112+
{
113+
correlationId: id,
114+
classifications: expect.arrayContaining([
115+
{
116+
label: expect.any(String),
117+
weight: expect.any(Number)
118+
}
119+
])
120+
}
121+
));
122+
123+
expect(outputs).toMatchObject(expectedOutputs);
124+
}, 120000);
125+
126+
it('should classify Steamboat-willie.jpg with raw uint8 resize return responses (integration smoke test)', async ({expect, annotate}) => {
35127
// This is a smoke test. You must have a running gRPC server at localhost:50051 for this to pass.
36128
// You may want to mock the gRPC client for true unit testing.
37129
const imagePath = __dirname + '/Steamboat-willie.jpg';
@@ -48,16 +140,24 @@ describe('classifyImage function', () => {
48140

49141
const correlationId = randomUUID();
50142

143+
annotate(`Correlation IDs: ${correlationId}`);
144+
51145
// Create a promise to wrap the event emitter event 'data'
52-
const promise = new Promise<ClassifyResponse>((resolve, reject) => {
53-
sdk.once('data', (data) => {
146+
const promise = new Promise<ClassificationOutput[]>((resolve, reject) => {
147+
// Add a timeout to reject the promise if no data is received in 30 seconds
148+
const timeout = setTimeout(() => {
149+
reject(new Error('Timeout waiting for classification response'));
150+
}, 30000);
151+
152+
sdk.on('data', (data) => {
54153
const byCorrelationId = data.outputs.filter(o => o.correlationId === correlationId);
55154
if (byCorrelationId.length > 0) {
56-
resolve(data);
155+
clearTimeout(timeout);
156+
resolve(byCorrelationId);
57157
}
58-
sdk.close();
59158
});
60159
sdk.once('error', (err) => {
160+
clearTimeout(timeout);
61161
reject(err);
62162
});
63163
});
@@ -68,10 +168,10 @@ describe('classifyImage function', () => {
68168
await sdk.open();
69169

70170
const imageStream = fs.createReadStream(imagePath);
71-
const options: ClassifyImageOptions = {
171+
const options: ClassifyImageInput = {
72172
imageStream,
73-
format: ImageFormat.PNG,
74-
correlationId
173+
correlationId,
174+
resize: true,
75175
};
76176
try {
77177
await sdk.sendClassifyRequest(options);
@@ -81,12 +181,97 @@ describe('classifyImage function', () => {
81181

82182
// Wait for classifier to process some data....
83183
const first = await promise;
184+
sdk.close();
84185

85186
expect(first).toBeDefined();
187+
expect(first).toMatchObject([
188+
{
189+
correlationId,
190+
classifications: expect.arrayContaining([
191+
{
192+
label: expect.any(String),
193+
weight: expect.any(Number)
194+
}
195+
])
196+
} as ClassificationOutput
197+
]);
198+
199+
// Accept either a successful call or a connection error (for CI/dev convenience)
200+
expect(error).toBeUndefined();
201+
}, 120000);
202+
203+
it('should classify Steamboat-willie.jpg and return responses (integration smoke test)', async ({expect, annotate}) => {
204+
// This is a smoke test. You must have a running gRPC server at localhost:50051 for this to pass.
205+
// You may want to mock the gRPC client for true unit testing.
206+
const imagePath = __dirname + '/Steamboat-willie.jpg';
207+
const sdk = new ClassifierSdk({
208+
deploymentId: process.env.VITE_ATHENA_DEPLOYMENT_ID,
209+
affiliate: process.env.VITE_ATHENA_AFFILIATE,
210+
authentication: {
211+
issuerUrl: process.env.VITE_OAUTH_ISSUER,
212+
clientId: process.env.VITE_ATHENA_CLIENT_ID,
213+
clientSecret: process.env.VITE_ATHENA_CLIENT_SECRET,
214+
scope: 'manage:classify'
215+
}
216+
});
86217

87-
const byCorrelationId = first.outputs.filter(o => o.correlationId === correlationId);
218+
const correlationId = randomUUID();
219+
220+
annotate(`Correlation IDs: ${correlationId}`);
221+
222+
// Create a promise to wrap the event emitter event 'data'
223+
const promise = new Promise<ClassificationOutput[]>((resolve, reject) => {
224+
// Add a timeout to reject the promise if no data is received in 30 seconds
225+
const timeout = setTimeout(() => {
226+
reject(new Error('Timeout waiting for classification response'));
227+
}, 30000);
88228

89-
expect(byCorrelationId.length).toBeGreaterThan(0);
229+
sdk.on('data', (data) => {
230+
const byCorrelationId = data.outputs.filter(o => o.correlationId === correlationId);
231+
if (byCorrelationId.length > 0) {
232+
clearTimeout(timeout);
233+
resolve(byCorrelationId);
234+
}
235+
});
236+
sdk.once('error', (err) => {
237+
clearTimeout(timeout);
238+
reject(err);
239+
});
240+
});
241+
242+
// This will fail if no server is running, but will exercise the code path.
243+
let error: any = undefined;
244+
245+
await sdk.open();
246+
247+
const imageStream = fs.createReadStream(imagePath);
248+
const options: ClassifyImageInput = {
249+
imageStream,
250+
correlationId,
251+
format: ImageFormat.JPEG
252+
};
253+
try {
254+
await sdk.sendClassifyRequest(options);
255+
} catch (err) {
256+
error = err;
257+
}
258+
259+
// Wait for classifier to process some data....
260+
const first = await promise;
261+
sdk.close();
262+
263+
expect(first).toBeDefined();
264+
expect(first).toMatchObject([
265+
{
266+
correlationId,
267+
classifications: expect.arrayContaining([
268+
{
269+
label: expect.any(String),
270+
weight: expect.any(Number)
271+
}
272+
])
273+
} as ClassificationOutput
274+
]);
90275

91276
// Accept either a successful call or a connection error (for CI/dev convenience)
92277
expect(error).toBeUndefined();

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
"test:watch": "vitest unit"
4848
},
4949
"author": "James Abbott <abbottdev@users.noreply.github.com>",
50-
"license": "Apache-2.0",
50+
"license": "MIT",
5151
"dependencies": {
5252
"@bufbuild/protobuf": "^2.7.0",
5353
"@grpc/grpc-js": "^1.13.4",

src/hashing.ts

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import { Readable } from 'stream';
22
import crypto from 'crypto';
33
import sharp from 'sharp';
4-
import { RequestEncoding } from './main';
4+
import { ImageFormat, RequestEncoding } from '.';
55
import brotli from 'brotli';
6+
import { buffer } from 'stream/consumers';
67

78
/**
89
* Computes MD5 and SHA1 hashes from a readable stream and resizes any image data.
@@ -13,16 +14,25 @@ import brotli from 'brotli';
1314
export async function computeHashesFromStream(
1415
stream: Readable,
1516
encoding: RequestEncoding = RequestEncoding.UNCOMPRESSED,
16-
): Promise<{ md5: string; sha1: string; data: Buffer }> {
17+
imageFormat: ImageFormat = ImageFormat.UNSPECIFIED,
18+
resize: boolean = false,
19+
): Promise<{ md5: string; sha1: string; data: Buffer; format: ImageFormat }> {
1720
const md5 = crypto.createHash('md5');
1821
const sha1 = crypto.createHash('sha1');
19-
const resizer = sharp().resize(448, 448).raw({ depth: 'uint' });
22+
23+
let data: Buffer<ArrayBufferLike>;
2024

2125
stream.pipe(md5);
2226
stream.pipe(sha1);
23-
stream.pipe(resizer);
2427

25-
let data = await resizer.toBuffer();
28+
if (resize) {
29+
const resizer = sharp().resize(448, 448).raw({ depth: 'uint' });
30+
stream.pipe(resizer);
31+
data = await resizer.toBuffer();
32+
imageFormat = ImageFormat.RAW_UINT8;
33+
} else {
34+
data = await buffer(stream);
35+
}
2636

2737
if (encoding === RequestEncoding.BROTLI) {
2838
data = Buffer.from(await brotli.compress(data));
@@ -32,5 +42,6 @@ export async function computeHashesFromStream(
3242
md5: md5.digest('hex'),
3343
sha1: sha1.digest('hex'),
3444
data,
45+
format: imageFormat,
3546
};
3647
}

0 commit comments

Comments
 (0)