diff --git a/src/LuaAST.ts b/src/LuaAST.ts index 2bfa16c18..994755617 100644 --- a/src/LuaAST.ts +++ b/src/LuaAST.ts @@ -44,6 +44,7 @@ export enum SyntaxKind { MethodCallExpression, Identifier, TableIndexExpression, + ParenthesizedExpression, // Operators @@ -843,3 +844,20 @@ export function isInlineFunctionExpression(expression: FunctionExpression): expr (expression.flags & NodeFlags.Inline) !== 0 ); } + +export type ParenthesizedExpression = Expression & { + expression: Expression; +}; + +export function isParenthesizedExpression(node: Node): node is ParenthesizedExpression { + return node.kind === SyntaxKind.ParenthesizedExpression; +} + +export function createParenthesizedExpression(expression: Expression, tsOriginal?: ts.Node): ParenthesizedExpression { + const parenthesizedExpression = createNode( + SyntaxKind.ParenthesizedExpression, + tsOriginal + ) as ParenthesizedExpression; + parenthesizedExpression.expression = expression; + return parenthesizedExpression; +} diff --git a/src/LuaPrinter.ts b/src/LuaPrinter.ts index 97a68fd74..1f89dd4ec 100644 --- a/src/LuaPrinter.ts +++ b/src/LuaPrinter.ts @@ -612,6 +612,8 @@ export class LuaPrinter { return this.printIdentifier(expression as lua.Identifier); case lua.SyntaxKind.TableIndexExpression: return this.printTableIndexExpression(expression as lua.TableIndexExpression); + case lua.SyntaxKind.ParenthesizedExpression: + return this.printParenthesizedExpression(expression as lua.ParenthesizedExpression); default: throw new Error(`Tried to print unknown statement kind: ${lua.SyntaxKind[expression.kind]}`); } @@ -822,6 +824,10 @@ export class LuaPrinter { return this.createSourceNode(expression, chunks); } + public printParenthesizedExpression(expression: lua.ParenthesizedExpression) { + return this.createSourceNode(expression, ["(", this.printExpression(expression.expression), ")"]); + } + public printOperator(kind: lua.Operator): SourceNode { return new SourceNode(null, null, this.relativeSourcePath, LuaPrinter.operatorMap[kind]); } diff --git a/src/transformation/visitors/access.ts b/src/transformation/visitors/access.ts index 12e9385ad..92955d15d 100644 --- a/src/transformation/visitors/access.ts +++ b/src/transformation/visitors/access.ts @@ -79,13 +79,15 @@ export function transformElementAccessExpressionWithCapture( context.diagnostics.push(invalidMultiReturnAccess(node)); } - // When selecting the first element, we can shortcut - if (ts.isNumericLiteral(node.argumentExpression) && node.argumentExpression.text === "0") { - return { expression: table }; - } else { - const selectIdentifier = lua.createIdentifier("select"); - return { expression: lua.createCallExpression(selectIdentifier, [updatedAccessExpression, table]) }; + const canOmitSelect = ts.isNumericLiteral(node.argumentExpression) && node.argumentExpression.text === "0"; + if (canOmitSelect) { + // wrapping in parenthesis ensures only the first return value is used + // https://www.lua.org/manual/5.1/manual.html#2.5 + return { expression: lua.createParenthesizedExpression(table) }; } + + const selectIdentifier = lua.createIdentifier("select"); + return { expression: lua.createCallExpression(selectIdentifier, [updatedAccessExpression, table]) }; } if (thisValueCapture) { diff --git a/test/unit/language-extensions/multi.spec.ts b/test/unit/language-extensions/multi.spec.ts index 3350876e8..8282ef4fc 100644 --- a/test/unit/language-extensions/multi.spec.ts +++ b/test/unit/language-extensions/multi.spec.ts @@ -338,3 +338,38 @@ test("LuaMultiReturn applies after casting a function (#1404)", () => { .withLanguageExtensions() .expectToEqual([3, 4]); }); + +// https://github.com/TypeScriptToLua/TypeScriptToLua/issues/1411 +describe("LuaMultiReturn returns all values even when indexed with [0] #1411", () => { + const sharedCode = ` + declare function select(this:void, index: '#', ...args: T[]): number; + + function foo(this: void): LuaMultiReturn<[number, number]> { + return $multi(123, 456); + } + + function bar(this: void): number { + return foo()[0]; + } + `; + test("Returns the correct value", () => { + util.testFunction` + return bar(); + ` + .setTsHeader(sharedCode) + .withLanguageExtensions() + .expectToEqual(123); + }); + + // We require an extra test since the test above would succeed even if bar() returned more than one value. + // This is because our test helper always returns just the first lua value returned from the testFunction. + // Here we need to test explicitly that only one value is returned. + test("Does not return multiple values", () => { + util.testFunction` + return select("#", bar()); + ` + .setTsHeader(sharedCode) + .withLanguageExtensions() + .expectToEqual(1); + }); +}); diff --git a/test/util.ts b/test/util.ts index aab9d303f..f2c73e1e2 100644 --- a/test/util.ts +++ b/test/util.ts @@ -161,7 +161,7 @@ export abstract class TestBuilder { skipLibCheck: true, target: ts.ScriptTarget.ES2017, lib: ["lib.esnext.d.ts"], - moduleResolution: ts.ModuleResolutionKind.NodeJs, + moduleResolution: ts.ModuleResolutionKind.Node10, resolveJsonModule: true, experimentalDecorators: true, sourceMap: true, @@ -173,7 +173,8 @@ export abstract class TestBuilder { } public withLanguageExtensions(): this { - this.setOptions({ types: [path.resolve(__dirname, "..", "language-extensions")] }); + const langExtTypes = path.resolve(__dirname, "..", "language-extensions"); + this.options.types = this.options.types ? [...this.options.types, langExtTypes] : [langExtTypes]; // Polyfill lualib for JS this.setJsHeader(` function $multi(...args) { return args; }