// Copyright (c) 2023 Google LLC. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extract_source.h" #include #include #include #include #include "source/opt/log.h" #include "spirv-tools/libspirv.hpp" #include "spirv/unified1/spirv.hpp" #include "tools/util/cli_consumer.h" namespace { constexpr auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_6; // Extract a string literal from a given range. // Copies all the characters from `begin` to the first '\0' it encounters, while // removing escape patterns. // Not finding a '\0' before reaching `end` fails the extraction. // // Returns `true` if the extraction succeeded. // `output` value is undefined if false is returned. spv_result_t ExtractStringLiteral(const spv_position_t& loc, const char* begin, const char* end, std::string* output) { size_t sourceLength = std::distance(begin, end); std::string escapedString; escapedString.resize(sourceLength); size_t writeIndex = 0; size_t readIndex = 0; for (; readIndex < sourceLength; writeIndex++, readIndex++) { const char read = begin[readIndex]; if (read == '\0') { escapedString.resize(writeIndex); output->append(escapedString); return SPV_SUCCESS; } if (read == '\\') { ++readIndex; } escapedString[writeIndex] = begin[readIndex]; } spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc, "Missing NULL terminator for literal string."); return SPV_ERROR_INVALID_BINARY; } spv_result_t extractOpString(const spv_position_t& loc, const spv_parsed_instruction_t& instruction, std::string* output) { assert(output != nullptr); assert(instruction.opcode == spv::Op::OpString); if (instruction.num_operands != 2) { spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc, "Missing operands for OpString."); return SPV_ERROR_INVALID_BINARY; } const auto& operand = instruction.operands[1]; const char* stringBegin = reinterpret_cast(instruction.words + operand.offset); const char* stringEnd = reinterpret_cast( instruction.words + operand.offset + operand.num_words); return ExtractStringLiteral(loc, stringBegin, stringEnd, output); } spv_result_t extractOpSourceContinued( const spv_position_t& loc, const spv_parsed_instruction_t& instruction, std::string* output) { assert(output != nullptr); assert(instruction.opcode == spv::Op::OpSourceContinued); if (instruction.num_operands != 1) { spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc, "Missing operands for OpSourceContinued."); return SPV_ERROR_INVALID_BINARY; } const auto& operand = instruction.operands[0]; const char* stringBegin = reinterpret_cast(instruction.words + operand.offset); const char* stringEnd = reinterpret_cast( instruction.words + operand.offset + operand.num_words); return ExtractStringLiteral(loc, stringBegin, stringEnd, output); } spv_result_t extractOpSource(const spv_position_t& loc, const spv_parsed_instruction_t& instruction, spv::Id* filename, std::string* code) { assert(filename != nullptr && code != nullptr); assert(instruction.opcode == spv::Op::OpSource); // OpCode [ Source Language | Version | File (optional) | Source (optional) ] if (instruction.num_words < 3) { spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc, "Missing operands for OpSource."); return SPV_ERROR_INVALID_BINARY; } *filename = 0; *code = ""; if (instruction.num_words < 4) { return SPV_SUCCESS; } *filename = instruction.words[3]; if (instruction.num_words < 5) { return SPV_SUCCESS; } const char* stringBegin = reinterpret_cast(instruction.words + 4); const char* stringEnd = reinterpret_cast(instruction.words + instruction.num_words); return ExtractStringLiteral(loc, stringBegin, stringEnd, code); } } // namespace bool ExtractSourceFromModule( const std::vector& binary, std::unordered_map* output) { auto context = spvtools::SpirvTools(kDefaultEnvironment); context.SetMessageConsumer(spvtools::utils::CLIMessageConsumer); // There is nothing valuable in the header. spvtools::HeaderParser headerParser = [](const spv_endianness_t, const spv_parsed_header_t&) { return SPV_SUCCESS; }; std::unordered_map stringMap; std::vector> sources; spv::Op lastOpcode = spv::Op::OpMax; size_t instructionIndex = 0; spvtools::InstructionParser instructionParser = [&stringMap, &sources, &lastOpcode, &instructionIndex](const spv_parsed_instruction_t& instruction) { const spv_position_t loc = {0, 0, instructionIndex + 1}; spv_result_t result = SPV_SUCCESS; if (instruction.opcode == spv::Op::OpString) { std::string content; result = extractOpString(loc, instruction, &content); if (result == SPV_SUCCESS) { stringMap.emplace(instruction.result_id, std::move(content)); } } else if (instruction.opcode == spv::Op::OpSource) { spv::Id filenameId; std::string code; result = extractOpSource(loc, instruction, &filenameId, &code); if (result == SPV_SUCCESS) { sources.emplace_back(std::make_pair(filenameId, std::move(code))); } } else if (instruction.opcode == spv::Op::OpSourceContinued) { if (lastOpcode != spv::Op::OpSource) { spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc, "OpSourceContinued MUST follow an OpSource."); return SPV_ERROR_INVALID_BINARY; } assert(sources.size() > 0); result = extractOpSourceContinued(loc, instruction, &sources.back().second); } ++instructionIndex; lastOpcode = static_cast(instruction.opcode); return result; }; if (!context.Parse(binary, headerParser, instructionParser)) { return false; } std::string defaultName = "unnamed-"; size_t unnamedCount = 0; for (auto & [ id, code ] : sources) { std::string filename; const auto it = stringMap.find(id); if (it == stringMap.cend() || it->second.empty()) { filename = "unnamed-" + std::to_string(unnamedCount) + ".hlsl"; ++unnamedCount; } else { filename = it->second; } if (output->count(filename) != 0) { spvtools::Error(spvtools::utils::CLIMessageConsumer, "", {}, "Source file name conflict."); return false; } output->insert({filename, code}); } return true; }